The GitHub source code defines a function _can_use_flash_attention(...)
that attempts to verify whether Flash Attention is available.
However, starting with JAX 0.6.2 (the version recommended by the requirements), the signature of the internal helper check_layout
was changed to:
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
q_offsets, kv_offsets, page_table_k, page_table_v, layout):
In the current implementation, _can_use_flash_attention
still calls check_layout
with the old argument list:
check_layout(
query,
key,
value,
bias,
q_seqlen=None,
kv_seqlen=None,
layout=_normalize_layout("BTNH"),
)
Because the required positional arguments q_offsets
, kv_offsets
, page_table_k
, and page_table_v
are missing, this call always raises a TypeError
.
As a result, _can_use_flash_attention
catches the exception and always returns False
, effectively preventing JAX from using Flash Attention in the backend.
Comment From: sonali-kumari1
Hi @pass-lin -
Thanks for reporting this issue. The current _can_use_flash_attention includes q_offsets=None
and kv_offsets=None
in the call to check_layout
, but it is missing page_table_k
, and page_table_v
arguments introduced in JAX 0.6.2.
That said, JAX is currently pinned to version 0.5.0 on CPU which doesn't include these newly arguments. Hopefully, this will be resolved once the support for newer JAX versions is added.