https://github.com/keras-team/keras/blob/c7b66fc23fbeb7197d7c4a6491d7fda6d3502c6b/keras/src/trainers/data_adapters/init.py#L93-L97
The model.fit()
method doesn't work on torch's DataLoaders.
Having dugged into the source code, I understand what's going on internally:
1) TFDatasetAdapter
receives tf.data.Dataset
2) on top of it TFDatasetAdapter
maps a function that:
- extracts the the respective weight for the label
- adds it as a third argument of the data pipeline
Having looked into the inner functionality of the training step, I can tell that it simply unpacks the data into three arguments: input tensor, label, and class weight. So to solve this issue I have to do redefine __getitem__
of the torch Dataset, to make sure it outputs the correct class weight as a third argument.
Now, I understand why it may not be technically feasible to map a custom function on top of DataLoaders. Whatever, the solution is simple enough. But to find it, I had to look through a bunch of source code, which definitely shouldn't be the case.
So this error message is just not helpful for two reasons:
1) It's obscure that I simply cannot do that. By looking at documentation of model.fit()
I have no idea that I cannot use class_weight
with torch DataLoader
2) The error message itself does not offer any solutions for the problem. A guide, a note, anything helpful, basically
Comment From: sonali-kumari1
Hi @DLumi -
Thanks for the detailed information. I have reproduced the issue with latest version of keras(3.10.0) in this gist. While Argument class_weight
is not directly supported with Pytorch DataLoaders, the keras documentation provides a way to support class_weight and sample_weight by using a custom training loop and overriding train_step
method in your model.
Comment From: DLumi
@sonali-kumari1
Custom training loop is kind of an overkill here. The default one works just fine if your generator outputs 3 elements: input tensor, label, class weight. What I'm saying here, it's not obvious in any way that you can even do that (and this is what needs to be fixed).
Besides, it would also be a good idea to: 1) add a more detailed description to the error when trying to fit the model on a pytorch dataloader; 2) link / mention the guide you provided to the API documentation, as I wouldn't even think of browsing guides for this particular case https://keras.io/api/models/model_training_apis/#fit-method
Comment From: sonali-kumari1
@DLumi
I have tested modifying __getitem__
method to return input tensor, label and class_weight
and calculating class weights using sklearn.utils.compute_class_weight() and it works in this gist. We will look into this and update you. Thanks!