I am encountering a bug with ops.pad
and the Tensorflow backend (I am using Keras 3.11.1 and Tensorflow 2.19.0., currently running on CPU).
Below is a MWE, where I define a layer that applies zero padding to an input. Afterwards I compare the sum of the input with the sum of the output using ops.sum
.
Since the layer pads with zeros, the sums should be exactly the same.
With Tensorflow and JAX backend, I regularly get a deviation of up to ~1e-3
.
From my experience, this is only reproducible when using ops.pad
within a layers.Layer
. Therefore, it may have something to do with graph execution?
With the following code, I am able to reproduce the problem (you may have to execute the code multiple times to trigger the misbehavior).
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
from keras import ops
from keras import layers
import numpy as np
class PadArray(layers.Layer):
def __init__(
self,
pad_width,
**kwargs
):
super().__init__(**kwargs)
self.pad_width = pad_width
def call(self, inputs):
return ops.pad(inputs, pad_width=self.pad_width, mode="constant", constant_values=0)
def build(self, input_shape):
super().build(input_shape=input_shape)
# get layer, get some random inputs and call layer
layer = PadArray(pad_width=((32, 11), (44, 16), (11, 51)))
x = np.random.random(size=(64, 64, 64))
x = ops.cast(x, dtype="float32")
layer.build(x.shape)
y = layer(x)
# compare sum of x and y
deviation = ops.sum(x) - ops.sum(y)
assert ops.sum(x) == ops.sum(y), f"Deviation is {deviation}"
Comment From: MathiesW
I just saw that Keras officially implements ZeroPadding(n)D
layers. Here, I encounter the same problem of thesum of the tensor changes when applying the zero-padding, suggesting that the layer either does not pad with zeros, or modifies nonzero entries in the initial tensor.
from keras import layers
pad3d = layers.ZeroPadding3D(padding=((32, 11), (44, 16), (11, 51)))
x = np.random.random(size=(1, 64, 64, 64, 1))
x = ops.cast(x, dtype="float32")
pad3d.build(x.shape)
y = pad3d(x)
deviation = ops.sum(x) - ops.sum(y)
assert deviation == 0.0, f"Deviation is {deviation}"
Comment From: mehtamansi29
Hi @MathiesW - Here you are try to compare value directly with equal to reference value without any tolerance. For comparing x and y sum, you can use np.testing.assert_allclose
. It will pass the assertion with tensorflow and jax backend with deviation upto 1e-3.
Attached gist when code is running fine with np.testing.assert_allclose
for tensorflow and jax backend.
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: github-actions[bot]
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.