Something about how Metric state variables are aggregated across replicas is behaving oddly. Here is a toy example that simply counts the number of inputs:
import tensorflow as tf
import keras
# import tf_keras as keras
n_replicas = 4
gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)
class CountInputs(keras.metrics.Metric):
def __init__(self):
super().__init__()
self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")
def update_state(self, y_true, y_pred, sample_weight=None):
val = tf.shape(y_pred)[0]
self.var.assign_add(val)
def reset_state(self):
self.var.assign(0)
def result(self):
return tf.cast(self.var, "float32")
batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
inp = keras.Input(shape=(1,))
layer = keras.layers.Dense(10)
model = keras.Model(inp, layer(inp))
model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
model.evaluate(x, y, batch_size=batch_size)
In tf-keras
this produces the expected output:
10/10 [==============================] - 1s 5ms/step - loss: 0.0000e+00 - count_inputs: 120.0000
But in Keras 3 this produces:
10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - count_inputs: 889696.3750 - loss: 0.0000e+00
Comment From: sachinprasadhs
Hi, Could you please try with the Keras 3 and Tensorflow as a backend, instead of using MirroredStrategy, could you please use the multi-device distribution available here https://keras.io/api/distribution/ and let us know your observation. Thanks
Comment From: hertschuh
Investigation ideas: add aggregation="only_first_replica"
in the add_weight
, which is typically what is used for counters
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.