Issue: keras.distribution.ModelParallel with from_preset() causes OOM errors when loading models larger than single-device memory. Root Cause: Model weights are downloaded and fully instantiated in memory before distribution/sharding is applied, rather than being distributed during the loading process.

Evidence Working Example: [Gemma2 2B notebook from last week] - Memory properly distributed across 8 TPU devices Failing Example: [Gemma2 9B notebook] - OOM during loading despite having sufficient total TPU memory Cross-Platform Confirmation: Similar issue occurs on Kaggle runtime (different Python version and library versions), suggesting it's not environment-specific.

Error Trace Pattern Consistent failure path across platforms: keras_hub/src/models/task.py → keras_hub/src/utils/preset_utils.py → keras/src/saving/serialization_lib.py → keras/src/layers/init.py → keras_hub/src/models/gemma/gemma_backbone.py → keras_hub/src/layers/modeling/reversible_embedding.py → jax/_src/numpy/array_methods.py → jax/_src/interpreters/pxla.py:1362 → results = self.xla_executable.execute_sharded(input_bufs)

Additional Issues Kaggle TPU: Consistent availability but TensorBoard profiling plugin incompatible with runtime Colab TPU: Libtpu library crashes runtime (logs attached) Impact: Prevents loading models >2B parameters despite having sufficient distributed memory capacity.

Through debugging, I believe I've traced this to a JAX canonicalization change that's breaking distributed model loading.

Possible Evidence of JAX Root Cause: I traced the issue to JAX commit 19a5055 by Peter Hawkins: "Remove most canonicalization from MLIR lowering" The commit message explicitly states: "There is still one place where we need the older canonicalization behavior that forms a NumPy array for a Python scalar, namely in pxla.shard_args." This matches exactly where our error occurs - in the sharding code (pxla.py) during distributed loading.

Request: This looks like a JAX regression that's hitting Keras distributed loading. Could you help validate this isn't actually a Keras issue and maybe file a JAX bug report together?

Comment From: MalyalaKarthik66

I tried loading a big model with ModelParallel and ran into the same OOM error. Looks like the weights are loaded fully on one device before distribution. Can I work on modifying it so the model is sharded while loading to avoid this?

Comment From: amitsrivastava78

@MalyalaKarthik66 please go ahead and provide the solution, your contribution is welcome!

Comment From: amitsrivastava78

@MalyalaKarthik66 any update on the issue ?

Comment From: MalyalaKarthik66

I sincerely apologie for the delay, I'm currently having exams, so I couldn’t work on this. I’ll get back to it once I’m free and update here with my progress. Thanks for your patience.

Comment From: amitsrivastava78

@MalyalaKarthik66 since you were busy i have tried to fix the issue, PR https://github.com/keras-team/keras/pull/21712