Hello, i have tried using training loop from scratch with jax in keras-3 but the model doesn't converge in training. i tested the model.fit() API to check the issue and it perfectly trains the model and updates the weights.
This is the gist of the issue: https://gist.github.com/heydaari/0645956a773fac764cd974f5f1850ea3
The first logs are the last epoch logs of training loop from scratch, and the logs from model.fit()
Comment From: heydaari
Isn't there any updates on this one? @abheesht17
Comment From: fchollet
Took a quick look. There's no bug -- your training loop is 100% correct.
You do make one mistake though -- loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
should be loss_fn = keras.losses.CategoricalCrossentropy()
. The softmax is included in the model (you could also remove the softmax from the model....)
I tried training a smaller/better model after fixing this (Xception
) without changing anything else in your code and I see that it converges quickly.