Hi everyone,
I am trying to better understand how the Keras MultiHeadAttention
layer handles the dimensions internally.
Suppose I input a tensor of shape (32, 24, 21) meaning (batch_size, time_steps, features) into the MultiHeadAttention
layer, and I set the number of heads to 8.
Keras correctly outputs a tensor of shape (32, 24, 21), matching my input dimensions, but I'm confused about the internal dimension handling.
My understanding is:
- If the input is (batch_size=2, time_steps=3, features=4) and we use num_heads=2
,
- Then after the linear projection, the queries (Q) will be shaped into (2, 3, 4),
- Then separated into heads: (2, 2, 3, 2),
- After transpose: (2, 2, 3, 2) → (2, 2, 3, 2),
- Then attention scores (QKᵀ) will be (2, 2, 3, 3),
- After applying softmax and multiplying by V, the output per head is (2, 2, 3, 2),
- After merging heads, we get back to (2, 3, 4) by concatenating heads.
My confusion:
In my case, features=21, and heads=8.
Since 21 is not divisible by 8, how is Keras handling this? Normally, the feature dimension must be divisible by the number of heads (i.e., key_dim * num_heads = features
).
So how does Keras map the 21 features into multiple heads internally, and how does it correctly recover the (32, 24, 21) output shape?
Would love a clarification on this!
Comment From: heydaari
Keras, like any other framework uses an assertion check before dividing the whole thing. like this: assert input.shape[-1] % self.num_heads == 0. If the embedding dimension was not divisible by num_heads, it returns an assertion error before doing the chunking thing.
Comment From: divyashreepathihalli
Input shape: (batch_size, time_steps, features) = (32, 24, 21) Number of heads: 8 Key dimension (per head): 16 Value dimension (per head): 16 Total projection dimension for query/key/value: num_heads * key_dim = 8 * 16 = 128 Output shape: (None, 24, 21)
Explanation:
1. The input tensor with shape (32, 24, 21) is passed to the MultiHeadAttention layer.
2. The layer projects the input into a higher-dimensional space. The new dimension is num_heads * key_dim
= 8 * 16 = 128.
3. The projected tensor is split into 8 heads, each with a dimension of 16.
4. The attention mechanism is applied.
5. The heads are concatenated, resulting in a tensor of shape (32, 24, 128).
6. A final dense projection maps this tensor back to the original feature dimension of 21, resulting in the final output shape of (32, 24, 21).