LossScaleOptimizer crash under jit_compile=True with MirroredStrategy and mixed_precision

Summary

When training a model using:

  • tf.keras.mixed_precision.set_global_policy("mixed_float16")
  • tf.distribute.MirroredStrategy()
  • model.compile(..., jit_compile=True)

Keras crashes with a RuntimeError related to merge_call() and LossScaleOptimizer.

This happens even with a trivial model and dummy data.


System Information

TensorFlow version: 2.19
CUDA/cuDNN: appropriate version for TF
OS: Ubuntu 22.04 / WSL2 / Arch Linux
Python: 3.12

Minimal Reproducible Example

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import set_global_policy

# Set mixed precision policy
set_global_policy("mixed_float16")

# Create a distribution strategy
strategy = tf.distribute.MirroredStrategy()

# Build model under strategy
with strategy.scope():
    model = keras.Sequential([
        layers.Input(shape=(32,)),
        layers.Dense(64, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(10),
    ])
    optimizer = keras.optimizers.Adam()
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
        jit_compile=True,  # Enable XLA
    )

# Dummy data
import numpy as np
x_train = np.random.randn(1024, 32).astype("float32")
y_train = np.random.randint(0, 10, size=(1024,))

# Train
model.fit(x_train, y_train, batch_size=64, epochs=5)

Error Output

Click to expand
Traceback (most recent call last):
  File "/home/user/projects/local_chatbot/test.py", line 34, in <module>
    model.fit(x_train, y_train, batch_size=64, epochs=5)
  File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/backend/tensorflow/core.py", line 66, in _direct_assign
    self._value.assign(tf.cast(value, self._value.dtype))
RuntimeError: Exception encountered when calling Cond.call().

merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn. If you are subclassing a tf.keras.Model, please avoid decorating overridden methods test_step and train_step in tf.function.

Arguments received by Cond.call():
  • args=('tf.Tensor(shape=(), dtype=bool)', '<function LossScaleOptimizer._common_apply.<locals>.<lambda> at 0x72834df60d60>', '<bound method LossScaleOptimizer._stateful_handle_non_finite_grads of <keras.src.optimizers.loss_scale_optimizer.LossScaleOptimizer object at 0x72835bc67190>>')
  • kwargs=<class 'inspect._empty'>

Expected Behavior

The model should train normally using XLA with jit_compile=True, under MirroredStrategy and mixed_precision, just as it does when not using MirroredStrategy.


Workarounds Tried

  • Disabling jit_compile=True → works, but no XLA, slow
  • Disabling mixed_precision → works, but no float16
  • Disabling MirroredStrategy → works
  • Disabling LossScaleOptimizer in compile → works
  • Custom training loop → slow / inconsistent behavior
  • jit_compile=False + mixed_precision + MirroredStrategy → works, but no compiler benefits

Conclusion

There seems to be an internal conflict between LossScaleOptimizer, XLA (jit_compile=True), and MirroredStrategy.

Please confirm if:

  • This combination is supposed to be supported
  • There is a planned fix
  • Or if a documented workaround exists

Let me know if you need a Colab notebook or system diagnostics.

Thanks!