Hi everyone,

I am trying to better understand how the Keras MultiHeadAttention layer handles the dimensions internally.

Suppose I input a tensor of shape (32, 24, 21) meaning (batch_size, time_steps, features) into the MultiHeadAttention layer, and I set the number of heads to 8.
Keras correctly outputs a tensor of shape (32, 24, 21), matching my input dimensions, but I'm confused about the internal dimension handling.

My understanding is:
- If the input is (batch_size=2, time_steps=3, features=4) and we use num_heads=2,
- Then after the linear projection, the queries (Q) will be shaped into (2, 3, 4), - Then separated into heads: (2, 2, 3, 2), - After transpose: (2, 2, 3, 2) → (2, 2, 3, 2), - Then attention scores (QKᵀ) will be (2, 2, 3, 3), - After applying softmax and multiplying by V, the output per head is (2, 2, 3, 2), - After merging heads, we get back to (2, 3, 4) by concatenating heads.

My confusion:
In my case, features=21, and heads=8.
Since 21 is not divisible by 8, how is Keras handling this? Normally, the feature dimension must be divisible by the number of heads (i.e., key_dim * num_heads = features).
So how does Keras map the 21 features into multiple heads internally, and how does it correctly recover the (32, 24, 21) output shape?

Would love a clarification on this!

Comment From: heydaari

Keras, like any other framework uses an assertion check before dividing the whole thing. like this: assert input.shape[-1] % self.num_heads == 0. If the embedding dimension was not divisible by num_heads, it returns an assertion error before doing the chunking thing.

Comment From: divyashreepathihalli

Input shape: (batch_size, time_steps, features) = (32, 24, 21) Number of heads: 8 Key dimension (per head): 16 Value dimension (per head): 16 Total projection dimension for query/key/value: num_heads * key_dim = 8 * 16 = 128 Output shape: (None, 24, 21)

Explanation: 1. The input tensor with shape (32, 24, 21) is passed to the MultiHeadAttention layer. 2. The layer projects the input into a higher-dimensional space. The new dimension is num_heads * key_dim = 8 * 16 = 128. 3. The projected tensor is split into 8 heads, each with a dimension of 16. 4. The attention mechanism is applied. 5. The heads are concatenated, resulting in a tensor of shape (32, 24, 128). 6. A final dense projection maps this tensor back to the original feature dimension of 21, resulting in the final output shape of (32, 24, 21).