A simple model like this:
inputs = keras.Input(shape=(sequence_length, raw_data.shape[-1]))
x = layers.Flatten()(inputs)
x = layers.Dense(16, activation="relu")(x)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
history = model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
)
Produces this error with the TensorFlow backend:
Traceback (most recent call last):
File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 113, in error_handler
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
logs = self.train_function(iterator)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
opt_outputs = multi_step_on_iterator(iterator)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
except TypeError as e:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node functional_1_1/flatten_1_1/Reshape defined at (most recent call last):
<stack traces unavailable>
only one input size may be -1, not both 0 and 1
Stack trace for op definition:
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/backend/tensorflow/trainer.py", line 377, in fit
File "keras/src/backend/tensorflow/trainer.py", line 220, in function
File "keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File "keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File "keras/src/backend/tensorflow/trainer.py", line 58, in train_step
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/layers/layer.py", line 936, in __call__
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/ops/operation.py", line 76, in __call__
File "keras/src/models/functional.py", line 183, in call
File "keras/src/ops/function.py", line 177, in _run_through_graph
File "keras/src/models/functional.py", line 648, in call
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/layers/layer.py", line 936, in __call__
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/ops/operation.py", line 76, in __call__
File "keras/src/layers/reshaping/flatten.py", line 54, in call
File "keras/src/ops/numpy.py", line 5074, in reshape
File "keras/src/backend/tensorflow/numpy.py", line 2068, in reshape
[[{{node functional_1_1/flatten_1_1/Reshape}}]]
tf2xla conversion failed while converting __inference_one_step_on_data_2408[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
[[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_2449]
No similar error with if I replace the Flatten
with a Reshape([-1])
.
Also, no similar error with the JAX backend.