LossScaleOptimizer crash under jit_compile=True with MirroredStrategy and mixed_precision
Summary
When training a model using:
tf.keras.mixed_precision.set_global_policy("mixed_float16")tf.distribute.MirroredStrategy()model.compile(..., jit_compile=True)
Keras crashes with a RuntimeError related to merge_call() and LossScaleOptimizer.
This happens even with a trivial model and dummy data.
System Information
TensorFlow version: 2.19
CUDA/cuDNN: appropriate version for TF
OS: Ubuntu 22.04 / WSL2 / Arch Linux
Python: 3.12
Minimal Reproducible Example
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import set_global_policy
# Set mixed precision policy
set_global_policy("mixed_float16")
# Create a distribution strategy
strategy = tf.distribute.MirroredStrategy()
# Build model under strategy
with strategy.scope():
model = keras.Sequential([
layers.Input(shape=(32,)),
layers.Dense(64, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(10),
])
optimizer = keras.optimizers.Adam()
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'],
jit_compile=True, # Enable XLA
)
# Dummy data
import numpy as np
x_train = np.random.randn(1024, 32).astype("float32")
y_train = np.random.randint(0, 10, size=(1024,))
# Train
model.fit(x_train, y_train, batch_size=64, epochs=5)
Error Output
Click to expand
Traceback (most recent call last):
File "/home/user/projects/local_chatbot/test.py", line 34, in <module>
model.fit(x_train, y_train, batch_size=64, epochs=5)
File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/backend/tensorflow/core.py", line 66, in _direct_assign
self._value.assign(tf.cast(value, self._value.dtype))
RuntimeError: Exception encountered when calling Cond.call().
merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn. If you are subclassing a tf.keras.Model, please avoid decorating overridden methods test_step and train_step in tf.function.
Arguments received by Cond.call():
• args=('tf.Tensor(shape=(), dtype=bool)', '<function LossScaleOptimizer._common_apply.<locals>.<lambda> at 0x72834df60d60>', '<bound method LossScaleOptimizer._stateful_handle_non_finite_grads of <keras.src.optimizers.loss_scale_optimizer.LossScaleOptimizer object at 0x72835bc67190>>')
• kwargs=<class 'inspect._empty'>
Expected Behavior
The model should train normally using XLA with jit_compile=True, under MirroredStrategy and mixed_precision, just as it does when not using MirroredStrategy.
Workarounds Tried
- Disabling
jit_compile=True→ works, but no XLA, slow - Disabling
mixed_precision→ works, but no float16 - Disabling
MirroredStrategy→ works - Disabling LossScaleOptimizer in compile → works
- Custom training loop → slow / inconsistent behavior
jit_compile=False+mixed_precision+MirroredStrategy→ works, but no compiler benefits
Conclusion
There seems to be an internal conflict between LossScaleOptimizer, XLA (jit_compile=True), and MirroredStrategy.
Please confirm if:
- This combination is supposed to be supported
- There is a planned fix
- Or if a documented workaround exists
Let me know if you need a Colab notebook or system diagnostics.
Thanks!
Comment From: amitsrivastava78
Hi I have tested this on tf-keras (keras 2.0) and your test works fine, sharing my notebook for the same, kindly check https://colab.sandbox.google.com/gist/amitsrivastava78/e0d05bf81932f98421fbb29363bb814a/issue21366fixed-ipnb.ipynb
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: github-actions[bot]
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.