Using the JAX backend for LSTM / GRU models, I'm unable to see any speed-up when training with 2 Nvidia 3090 vs using a single Nvidia 3090 (using keras-nightly and JAX 0.5.2). The distributed training across 2 GPUs seems to work fine, but it is just not faster and maybe even slower. See attached file for a modified version of the Keras timeseries weather forecasting example that showcases the problem.
I also can't seem to find any "official" Keras / Keras-IO example showing distributed training with a measurement of the training time. Shouldn't there be such an "official" example to showcase the gain by multi-device training?
timeseries_weather_forecasting_LC.zip
Comment From: larschristensen
Having looked more into this issue, it turns out I'm able to see a speed-up for very large batch sizes, e.g. 65536. However, using such a large batch size is likely not practical for most model trainings.
The lack of speed-up for "normal" batch size seems to be the result of the way lax.scan
is implemented in JAX / XLA, see e.g. https://github.com/jax-ml/jax/discussions/25336 and links therein for a good overview. It therefore looks like this is really a bottleneck in JAX / XLA and not Keras. However, it is proably good to monitor the development of this in JAX / XLA to see if any improvements made there can directly benefit Keras.
Comment From: sonali-kumari1
Hi @larschristensen -
Please provide the reproduced issue in the form of a colab notebook, this ensures that the issue is easily reproducible on standard hardware.
Comment From: larschristensen
@sonali-kumari1 Here is a colab notebook showing the lack of RNN distribution speedup on a v2-8 TPU: https://colab.research.google.com/drive/1fOask5O9zNfbM4SGfFjpYs4qtYub8HdD?usp=sharing
Comment From: sonali-kumari1
Hi @larschristensen -
I have tested the code with latest version of keras(3.10.0) using v2-8 TPU in this gist. The model trains successfully with both 1 and 2 TPU cores but there was no significant speedup when using 2 TPU cores which is a known bottleneck JAX / XLA for lax.scan
. I also tested the code with T4-GPU and observed it to be slower than TPU.
Additionally, when trying large batch sizes(32768 and 65536), I encountered AttributeError: module 'jax._src.interpreters.partial_eval' has no attribute 'ReadWrite'
which appears to be an internal compatibility issue with jax. We will look into this and update you. Thanks!
Comment From: mattdangerw
@sonali-kumari1 I don't see the AttributeError: module 'jax._src.interpreters.partial_eval' has no attribute 'ReadWrite'
, could your runtime have gotten messed up or something? Or this was a issue on a different version of jax than is currently preinstalled with colab?
As for the original bug thanks @larschristensen for filing. It sounds like the issue resides with Jax here, but we can leave this open to continue monitoring.