Hi, I ran into an issue with BatchNormalization layers applied to concatenated embeddings:
When utilizing Embedding
layer with the mask_zero=True
keyword argument, the created masks lead to an exception downstream, when concatenated embeddings are passed into BatchNormalization
layers.
Please see the following MWE:
import keras
import numpy as np
rng = np.random.default_rng(123)
data_1 = rng.choice(9, size=(50, 20)) + 1
data_2 = rng.choice(2, size=(50, 20)) + 1
y = rng.normal(0, 1, size=(50,))
data_1[:, 25:] = 0
data_2[:, 25:] = 0
input_1 = keras.layers.Input((20,))
input_2 = keras.layers.Input((20,))
embedding_1 = keras.layers.Embedding(10, 3, mask_zero=True)(input_1)
embedding_2 = keras.layers.Embedding(10, 2, mask_zero=True)(input_2)
x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(1)(x)
model = keras.models.Model(inputs=[input_1,input_2], outputs=[x])
model.compile(loss="mse")
model.summary()
model.fit([data_1, data_2], y, epochs=10)
After Execution the summary already shows uncharacteristically many operations (significantly more operations, than when the BatchNormalization
layer is excluded) and executing the script throws the following exception:
Epoch 1/10
Traceback (most recent call last):
File "/home/stud2019/tharren/mwe.py", line 24, in <module>
model.fit([data_1, data_2], y, epochs=10)
File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/scratch/stud2019/tharren/anaconda/envs/naomiml_docking/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
^^^^^^^^^^^
TypeError: Exception encountered when calling BroadcastTo.call().
Failed to convert elements of (None, 20, 3) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Arguments received by BroadcastTo.call():
• x=tf.Tensor(shape=(None, 20, 1), dtype=bool)
However, if the BatchNormalization
layer is not included, the training can proceed as expected.
Package Information: keras 3.11.2 tensorflow 2.19.0 numpy 1.26.4
Comment From: avish006
Hello @tobiasharren , Actually batch normalization does not support masking,
After this line, x = keras.layers.Concatenate(axis=-1)([embedding_1,embedding_2])
try adding: x = keras.layers.Lambda(lambda t: tf.stop_gradient(t), name="strip_mask")(x)
then keep the rest as same, this line essentially removes the mask for the BatchNormalization layer