Something about how Metric state variables are aggregated across replicas is behaving oddly. Here is a toy example that simply counts the number of inputs:

import tensorflow as tf

import keras

# import tf_keras as keras

n_replicas = 4

gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
    gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)


class CountInputs(keras.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        val = tf.shape(y_pred)[0]
        self.var.assign_add(val)

    def reset_state(self):
        self.var.assign(0)

    def result(self):
        return tf.cast(self.var, "float32")


batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    inp = keras.Input(shape=(1,))
    layer = keras.layers.Dense(10)
    model = keras.Model(inp, layer(inp))
    model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])
    model.evaluate(x, y, batch_size=batch_size)

In tf-keras this produces the expected output:

10/10 [==============================] - 1s 5ms/step - loss: 0.0000e+00 - count_inputs: 120.0000

But in Keras 3 this produces:

10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - count_inputs: 889696.3750 - loss: 0.0000e+00

Comment From: sachinprasadhs

Hi, Could you please try with the Keras 3 and Tensorflow as a backend, instead of using MirroredStrategy, could you please use the multi-device distribution available here https://keras.io/api/distribution/ and let us know your observation. Thanks

Comment From: hertschuh

Investigation ideas: add aggregation="only_first_replica" in the add_weight, which is typically what is used for counters

Comment From: github-actions[bot]

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

Comment From: amitsrivastava78

I tested on tf version = 2.19.0 , keras version =3.10.0 the tests produces the consistent result on tf-keras and keras 3.0

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import tensorflow as tf

print(tf.__version__)
print(keras.__version__)

n_replicas = 4

gpus = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(
    gpus[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * n_replicas
)

class CountInputs(keras.metrics.Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.var = self.add_weight(name="var", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Use keras.ops.shape which works across all backends
        batch_size = keras.ops.shape(y_pred)[0]
        self.var.assign_add(keras.ops.cast(batch_size, "int32"))

    def reset_state(self):
        self.var.assign(0)

    def result(self):
        return keras.ops.cast(self.var, "float32")


batch_size = 12
x = tf.zeros((batch_size * 10, 1))
y = tf.zeros((batch_size * 10, 1))

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    inp = keras.Input(shape=(1,))
    layer = keras.layers.Dense(10)
    model = keras.Model(inp, layer(inp))
    model.compile(loss="mse", optimizer="sgd", metrics=[CountInputs()])

model.evaluate(x, y, batch_size=batch_size)

Comment From: github-actions[bot]

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

Comment From: github-actions[bot]

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.