Hi, I ran into an issue with BatchNormalization layers applied to concatenated embeddings:
When utilizing Embedding layer with the mask_zero=True keyword argument, the created masks lead to an exception downstream, when concatenated embeddings are passed into BatchNormalization layers.
Please see the following MWE:
import keras
import numpy as np
rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))
data_1[:, 25:] = 0
data_2[:, 25:] = 0
input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)
x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)
After Execution the summary already shows uncharacteristically many operations (significantly more operations, than when the BatchNormalization layer is excluded) and executing the script throws the following exception:
Epoch 1/10
Traceback (most recent call last):
File "/home/stud2019/tharren/mwe.py", line 24, in <module>
model.fit([data_1, data_2], y, epochs=10)
File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
^^^^^^^^^^^
TypeError: Exception encountered when calling BroadcastTo.call().
Failed to convert elements of (None, 20, 3) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Arguments received by BroadcastTo.call():
• x=tf.Tensor(shape=(None, 20, 1), dtype=bool)
However, if the BatchNormalization layer is not included, the training can proceed as expected.
Package Information: keras 3.11.2 tensorflow 2.19.0 numpy 1.26.4
Comment From: avish006
Hello @tobiasharren , Actually batch normalization does not support masking,
After this line, x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
try adding: x = keras.layers.Lambda(lambda t: tf.stop_gradient(t), name="strip_mask")(x)
then keep the rest as same, this line essentially removes the mask for the BatchNormalization layer
Comment From: tobiasharren
Dear @avish006 , thank you for your reply!
I believe BatchNormalization has actually been changed to implement masking!
If I take a look at: https://github.com/keras-team/keras/blob/v3.11.2/keras/src/layers/normalization/batch_normalization.py#L11
the layer has self.supports_masking=True and does implement logic with the mask in its call function.
This is also confirmed by the fact that the following code does execute as expected:
import keras
import numpy as np
rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))
data_1[:, 25:] = 0
input_1 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
x = keras.layers.BatchNormalization()(embedding_1)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit(data_1, y, epochs=10)
The problem does in fact, after further observation also occur with masks created from the masking layer:
rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20))
data_1 = np.eye(9)[data_1]
data_2 = rng.choice(2, size=(50, 20))
data_2 = np.eye(2)[data_2]
y = rng.normal(0, 1, size=(50,))
data_1[:, 25:, :] = 0
data_2[:, 25:, :] = 0
input_1 = keras.layers.Input((20, 9))
input_2 = keras.layers.Input((20, 2))
m_1 = keras.layers.Masking(mask_value=0)(input_1)
m_2 = keras.layers.Masking(mask_value=0)(input_2)
x = keras.layers.Concatenate(axis=-1)([m_1, m_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)
However, the Concatenate layer is also implemented with masking logic:
https://github.com/keras-team/keras/blob/v3.11.2/keras/src/layers/merging/concatenate.py#L8
Comment From: mattdangerw
Slightly more minimal repo (without compilation), happens any backend
import keras
import numpy as np
rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))
data_1[:, 25:] = 0
data_2[:, 25:] = 0
input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)
x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1, input_2], outputs=[x])
model([data_1, data_2])
It looks like the issue starts with this commit https://github.com/keras-team/keras/commit/603db8003caee499b3f0e98fc08716dd93076734 which adds the broadcast_to call to concatenate. @hertschuh any ideas?