Hello
I am encountering a problem with keras.layers.Reshape . I am currently running 3.9, but given that the related code did not change in 3.10, the problem will be the same.
in contrast to my expectation. and to the behavior I was accustomed to from keras-2, the keras 3 version of the reshape operation statically fixes all dimensions except the batch dimension. The current version of the code is here https://github.com/keras-team/keras/blob/3bedb9a970394879360fcb1c0264f3ffdc634a77/keras/src/layers/reshaping/reshape.py#L56:
def build(self, input_shape):
sample_output_shape = operation_utils.compute_reshape_output_shape(
input_shape[1:], self.target_shape, "target_shape"
)
self._resolved_target_shape = tuple(
-1 if d is None else d for d in sample_output_shape
)
def call(self, inputs):
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
)
Clearly, this will reduce the possible uses of the reshape layer compared to the ops.reshape operator, which can handle a dimension with -1 without problem. It appears straightforward to extend the use of the Reshape layer to preserve the flexibility of the ops.reshape operator. Is there any reason why this is kept static? If not I may provide a suggestion for a fix.
Thanks for your comments.
Comment From: hertschuh
@roebel
Thank you for the report.
Can you provide an example of an issue that you're encountering? I would have expected the PR to have a unit test that fails before and passes after.
Clearly, this will reduce the possible uses of the reshape layer compared to the ops.reshape operator, which can handle a dimension with -1 without problem
The Reshape
layer can handle of -1
both with static and dynamic dimensions (and dynamic batch size). Instead redoing the work of calculating the dynamically resolved shape, it just lets ops.reshape
handle it. No need to duplicated that work.
It does so by having self._resolved_target_shape
contain a -1
in it if it could not be statically resolved.
Comment From: roebel
It appears I misinterpreted the current implementation, notably the name self._resolved_target_shape
, which, given the name, I understood would resolve the -1 in all cases. I now see that it only resolves the shape when the input shape is not dynamic. This indeed is exactly my use case. I share here below a small test script that shows the problem.
The script runs file when using it with keras 2.14.
KERAS_INSTALL=system ./shell/manual_reshape_test.py
fails with all backends under keras 3.9
for backend in torch tensorflow jax; do
KERAS_INSTALL=system KERAS_BACKEND=$backend ./shell/manual_reshape_test.py
done
and runs fine with the proposed fix for all backends
for backend in torch tensorflow jax; do
KERAS_INSTALL=local KERAS_BACKEND=$backend ./shell/manual_reshape_test.py
done
here the test script.
import os
import sys
from pathlib import Path
import numpy as np
# The path to your local 'src' directory
if os.environ.get("KERAS_INSTALL", "LOCAL").upper() == "SYSTEM":
from keras import Model, layers, backend
print(f"system wide keras layers implementation {layers.__file__}, backend {backend.backend()}")
else:
local_src_path = str(Path(__file__).parents[1])
sys.path.insert(0, local_src_path)
from keras.src import Model, layers, backend
print( f"local keras layers implementation {layers.__file__}, backend {backend.backend()}")
class MM(Model):
def __init__(self):
super().__init__()
self.conv = layers.Conv1D(4, 3, padding="same")
self.reshape = layers.Reshape((-1, 8))
def call(self, inputs):
return self.reshape(self.conv(inputs))
m = MM()
res = m(np.ones((1, 6, 2), dtype="float32"))
assert res.shape == (1, 3, 8), f"Expected shape (1, 3, 8), got {res.shape}"
res = m(np.ones((1, 10, 2), dtype="float32"))
assert res.shape == (1, 5, 8), f"Expected shape (1, 5, 8), got {res.shape}"
print("Custom reshape model test passed successfully.")
Comment From: hertschuh
@roebel
m = MM()
res = m(np.ones((1, 6, 2), dtype="float32"))
assert res.shape == (1, 3, 8), f"Expected shape (1, 3, 8), got {res.shape}"
res = m(np.ones((1, 10, 2), dtype="float32"))
assert res.shape == (1, 5, 8), f"Expected shape (1, 5, 8), got {res.shape}"
What happens in this case, and the reason why it's failing without your fix is that:
- when m(np.ones((1, 6, 2), dtype="float32"))
is called, it calls build
with shape (1, 6, 2)
- when m(np.ones((1, 10, 2), dtype="float32"))
is called, it fails because the shape of the input doesn't match (1, 6, 2)
, which was passed in build
.
So from that point of view, what's missing in call
is a check that the input conforms to what was passed in build
.
But then, what if you do want a dynamic dimension? Well in theory, you just need to build with some None
s. But to make this work, MM
needs to implement build
class MM(Model):
def __init__(self):
super().__init__()
self.conv = layers.Conv1D(4, 3, padding="same")
self.reshape = layers.Reshape((-1, 8))
def build(self, input_shape):
self.conv.build(input_shape)
self.reshape.build(input_shape[0:-1] + (4,))
def call(self, inputs):
return self.reshape(self.conv(inputs))
m = MM()
m.build((None, None, 2))
res = m(np.ones((1, 6, 2), dtype="float32"))
assert res.shape == (1, 3, 8), f"Expected shape (1, 3, 8), got {res.shape}"
res = m(np.ones((1, 10, 2), dtype="float32"))
assert res.shape == (1, 5, 8), f"Expected shape (1, 5, 8), got {res.shape}"
print("Custom reshape model test passed successfully.")m = MM()
Now the test passes without your change. Two things though:
- In this example, there's a baked in assumption that the second None
represents an even number, otherwise it fails, which you can verify by doing m(np.ones((1, 3, 2), dtype="float32"))
. But there's no way to represent that.
- All these tests are run eagerly, which is not how we really want to use this. Which lead me to the next section.
Moving the logic from build
to call
seems reasonable. But you have to realize that things get compiled, call
get called with a particular shape, and from that point on the Python code is no longer called.
I'll comment on the PR.