Hi, I ran into an issue with BatchNormalization layers applied to concatenated embeddings:

When utilizing Embedding layer with the mask_zero=True keyword argument, the created masks lead to an exception downstream, when concatenated embeddings are passed into BatchNormalization layers.

Please see the following MWE:

import keras
import numpy as np

rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))

data_1[:, 25:] = 0
data_2[:, 25:] = 0

input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)

x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)

model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)

After Execution the summary already shows uncharacteristically many operations (significantly more operations, than when the BatchNormalization layer is excluded) and executing the script throws the following exception:

Epoch 1/10
Traceback (most recent call last):
  File "/home/stud2019/tharren/mwe.py", line 24, in <module>
    model.fit([data_1, data_2], y, epochs=10)
  File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
        ^^^^^^^^^^^
TypeError: Exception encountered when calling BroadcastTo.call().

Failed to convert elements of (None, 20, 3) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

Arguments received by BroadcastTo.call():
  • x=tf.Tensor(shape=(None, 20, 1), dtype=bool)

However, if the BatchNormalization layer is not included, the training can proceed as expected.

Package Information: keras 3.11.2 tensorflow 2.19.0 numpy 1.26.4

Comment From: avish006

Hello @tobiasharren , Actually batch normalization does not support masking,

After this line, x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])

try adding: x = keras.layers.Lambda(lambda t: tf.stop_gradient(t), name="strip_mask")(x)

then keep the rest as same, this line essentially removes the mask for the BatchNormalization layer