In tensorflow, we can do


a = tf.ones((1, 96, 96, 96, 4))   # add batch dim → (B, D, H, W, C)
p = 4
patches = tf.extract_volume_patches(
    input=a, 
    ksizes=[1, p, p, p, 1],
    strides=[1, p, p, p, 1],
    padding='VALID'
)

print(patches.shape)
(1, 24, 24, 24, 256)

This feature is missing in keras.ops.image.extract_patches.


Workaround if anyone is looking for

def extract_volume_patches_simple(x, patch_size):
    batch_size, depth, height, width, channels = ops.shape(x)

    # Calculate patch counts
    d_patches = depth // patch_size[0]
    h_patches = height // patch_size[1]
    w_patches = width // patch_size[2]

    # Reshape directly to the target shape
    patches = ops.reshape(
        x,
        [
            batch_size,
            d_patches, patch_size[0],
            h_patches, patch_size[1],
            w_patches, patch_size[2],
            channels
        ]
    )

    # Now reshape to final form
    patches = ops.reshape(
        patches,
        [
            batch_size,
            d_patches,
            h_patches,
            w_patches,
            patch_size[0] * patch_size[1] * patch_size[2] * channels
        ]
    )

    return patches

# Test
a = ops.ones((1, 96, 96, 96, 4))
p = 4
patches = extract_volume_patches_simple(a, (p, p, p))
print(patches.shape) 
(1, 24, 24, 24, 256)