I get the following error when running examples/demo_jax_distributed.py on a Cloud TPU VM (tpuv2):

Traceback (most recent call last):
 File "/home/colby/jax_test.py", line 341, in <module>
    loss, accuracy = model.evaluate(eval_data)
  File "/home/colby/.local/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/colby/.local/lib/python3.10/site-packages/keras_core/src/backend/jax/trainer.py", line 597, in evaluate
    logs, state = self.test_function(state, data)
ValueError: Received incompatible devices for jitted computation. Got argument state[0][0] of JAXTrainer.make_test_function.<locals>.compiled_test_step with shape float32[3,3,1,12] and device ids [0, 1, 2, 3, 6, 7, 4, 5] on platform TPU and sharding_constraint inside jit with device ids [0] on platform TPU at /home/colby/.local/lib/python3.10/site-packages/keras_core/src/backend/jax/trainer.py:910 (_enforce_jax_state_sharding)

Packaged Versions: - Keras-core: 0.1.7 - jax[tpu]: 0.4.19

Comment From: qlzh727

Thanks for the reporting. Let me take a look.

Comment From: sonali-kumari1

Hi @colbybanbury - Could you please confirm whether this issue is still reproducible with latest versions of Keras and JAX? Thanks!

Comment From: github-actions[bot]

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.