https://github.com/keras-team/keras/blob/c7b66fc23fbeb7197d7c4a6491d7fda6d3502c6b/keras/src/trainers/data_adapters/init.py#L93-L97

The model.fit() method doesn't work on torch's DataLoaders.

Having dugged into the source code, I understand what's going on internally: 1) TFDatasetAdapter receives tf.data.Dataset 2) on top of it TFDatasetAdapter maps a function that: - extracts the the respective weight for the label - adds it as a third argument of the data pipeline

Having looked into the inner functionality of the training step, I can tell that it simply unpacks the data into three arguments: input tensor, label, and class weight. So to solve this issue I have to do redefine __getitem__ of the torch Dataset, to make sure it outputs the correct class weight as a third argument.

Now, I understand why it may not be technically feasible to map a custom function on top of DataLoaders. Whatever, the solution is simple enough. But to find it, I had to look through a bunch of source code, which definitely shouldn't be the case.

So this error message is just not helpful for two reasons: 1) It's obscure that I simply cannot do that. By looking at documentation of model.fit() I have no idea that I cannot use class_weight with torch DataLoader 2) The error message itself does not offer any solutions for the problem. A guide, a note, anything helpful, basically

Comment From: sonali-kumari1

Hi @DLumi -

Thanks for the detailed information. I have reproduced the issue with latest version of keras(3.10.0) in this gist. While Argument class_weight is not directly supported with Pytorch DataLoaders, the keras documentation provides a way to support class_weight and sample_weight by using a custom training loop and overriding train_step method in your model.

Comment From: DLumi

@sonali-kumari1

Custom training loop is kind of an overkill here. The default one works just fine if your generator outputs 3 elements: input tensor, label, class weight. What I'm saying here, it's not obvious in any way that you can even do that (and this is what needs to be fixed).

Besides, it would also be a good idea to: 1) add a more detailed description to the error when trying to fit the model on a pytorch dataloader; 2) link / mention the guide you provided to the API documentation, as I wouldn't even think of browsing guides for this particular case https://keras.io/api/models/model_training_apis/#fit-method

Comment From: sonali-kumari1

@DLumi I have tested modifying __getitem__ method to return input tensor, label and class_weight and calculating class weights using sklearn.utils.compute_class_weight() and it works in this gist. We will look into this and update you. Thanks!