The GitHub source code defines a function _can_use_flash_attention(...) that attempts to verify whether Flash Attention is available.
However, starting with JAX 0.6.2 (the version recommended by the requirements), the signature of the internal helper check_layout was changed to:

def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
                 q_offsets, kv_offsets, page_table_k, page_table_v, layout):

In the current implementation, _can_use_flash_attention still calls check_layout with the old argument list:

check_layout(
    query,
    key,
    value,
    bias,
    q_seqlen=None,
    kv_seqlen=None,
    layout=_normalize_layout("BTNH"),
)

Because the required positional arguments q_offsets, kv_offsets, page_table_k, and page_table_v are missing, this call always raises a TypeError.
As a result, _can_use_flash_attention catches the exception and always returns False, effectively preventing JAX from using Flash Attention in the backend.

Comment From: sonali-kumari1

Hi @pass-lin -

Thanks for reporting this issue. The current _can_use_flash_attention includes q_offsets=None and kv_offsets=None in the call to check_layout, but it is missing page_table_k, and page_table_v arguments introduced in JAX 0.6.2.

That said, JAX is currently pinned to version 0.5.0 on CPU which doesn't include these newly arguments. Hopefully, this will be resolved once the support for newer JAX versions is added.