From 902bc1d91f9667ba12455c7182709dcb3989b1f7 Mon Sep 17 00:00:00 2001 From: dorbit Date: Sat, 29 Nov 2025 08:17:26 +0200 Subject: [PATCH] Add FlexAttention examples to SDPA tutorial MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new section demonstrating flex_attention from PyTorch 2.5: - Custom score_mod functions (relative position bias, ALiBi) - block_mask for sparse attention patterns (causal masking) - Combining score_mod and block_mask - Performance comparison with standard SDPA This extends the existing SDPA tutorial with practical examples of the flexible attention API for custom attention patterns. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../scaled_dot_product_attention_tutorial.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/intermediate_source/scaled_dot_product_attention_tutorial.py b/intermediate_source/scaled_dot_product_attention_tutorial.py index 35b1ba7be4..cbe81799fb 100644 --- a/intermediate_source/scaled_dot_product_attention_tutorial.py +++ b/intermediate_source/scaled_dot_product_attention_tutorial.py @@ -392,6 +392,166 @@ def generate_rand_batch( compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True) out_upper_left = compiled_sdpa(query, key, value, upper_left_bias) +###################################################################### +# FlexAttention: Custom Attention Score Modifications +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# PyTorch 2.5 introduced ``flex_attention``, a powerful extension that allows +# arbitrary modifications to attention scores before the softmax operation. +# This enables implementing custom attention patterns like relative position +# biases, ALiBi, sliding window attention, and more - all while maintaining +# the performance benefits of fused attention kernels. +# +# ``flex_attention`` uses ``torch.compile`` to generate optimized kernels +# for your custom score modification functions. +# + +from torch.nn.attention.flex_attention import flex_attention, create_block_mask + +# Define custom score modification functions +# These functions receive the attention score and position indices + + +def relative_position_bias(score, batch, head, q_idx, kv_idx): + """Apply a simple relative position bias that penalizes distant tokens.""" + return score - (q_idx - kv_idx).abs() * 0.1 + + +def alibi_bias(score, batch, head, q_idx, kv_idx): + """ + Implement ALiBi (Attention with Linear Biases). + + ALiBi adds a linear bias based on the distance between query and key + positions, allowing the model to extrapolate to longer sequences. + """ + distance = (q_idx - kv_idx).float() + return score + distance * (-0.5) + + +###################################################################### +# Using flex_attention with score_mod +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The ``score_mod`` parameter accepts a function that modifies attention +# scores. The function signature is: +# ``score_mod(score, batch, head, q_idx, kv_idx) -> modified_score`` +# + +# flex_attention works best when compiled +flex_attention_compiled = torch.compile(flex_attention) + +# Create sample tensors +flex_batch_size = 4 +flex_num_heads = 8 +flex_seq_len_q = 64 +flex_seq_len_kv = 64 +flex_head_dim = 64 + +flex_query = torch.randn( + flex_batch_size, flex_num_heads, flex_seq_len_q, flex_head_dim, + device=device, dtype=torch.float16 +) +flex_key = torch.randn( + flex_batch_size, flex_num_heads, flex_seq_len_kv, flex_head_dim, + device=device, dtype=torch.float16 +) +flex_value = torch.randn( + flex_batch_size, flex_num_heads, flex_seq_len_kv, flex_head_dim, + device=device, dtype=torch.float16 +) + +# Apply flex_attention with relative position bias +out_relative = flex_attention_compiled( + flex_query, flex_key, flex_value, + score_mod=relative_position_bias +) +print(f"Output with relative position bias: {out_relative.shape}") + +# Apply flex_attention with ALiBi +out_alibi = flex_attention_compiled( + flex_query, flex_key, flex_value, + score_mod=alibi_bias +) +print(f"Output with ALiBi bias: {out_alibi.shape}") + +###################################################################### +# Using block_mask for Sparse Attention Patterns +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# ``flex_attention`` also supports ``block_mask`` for implementing sparse +# attention patterns efficiently. The mask is computed at the block level, +# allowing entire blocks to be skipped during computation. +# + + +def causal_mask_fn(batch, head, q_idx, kv_idx): + """Standard causal mask: each position can only attend to previous.""" + return q_idx >= kv_idx + + +# Create block mask for causal attention +flex_block_mask = create_block_mask( + causal_mask_fn, + B=flex_batch_size, + H=flex_num_heads, + Q_LEN=flex_seq_len_q, + KV_LEN=flex_seq_len_kv, + device=device +) + +# Apply flex_attention with block mask +out_causal_flex = flex_attention_compiled( + flex_query, flex_key, flex_value, + block_mask=flex_block_mask +) +print(f"Output with causal block mask: {out_causal_flex.shape}") + + +###################################################################### +# Combining score_mod and block_mask +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# You can combine both ``score_mod`` and ``block_mask`` for maximum +# flexibility. The block mask determines which blocks are computed, +# while score_mod modifies the scores within computed blocks. +# + +out_combined = flex_attention_compiled( + flex_query, flex_key, flex_value, + score_mod=relative_position_bias, + block_mask=flex_block_mask +) +print(f"Output with combined score_mod and block_mask: {out_combined.shape}") + +###################################################################### +# Performance Comparison +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# Let's compare the performance of flex_attention with standard SDPA: +# + +# Benchmark standard SDPA +sdpa_time = benchmark_torch_function_in_microseconds( + F.scaled_dot_product_attention, flex_query, flex_key, flex_value, + is_causal=True +) +print(f"Standard SDPA (causal): {sdpa_time:.3f} microseconds") + +# Benchmark flex_attention with causal block mask +flex_time = benchmark_torch_function_in_microseconds( + flex_attention_compiled, flex_query, flex_key, flex_value, + block_mask=flex_block_mask +) +print(f"FlexAttention (causal block_mask): {flex_time:.3f} microseconds") + +# Benchmark flex_attention with score_mod +flex_score_mod_time = benchmark_torch_function_in_microseconds( + flex_attention_compiled, flex_query, flex_key, flex_value, + score_mod=relative_position_bias +) +print(f"FlexAttention (relative_position_bias): {flex_score_mod_time:.3f} microseconds") + + ###################################################################### # Conclusion # ~~~~~~~~~~~ @@ -405,3 +565,9 @@ def generate_rand_batch( # be used to explore the performance characteristics of a user defined # module. # +# Additionally, we explored ``flex_attention``, which extends SDPA with +# custom score modifications and sparse attention patterns. This enables +# implementing advanced attention mechanisms like ALiBi and relative +# position biases while maintaining high performance through compiled +# kernels. +#