Hello, i have tried using training loop from scratch with jax in keras-3 but the model doesn't converge in training. i tested the model.fit() API to check the issue and it perfectly trains the model and updates the weights.

This is the gist of the issue: https://gist.github.com/heydaari/0645956a773fac764cd974f5f1850ea3

Image

The first logs are the last epoch logs of training loop from scratch, and the logs from model.fit()

Comment From: heydaari

Isn't there any updates on this one? @abheesht17

Comment From: fchollet

Took a quick look. There's no bug -- your training loop is 100% correct.

You do make one mistake though -- loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) should be loss_fn = keras.losses.CategoricalCrossentropy(). The softmax is included in the model (you could also remove the softmax from the model....)

I tried training a smaller/better model after fixing this (Xception) without changing anything else in your code and I see that it converges quickly.