Something about how Metric state variables are aggregated across replicas is behaving oddly. Here is a toy example that simply counts the number of inputs:
import tensorflow as tf
import keras
# import tf_keras as keras
n_replicas = 4
gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)
class CountInputs(keras.metrics.Metric):
def __init__(self):
super().__init__()
self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")
def update_state(self, y_true, y_pred, sample_weight=None):
val = tf.shape(y_pred)[0]
self.var.assign_add(val)
def reset_state(self):
self.var.assign(0)
def result(self):
return tf.cast(self.var, "float32")
batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
inp = keras.Input(shape=(1,))
layer = keras.layers.Dense(10)
model = keras.Model(inp, layer(inp))
model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
model.evaluate(x, y, batch_size=batch_size)
In tf-keras this produces the expected output:
10/10 [==============================] - 1s 5ms/step - loss: 0.0000e+00 - count_inputs: 120.0000
But in Keras 3 this produces:
10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - count_inputs: 889696.3750 - loss: 0.0000e+00
Comment From: sachinprasadhs
Hi, Could you please try with the Keras 3 and Tensorflow as a backend, instead of using MirroredStrategy, could you please use the multi-device distribution available here https://keras.io/api/distribution/ and let us know your observation. Thanks
Comment From: hertschuh
Investigation ideas: add aggregation="only_first_replica" in the add_weight, which is typically what is used for counters
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: amitsrivastava78
I tested on tf version = 2.19.0 , keras version =3.10.0 the tests produces the consistent result on tf-keras and keras 3.0
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf
print(tf.__version__)
print(keras.__version__)
n_replicas = 4
gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)
class CountInputs(keras.metrics.Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")
def update_state(self, y_true, y_pred, sample_weight=None):
# Use keras.ops.shape which works across all backends
batch_size = keras.ops.shape(y_pred)[0]
self.var.assign_add(keras.ops.cast(batch_size, "int32"))
def reset_state(self):
self.var.assign(0)
def result(self):
return keras.ops.cast(self.var, "float32")
batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
inp = keras.Input(shape=(1,))
layer = keras.layers.Dense(10)
model = keras.Model(inp, layer(inp))
model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
model.evaluate(x, y, batch_size=batch_size)
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: github-actions[bot]
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.