Hi! I have stumbled upon a pretty annoying issue when training a model using tf.distribute.MirroredStrategy.
Fundamentally, after a certain number of steps in each epoch, all metrics go "nan", but the model is actually training fine under the hood.
After looking through and debugging keras' code that implements metrics, I finally found it.
keras.metrics.Mean
has "total" and "count" mirrored variables that are reduced through sum: "total" just accumulates the state of the metric, "count" is used to divide "total" and give back the correct average result. When running on multiple gpus it seems that "total" could increase too drastically, overflowing, resulting in the nan metrics.
keras.metrics.Sum
might have the same issue.
If you guys at keras have any idea on what path to follow to fix this, I would be happy to contribute if necessary.
P.S.: this issue could be related to some tensorflow issues describing this same behaviour, like https://github.com/tensorflow/tensorflow/issues/90686
Comment From: dhantule
Hi @gianlucasama, thanks for reporting this.
Could you please provide some reproducible code.