Something about how Metric state variables are aggregated across replicas is behaving oddly. Here is a toy example that simply counts the number of inputs:

import tensorflow as tf

import keras

# import tf_keras as keras

n_replicas = 4

gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
    gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)


class CountInputs(keras.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        val = tf.shape(y_pred)[0]
        self.var.assign_add(val)

    def reset_state(self):
        self.var.assign(0)

    def result(self):
        return tf.cast(self.var, "float32")


batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    inp = keras.Input(shape=(1,))
    layer = keras.layers.Dense(10)
    model = keras.Model(inp, layer(inp))
    model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
    model.evaluate(x, y, batch_size=batch_size)

In tf-keras this produces the expected output:

10/10 [==============================] - 1s 5ms/step - loss: 0.0000e+00 - count_inputs: 120.0000

But in Keras 3 this produces:

10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - count_inputs: 889696.3750 - loss: 0.0000e+00

Comment From: sachinprasadhs

Hi, Could you please try with the Keras 3 and Tensorflow as a backend, instead of using MirroredStrategy, could you please use the multi-device distribution available here https://keras.io/api/distribution/ and let us know your observation. Thanks

Comment From: hertschuh

Investigation ideas: add aggregation="only_first_replica" in the add_weight, which is typically what is used for counters

Comment From: github-actions[bot]

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.