Hi, I ran into an issue with BatchNormalization layers applied to concatenated embeddings:

When utilizing Embedding layer with the mask_zero=True keyword argument, the created masks lead to an exception downstream, when concatenated embeddings are passed into BatchNormalization layers.

Please see the following MWE:

import keras
import numpy as np

rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))

data_1[:, 25:] = 0
data_2[:, 25:] = 0

input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)

x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)

model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)

After Execution the summary already shows uncharacteristically many operations (significantly more operations, than when the BatchNormalization layer is excluded) and executing the script throws the following exception:

Epoch 1/10
Traceback (most recent call last):
  File "/home/stud2019/tharren/mwe.py", line 24, in <module>
    model.fit([data_1, data_2], y, epochs=10)
  File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
        ^^^^^^^^^^^
TypeError: Exception encountered when calling BroadcastTo.call().

Failed to convert elements of (None, 20, 3) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

Arguments received by BroadcastTo.call():
  • x=tf.Tensor(shape=(None, 20, 1), dtype=bool)

However, if the BatchNormalization layer is not included, the training can proceed as expected.

Package Information: keras 3.11.2 tensorflow 2.19.0 numpy 1.26.4

Comment From: avish006

Hello @tobiasharren , Actually batch normalization does not support masking,

After this line, x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])

try adding: x = keras.layers.Lambda(lambda t: tf.stop_gradient(t), name="strip_mask")(x)

then keep the rest as same, this line essentially removes the mask for the BatchNormalization layer

Comment From: tobiasharren

Dear @avish006 , thank you for your reply!

I believe BatchNormalization has actually been changed to implement masking! If I take a look at: https://github.com/keras-team/keras/blob/v3.11.2/keras/src/layers/normalization/batch_normalization.py#L11 the layer has self.supports_masking=True and does implement logic with the mask in its call function.

This is also confirmed by the fact that the following code does execute as expected:

import keras
import numpy as np

rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))

data_1[:, 25:] = 0

input_1 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
x = keras.layers.BatchNormalization()(embedding_1)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit(data_1, y, epochs=10) 

The problem does in fact, after further observation also occur with masks created from the masking layer:

rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20))
data_1 = np.eye(9)[data_1]
data_2 = rng.choice(2, size=(50, 20))
data_2 = np.eye(2)[data_2]
y = rng.normal(0, 1, size=(50,))

data_1[:, 25:, :] = 0
data_2[:, 25:, :] = 0

input_1 = keras.layers.Input((20, 9))
input_2 = keras.layers.Input((20, 2))
m_1 = keras.layers.Masking(mask_value=0)(input_1)
m_2 = keras.layers.Masking(mask_value=0)(input_2)

x = keras.layers.Concatenate(axis=-1)([m_1, m_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)

model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)

However, the Concatenate layer is also implemented with masking logic: https://github.com/keras-team/keras/blob/v3.11.2/keras/src/layers/merging/concatenate.py#L8

Comment From: mattdangerw

Slightly more minimal repo (without compilation), happens any backend

import keras
import numpy as np

rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))

data_1[:, 25:] = 0
data_2[:, 25:] = 0

input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)

x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)

model = keras.models.Model(inputs=[input_1, input_2], outputs=[x])
model([data_1, data_2])

It looks like the issue starts with this commit https://github.com/keras-team/keras/commit/603db8003caee499b3f0e98fc08716dd93076734 which adds the broadcast_to call to concatenate. @hertschuh any ideas?