In tensorflow, we can do
a = tf.ones((1, 96, 96, 96, 4)) # add batch dim → (B, D, H, W, C)
p = 4
patches = tf.extract_volume_patches(
input=a,
ksizes=[1, p, p, p, 1],
strides=[1, p, p, p, 1],
padding='VALID'
)
print(patches.shape)
(1, 24, 24, 24, 256)
This feature is missing in keras.ops.image.extract_patches
.
Workaround if anyone is looking for
def extract_volume_patches_simple(x, patch_size):
batch_size, depth, height, width, channels = ops.shape(x)
# Calculate patch counts
d_patches = depth // patch_size[0]
h_patches = height // patch_size[1]
w_patches = width // patch_size[2]
# Reshape directly to the target shape
patches = ops.reshape(
x,
[
batch_size,
d_patches, patch_size[0],
h_patches, patch_size[1],
w_patches, patch_size[2],
channels
]
)
# Now reshape to final form
patches = ops.reshape(
patches,
[
batch_size,
d_patches,
h_patches,
w_patches,
patch_size[0] * patch_size[1] * patch_size[2] * channels
]
)
return patches
# Test
a = ops.ones((1, 96, 96, 96, 4))
p = 4
patches = extract_volume_patches_simple(a, (p, p, p))
print(patches.shape)
(1, 24, 24, 24, 256)