Using the JAX backend for LSTM / GRU models, I'm unable to see any speed-up when training with 2 Nvidia 3090 vs using a single Nvidia 3090 (using keras-nightly and JAX 0.5.2). The distributed training across 2 GPUs seems to work fine, but it is just not faster and maybe even slower. See attached file for a modified version of the Keras timeseries weather forecasting example that showcases the problem.

I also can't seem to find any "official" Keras / Keras-IO example showing distributed training with a measurement of the training time. Shouldn't there be such an "official" example to showcase the gain by multi-device training?

timeseries_weather_forecasting_LC.zip

Comment From: larschristensen

Having looked more into this issue, it turns out I'm able to see a speed-up for very large batch sizes, e.g. 65536. However, using such a large batch size is likely not practical for most model trainings.

The lack of speed-up for "normal" batch size seems to be the result of the way lax.scan is implemented in JAX / XLA, see e.g. https://github.com/jax-ml/jax/discussions/25336 and links therein for a good overview. It therefore looks like this is really a bottleneck in JAX / XLA and not Keras. However, it is proably good to monitor the development of this in JAX / XLA to see if any improvements made there can directly benefit Keras.

Comment From: sonali-kumari1

Hi @larschristensen -

Please provide the reproduced issue in the form of a colab notebook, this ensures that the issue is easily reproducible on standard hardware.

Comment From: larschristensen

@sonali-kumari1 Here is a colab notebook showing the lack of RNN distribution speedup on a v2-8 TPU: https://colab.research.google.com/drive/1fOask5O9zNfbM4SGfFjpYs4qtYub8HdD?usp=sharing

Comment From: sonali-kumari1

Hi @larschristensen -

I have tested the code with latest version of keras(3.10.0) using v2-8 TPU in this gist. The model trains successfully with both 1 and 2 TPU cores but there was no significant speedup when using 2 TPU cores which is a known bottleneck JAX / XLA for lax.scan. I also tested the code with T4-GPU and observed it to be slower than TPU.

Additionally, when trying large batch sizes(32768 and 65536), I encountered AttributeError: module 'jax._src.interpreters.partial_eval' has no attribute 'ReadWrite' which appears to be an internal compatibility issue with jax. We will look into this and update you. Thanks!

Comment From: mattdangerw

@sonali-kumari1 I don't see the AttributeError: module 'jax._src.interpreters.partial_eval' has no attribute 'ReadWrite', could your runtime have gotten messed up or something? Or this was a issue on a different version of jax than is currently preinstalled with colab?

As for the original bug thanks @larschristensen for filing. It sounds like the issue resides with Jax here, but we can leave this open to continue monitoring.