Describe the bug
The arithmetic operation A+B, where both A and B are int64
Keras Variables (tested on TensorFlow), unexpectedly results in a tf.int32
output. This is a consequence of an explicit dtype promotion policy in:
https://github.com/keras-team/keras/blob/5dbdf60fab00c9b4592b16ec6dbc31c9c3ebf394/keras/src/backend/common/dtypes.py#L280-L283
This policy downcasts the result type from int64 to int32 to align Keras's integer precision with the default float32 precision and causes weird dtype behaviors.
This behavior was introduced in: https://github.com/keras-team/keras/pull/21604
To Reproduce
import numpy as np
import tensorflow as tf
import keras
# Set up a random number generator for reproducibility
# Note: The original image uses 'np.random.default_rng()'
# which is the modern NumPy way, so we'll stick to that.
rng = np.random.default_rng()
# 1. Create a NumPy array 'x'
# rng.standard_normal((2, 3, 4, 5)) creates a 4D array of random standard normal values
x = rng.standard_normal((2, 3, 4, 5))
# 2. Create a Keras Variable for 'indices'
# This creates a variable of shape (3,) with all ones, and data type 'int64'
indices = keras.Variable(
initializer="ones",
shape=(3,),
dtype="int64",
trainable=False
)
# 3. Print the data types of 'x' and 'indices'
print(f"Data type of x: {x.dtype}")
print(f"Data type of indices: {indices.dtype}")
# 4. TensorFlow Type Casting and Operations
# Get the value of the 2nd dimension of x's shape (which is 3)
x_shape_dim2 = tf.shape(x)[2]
print(f"tf.shape(x)[2]: {x_shape_dim2.numpy()} (dtype: {x_shape_dim2.dtype.name})")
# Cast the shape dimension value to a TensorFlow 32-bit integer
x_shape_dim2_int32 = tf.cast(x_shape_dim2, tf.int32)
print(f"tf.cast(tf.shape(x)[2], tf.int32): {x_shape_dim2_int32.numpy()} (dtype: {x_shape_dim2_int32.dtype.name})")
# Cast the 'indices' variable to a TensorFlow 32-bit integer
indices_int32 = tf.cast(indices, tf.int32)
print(f"tf.cast(indices, tf.int32) (a Keras Variable cast): dtype: {indices_int32.dtype.name}")
# Perform the final element-wise addition and print the resulting dtype
# The result will be an 'tf.int32' Tensor because both operands are cast to 'tf.int32'
final_result = indices_int32 + x_shape_dim2_int32
print(f"\nFinal operation (indices + tf.cast(tf.shape(x)[2], indices.dtype)).dtype: {final_result.dtype.name}")
Expected behavior Adding two 64-bit integers should result in a 64-bit integer.
Additional context
Would you like to help us fix it?