A simple model like this:

inputs = keras.Input(shape=(sequence_length, raw_data.shape[-1]))
x = layers.Flatten()(inputs)
x = layers.Dense(16, activation="relu")(x)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.compile(optimizer="adam", loss="mse", metrics=["mae"])
history = model.fit(
    train_dataset,
    epochs=10,
    validation_data=val_dataset,
)

Produces this error with the TensorFlow backend:

Traceback (most recent call last):
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
    logs = self.train_function(iterator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
    opt_outputs = multi_step_on_iterator(iterator)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/EYMKzeu1A1WR-04db7vYi/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    except TypeError as e:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node functional_1_1/flatten_1_1/Reshape defined at (most recent call last):
<stack traces unavailable>
only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/backend/tensorflow/trainer.py", line 377, in fit
File "keras/src/backend/tensorflow/trainer.py", line 220, in function
File "keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File "keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File "keras/src/backend/tensorflow/trainer.py", line 58, in train_step
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/layers/layer.py", line 936, in __call__
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/ops/operation.py", line 76, in __call__
File "keras/src/models/functional.py", line 183, in call
File "keras/src/ops/function.py", line 177, in _run_through_graph
File "keras/src/models/functional.py", line 648, in call
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/layers/layer.py", line 936, in __call__
File "keras/src/utils/traceback_utils.py", line 113, in error_handler
File "keras/src/ops/operation.py", line 76, in __call__
File "keras/src/layers/reshaping/flatten.py", line 54, in call
File "keras/src/ops/numpy.py", line 5074, in reshape
File "keras/src/backend/tensorflow/numpy.py", line 2068, in reshape

     [[{{node functional_1_1/flatten_1_1/Reshape}}]]
    tf2xla conversion failed while converting __inference_one_step_on_data_2408[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
     [[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_2449]

No similar error with if I replace the Flatten with a Reshape([-1]). Also, no similar error with the JAX backend.

Comment From: sonali-kumari1

Hi @t-kalinowski - The error traceback indicates that tensorflow is encountering conflicting dimensions during reshaping operations. It seems Flatten() is trying to reshape the input in a way that involves both 0 and 1 as dimensions, which is causing the graph execution error. As you mentioned, replacing the Flatten()layer with a Reshape([-1]) works because it ensures that only one dimension is being inferred, resolving the error. To help reproduce this issue, Could you please provide the shape of your input data(raw_data) or share a sample dataset? Thanks!

Comment From: t-kalinowski

Here is a minimal reproducible example. The key requirement is that the 2nd dataset axis (sequence_length) must be unknown at model def time. This only happens with the tensorflow backend:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
from keras import layers

sequence_length = 120
ncol_input_data = 14

num_samples = 1000
data = tf.random.normal((num_samples, ncol_input_data))
targets = tf.random.normal((num_samples,))

def make_dataset():
    return keras.utils.timeseries_dataset_from_array(
        data=data,
        targets=targets,
        sequence_length=sequence_length,
        batch_size=16,
    )

train_dataset = make_dataset()
val_dataset = make_dataset()
test_dataset = make_dataset()

inputs = keras.Input(shape=(sequence_length, ncol_input_data))
x = layers.Flatten()(inputs)
x = layers.Dense(16, activation="relu")(x)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.compile(optimizer="adam", loss="mse", metrics=["mae"])
model.fit(train_dataset, epochs=10, validation_data=val_dataset)

print(f"Test MAE: {model.evaluate(test_dataset)[1]:.2f}")

This is the traceback I see:

$ /home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/bin/python3 -i  keras-flatten-bug.py 
2025-06-16 16:28:20.982895: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750105700.995195   37017 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750105700.999142   37017 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750105701.009015   37017 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750105701.009035   37017 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750105701.009037   37017 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750105701.009057   37017 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-06-16 16:28:21.012363: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1750105703.068555   37017 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9615 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5
I0000 00:00:1750105703.069800   37017 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 668 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:03:00.0, compute capability: 7.5
Epoch 1/10
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1750105704.192654   37074 service.cc:152] XLA service 0x7051e8005040 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1750105704.192669   37074 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
I0000 00:00:1750105704.192674   37074 service.cc:160]   StreamExecutor device (1): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-06-16 16:28:24.207224: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-06-16 16:28:24.226577: W tensorflow/core/framework/op_kernel.cc:1857] OP_REQUIRES failed at xla_ops.cc:591 : INVALID_ARGUMENT: only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
File "keras-flatten-bug.py", line 37, in <module>
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 58, in train_step
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 183, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/function.py", line 177, in _run_through_graph
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 648, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/reshaping/flatten.py", line 54, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/numpy.py", line 5074, in reshape
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/numpy.py", line 2068, in reshape

     [[{{node functional_1/flatten_1/Reshape}}]]
    tf2xla conversion failed while converting __inference_one_step_on_data_1227[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
2025-06-16 16:28:24.226621: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INVALID_ARGUMENT: only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
File "keras-flatten-bug.py", line 37, in <module>
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 58, in train_step
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 183, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/function.py", line 177, in _run_through_graph
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 648, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/reshaping/flatten.py", line 54, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/numpy.py", line 5074, in reshape
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/numpy.py", line 2068, in reshape

     [[{{node functional_1/flatten_1/Reshape}}]]
    tf2xla conversion failed while converting __inference_one_step_on_data_1227[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
     [[StatefulPartitionedCall]]
Traceback (most recent call last):
  File "/home/tomasz/keras-flatten-bug.py", line 37, in <module>
    model.fit(train_dataset, epochs=10, validation_data=val_dataset)
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node functional_1/flatten_1/Reshape defined at (most recent call last):
<stack traces unavailable>
only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
File "keras-flatten-bug.py", line 37, in <module>
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 114, in one_step_on_data
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py", line 58, in train_step
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 183, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/function.py", line 177, in _run_through_graph
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/models/functional.py", line 648, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/layer.py", line 936, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/operation.py", line 58, in __call__
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/layers/reshaping/flatten.py", line 54, in call
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/ops/numpy.py", line 5074, in reshape
File ".cache/R/reticulate/uv/cache/archive-v0/UKrZHVgkU9dkAwai6kXUs/lib/python3.11/site-packages/keras/src/backend/tensorflow/numpy.py", line 2068, in reshape

     [[{{node functional_1/flatten_1/Reshape}}]]
    tf2xla conversion failed while converting __inference_one_step_on_data_1227[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
     [[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_1268]

This is the list of packages used to create the venv with uv:

"numpy","keras","pydot","scipy","pandas","Pillow","ipython","tensorflow[and-cuda]"

Comment From: sonali-kumari1

Hi @t-kalinowski - I have tested your code with the latest version of keras(3.10.0) across all three backends: Tensorflow, JAX and Torch in this gist file. Your code works fine with JAX and Torch backends, but fails on the Tensorflow backend with error:

InvalidArgumentError: Graph execution error: only one input size may be -1, not both 0 and 1

I verified the shapes explicitly, and they are consistent across all backends, so the error seems specific to how Tensorflow handles the reshape operation internally in the Flatten() layer. We will look into this and update you. Thanks!