It seems that we need a Keras JIT program. Now that torch.compile, tf.function, and jax.jit can all implement JIT functionality, implementing keras.jit doesn't appear to be difficult. The current issue is how we should define the interface.
https://docs.jax.dev/en/latest/_autosummary/jax.jit.html https://tensorflow.google.cn/api_docs/python/tf/function https://docs.pytorch.org/docs/stable/generated/torch.compile.html