LossScaleOptimizer
crash under jit_compile=True
with MirroredStrategy
and mixed_precision
Summary
When training a model using:
tf.keras.mixed_precision.set_global_policy("mixed_float16")
tf.distribute.MirroredStrategy()
model.compile(..., jit_compile=True)
Keras crashes with a RuntimeError
related to merge_call()
and LossScaleOptimizer
.
This happens even with a trivial model and dummy data.
System Information
TensorFlow version: 2.19
CUDA/cuDNN: appropriate version for TF
OS: Ubuntu 22.04 / WSL2 / Arch Linux
Python: 3.12
Minimal Reproducible Example
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import set_global_policy
# Set mixed precision policy
set_global_policy("mixed_float16")
# Create a distribution strategy
strategy = tf.distribute.MirroredStrategy()
# Build model under strategy
with strategy.scope():
model = keras.Sequential([
layers.Input(shape=(32,)),
layers.Dense(64, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(10),
])
optimizer = keras.optimizers.Adam()
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'],
jit_compile=True, # Enable XLA
)
# Dummy data
import numpy as np
x_train = np.random.randn(1024, 32).astype("float32")
y_train = np.random.randint(0, 10, size=(1024,))
# Train
model.fit(x_train, y_train, batch_size=64, epochs=5)
Error Output
Click to expand
Traceback (most recent call last):
File "/home/user/projects/local_chatbot/test.py", line 34, in <module>
model.fit(x_train, y_train, batch_size=64, epochs=5)
File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/user/.pyenv/versions/3.11.12/lib/python3.11/site-packages/keras/src/backend/tensorflow/core.py", line 66, in _direct_assign
self._value.assign(tf.cast(value, self._value.dtype))
RuntimeError: Exception encountered when calling Cond.call().
merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn. If you are subclassing a tf.keras.Model, please avoid decorating overridden methods test_step and train_step in tf.function.
Arguments received by Cond.call():
• args=('tf.Tensor(shape=(), dtype=bool)', '<function LossScaleOptimizer._common_apply.<locals>.<lambda> at 0x72834df60d60>', '<bound method LossScaleOptimizer._stateful_handle_non_finite_grads of <keras.src.optimizers.loss_scale_optimizer.LossScaleOptimizer object at 0x72835bc67190>>')
• kwargs=<class 'inspect._empty'>
Expected Behavior
The model should train normally using XLA with jit_compile=True
, under MirroredStrategy
and mixed_precision
, just as it does when not using MirroredStrategy
.
Workarounds Tried
- Disabling
jit_compile=True
→ works, but no XLA, slow - Disabling
mixed_precision
→ works, but no float16 - Disabling
MirroredStrategy
→ works - Disabling LossScaleOptimizer in compile → works
- Custom training loop → slow / inconsistent behavior
jit_compile=False
+mixed_precision
+MirroredStrategy
→ works, but no compiler benefits
Conclusion
There seems to be an internal conflict between LossScaleOptimizer
, XLA (jit_compile=True
), and MirroredStrategy
.
Please confirm if:
- This combination is supposed to be supported
- There is a planned fix
- Or if a documented workaround exists
Let me know if you need a Colab notebook or system diagnostics.
Thanks!