Issue: keras.distribution.ModelParallel with from_preset() causes OOM errors when loading models larger than single-device memory. Root Cause: Model weights are downloaded and fully instantiated in memory before distribution/sharding is applied, rather than being distributed during the loading process.

Evidence Working Example: [Gemma2 2B notebook from last week] - Memory properly distributed across 8 TPU devices Failing Example: [Gemma2 9B notebook] - OOM during loading despite having sufficient total TPU memory Cross-Platform Confirmation: Similar issue occurs on Kaggle runtime (different Python version and library versions), suggesting it's not environment-specific.

Error Trace Pattern Consistent failure path across platforms: keras_hub/src/models/task.py → keras_hub/src/utils/preset_utils.py → keras/src/saving/serialization_lib.py → keras/src/layers/init.py → keras_hub/src/models/gemma/gemma_backbone.py → keras_hub/src/layers/modeling/reversible_embedding.py → jax/_src/numpy/array_methods.py → jax/_src/interpreters/pxla.py:1362 → results = self.xla_executable.execute_sharded(input_bufs)

Additional Issues Kaggle TPU: Consistent availability but TensorBoard profiling plugin incompatible with runtime Colab TPU: Libtpu library crashes runtime (logs attached) Impact: Prevents loading models >2B parameters despite having sufficient distributed memory capacity.

Through debugging, I believe I've traced this to a JAX canonicalization change that's breaking distributed model loading.

Possible Evidence of JAX Root Cause: I traced the issue to JAX commit 19a5055 by Peter Hawkins: "Remove most canonicalization from MLIR lowering" The commit message explicitly states: "There is still one place where we need the older canonicalization behavior that forms a NumPy array for a Python scalar, namely in pxla.shard_args." This matches exactly where our error occurs - in the sharding code (pxla.py) during distributed loading.

Request: This looks like a JAX regression that's hitting Keras distributed loading. Could you help validate this isn't actually a Keras issue and maybe file a JAX bug report together?