When using multiple GPUs (n > 1) with a batch size that's not divisible by the number of GPUs, you'll encounter a ZeroDivisionError. I understand that library needs to split the batch evenly across devices. In tf.data api, something call drop_reminder exist. Now, my concern is in case of evaluation or inference, we can't drop any sample in order to measure performance. That makes it to do some sort of padding or something.

import
os.environ["KERAS_BACKEND"] = "jax"

devices = jax.devices("gpu")
[CudaDevice(id=0), CudaDevice(id=1)]

import keras 
data_parallel = keras.distribution.DataParallel(devices=devices)
keras.distribution.set_distribution(data_parallel)

import tensorflow as tf
from keras import layers, models
import numpy as np

def create_2d_model(input_shape=(100, 100, 3)):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])

    return model

def create_dummy_dataset(batch_size=32, image_shape=(100, 100, 3)):
    def generator():
        while True:
            images = np.random.rand(*image_shape).astype(np.float32)
            label = np.random.randint(0, 2, size=(1,)).astype(np.float32)
            yield images, label

    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=image_shape, dtype=tf.float32),
            tf.TensorSpec(shape=(1,), dtype=tf.float32)
        )
    )
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

model = create_2d_model((100, 100, 3))
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

train_dataset = create_dummy_dataset(
    batch_size=32, 
    image_shape=(100, 100, 3)
)
val_dataset = create_dummy_dataset(
    batch_size=1, 
    image_shape=(100, 100, 3)
)

history = model.fit(
    train_dataset,
    steps_per_epoch=100,
    epochs=2,
    validation_data=val_dataset,
)
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
/tmp/ipykernel_31/1415789567.py in <cell line: 0>()
     41 val_dataset = create_dummy_dataset(batch_size=1, image_shape=(100, 100, 3))
     42 
---> 43 history = model.fit(
     44     train_dataset,
     45     steps_per_epoch=100,

/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.11/dist-packages/optree/ops.py in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
    764     leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
    765     flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 766     return treespec.unflatten(map(func, *flat_args))
    767 
    768 

ZeroDivisionError: integer modulo by zero

In short

inputs = [1, 2, ,3, 4, 5]
replica = num_of_available_gpu # i.e. 2
batch size = 1 * replica
data_gen = [1,2], [3, 4], [5]

The model should be able to handle data generation, right? The remaining data would be allocated to the available GPU.

Comment From: vulkomilev

let me see that

Comment From: vulkomilev

I have reproduced the bug I am on it

Comment From: vulkomilev

It looks like the treespec.unflatten function need the correct number otherwise it will throw this error. I am not sure if a fix is needed

Comment From: innat

Do you mean that its an expected error?

Comment From: divyashreepathihalli

Looks like there is no way to rebatch to given GPU size either. I am thinking raising a meaningful error for this would make more sense?

Comment From: innat

@divyashreepathihalli That would be useful.

(However, inference with such case would be an issue forever, unless we unpack dataloader and pass them sample by sample -- and that maybe limits the usage of multi-gpu for inference. For example: samples [1,2,3,4,5] and with 2 gpu, with batch size 2 - it would be nice if sample no. 5 get assigned to any available gpu -- just a thought. This would require some work btw.)