Hello, i have tried using training loop from scratch with jax in keras-3 but the model doesn't converge in training. i testes the model.fit() API to check the issue and it perfectly trains the model and updates the weights.
This is the gist of the issue: https://gist.github.com/heydaari/0645956a773fac764cd974f5f1850ea3
The first logs are the last epoch logs of training loop from scratch, and the logs from model.fit()