As per title, compiling a Keras EfficientNetV2 model using torch backend raises an InternalTorchDynamoError, see below.

Environment (running on Colab, with GPU): Keras version: 3.11.3 Torch version: 2.8.0-cu126

Minimal reproducible example:

import os
os.environ["KERAS_BACKEND"] = 'torch'

import numpy as np
import torch
import keras
from keras import layers, models
from keras.applications import EfficientNetV2B2
from keras.optimizers import Adam
from keras.losses import CategoricalCrossentropy

print(f"Backend: {keras.config.backend()}")

num_classes = 10
batch_size = 16
steps_per_epoch = 5
epochs = 2

# Generate random data
data_shape = (224, 224, 3)
x_train = np.random.rand(batch_size * steps_per_epoch, *data_shape).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(batch_size * steps_per_epoch,))
y_train = np.eye(num_classes)[y_train] 

base_model = EfficientNetV2B2(include_top=False, input_shape=(None, None, 3), pooling='avg', include_preprocessing=True)
x = base_model.output
output = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=base_model.input, outputs=output)

model.compile(optimizer=Adam(learning_rate=0.001), loss=CategoricalCrossentropy(), metrics=['accuracy'], jit_compile=True)

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

Output:

Backend: torch
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b2_notop.h5
35839040/35839040 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step
Epoch 1/2
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.is_leaf.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.flatten.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
W0909 14:47:48.027000 291 torch/_inductor/utils.py:1436] [102/0] Not enough SMs to use max_autotune_gemm mode
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] torch._dynamo hit config.recompile_limit (8)
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8]    function: '__call__' (/usr/local/lib/python3.12/dist-packages/keras/src/layers/layer.py:816)
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8]    last reason: 3/7: expected type of 'args[0]' to be a tensor type, ' but found <class 'list'>
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0909 14:48:07.288000 291 torch/_dynamo/convert_frame.py:1016] [3/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
---------------------------------------------------------------------------
InternalTorchDynamoError                  Traceback (most recent call last)
[/tmp/ipython-input-2007246960.py](https://localhost:8080/#) in <cell line: 0>()
     32 print("Model compiled!")
     33 
---> 34 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

23 frames
[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in compute_exception_table(instructions)
    894 
    895     # Sort keys by increasing start, then decreasing end
--> 896     keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1]))
    897     # smallest byte that the next exception table entry can start at
    898     nexti = 0

InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'

from user code:
   File "/usr/local/lib/python3.12/dist-packages/keras/src/trainers/compile_utils.py", line 693, in call
    if not tree.is_nested(y_true) and not tree.is_nested(y_pred):

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Interestingly enough, switching to CPU (same versions of the libraries) the error changes:

Backend: torch
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b2_notop.h5
35839040/35839040 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Epoch 1/2
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.is_leaf.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1464: UserWarning: Dynamo cannot trace optree C/C++ function optree._C.pybind11_detail_function_record_v1_system_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_1.flatten.
 Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
W0909 14:55:17.591000 270 torch/utils/cpp_extension.py:118] [93/0] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/tmp/ipython-input-360142954.py](https://localhost:8080/#) in <cell line: 0>()
     30 model.compile(optimizer=Adam(learning_rate=0.001), loss=CategoricalCrossentropy(), metrics=['accuracy'], jit_compile=True)
     31 
---> 32 model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

38 frames
[/usr/local/lib/python3.12/dist-packages/sympy/core/relational.py](https://localhost:8080/#) in __bool__(self)
    514 
    515     def __bool__(self):
--> 516         raise TypeError("cannot determine truth value of Relational")
    517 
    518     def _eval_as_set(self):

RuntimeError: Exception encountered when calling Conv2D.call().

TypeError: cannot determine truth value of Relational

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


Arguments received by Conv2D.call():
  • inputs=torch.Tensor(shape=torch.Size([16, 112, 112, 32]), dtype=float32)

What I tried (and it didn't work): - Fixed input_shape (instead of Nones) on the model. - Different versions of EfficientNetV2. - Using torch.nn losses.

Comment From: monicadsong

Hi Davide--

The Keras team would like to learn more your Keras usage. Please thumbs up this comment if you are open to speaking with us and we'll reach out to you over the LinkedIn account in your GitHub profile.

Thanks! Monica (monicadsong@google.com)

Comment From: dhantule

Hi @Doch88, Thanks for reporting this.

I've tested your code and it works fine when jit_compile is disabled, when it's enabled I was able to reproduce the above errors in this gist. We'll look into this issue and update you.

Comment From: Doch88

Thank you @dhantule 🙏

Comment From: SamanehSaadat

Have you tried using other backends like jax? Would the jax backend work for your usecase?

Comment From: Doch88

Hi @SamanehSaadat, no unfortunately that's not a possibility for us 🙁

Comment From: Doch88

Hi @james77777778, I see you made a PR related to this issue, did you fix the problem?

Comment From: james77777778

@Doch88

I would say the issue is partially fixed, especially for your use case: https://colab.research.google.com/drive/1ZROnvDmLnUbg2NeIdlktskuEVcJtAhPb?usp=sharing

However, torch 2.8 doesn't seem to be fully compatible with Keras. Many unit tests failed when I upgraded the test suite in that PR.