Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions intermediate_source/scaled_dot_product_attention_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,166 @@ def generate_rand_batch(
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)

######################################################################
# FlexAttention: Custom Attention Score Modifications
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# PyTorch 2.5 introduced ``flex_attention``, a powerful extension that allows
# arbitrary modifications to attention scores before the softmax operation.
# This enables implementing custom attention patterns like relative position
# biases, ALiBi, sliding window attention, and more - all while maintaining
# the performance benefits of fused attention kernels.
#
# ``flex_attention`` uses ``torch.compile`` to generate optimized kernels
# for your custom score modification functions.
#

from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# Define custom score modification functions
# These functions receive the attention score and position indices


def relative_position_bias(score, batch, head, q_idx, kv_idx):
"""Apply a simple relative position bias that penalizes distant tokens."""
return score - (q_idx - kv_idx).abs() * 0.1


def alibi_bias(score, batch, head, q_idx, kv_idx):
"""
Implement ALiBi (Attention with Linear Biases).

ALiBi adds a linear bias based on the distance between query and key
positions, allowing the model to extrapolate to longer sequences.
"""
distance = (q_idx - kv_idx).float()
return score + distance * (-0.5)


######################################################################
# Using flex_attention with score_mod
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The ``score_mod`` parameter accepts a function that modifies attention
# scores. The function signature is:
# ``score_mod(score, batch, head, q_idx, kv_idx) -> modified_score``
#

# flex_attention works best when compiled
flex_attention_compiled = torch.compile(flex_attention)

# Create sample tensors
flex_batch_size = 4
flex_num_heads = 8
flex_seq_len_q = 64
flex_seq_len_kv = 64
flex_head_dim = 64

flex_query = torch.randn(
flex_batch_size, flex_num_heads, flex_seq_len_q, flex_head_dim,
device=device, dtype=torch.float16
)
flex_key = torch.randn(
flex_batch_size, flex_num_heads, flex_seq_len_kv, flex_head_dim,
device=device, dtype=torch.float16
)
flex_value = torch.randn(
flex_batch_size, flex_num_heads, flex_seq_len_kv, flex_head_dim,
device=device, dtype=torch.float16
)

# Apply flex_attention with relative position bias
out_relative = flex_attention_compiled(
flex_query, flex_key, flex_value,
score_mod=relative_position_bias
)
print(f"Output with relative position bias: {out_relative.shape}")

# Apply flex_attention with ALiBi
out_alibi = flex_attention_compiled(
flex_query, flex_key, flex_value,
score_mod=alibi_bias
)
print(f"Output with ALiBi bias: {out_alibi.shape}")

######################################################################
# Using block_mask for Sparse Attention Patterns
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# ``flex_attention`` also supports ``block_mask`` for implementing sparse
# attention patterns efficiently. The mask is computed at the block level,
# allowing entire blocks to be skipped during computation.
#


def causal_mask_fn(batch, head, q_idx, kv_idx):
"""Standard causal mask: each position can only attend to previous."""
return q_idx >= kv_idx


# Create block mask for causal attention
flex_block_mask = create_block_mask(
causal_mask_fn,
B=flex_batch_size,
H=flex_num_heads,
Q_LEN=flex_seq_len_q,
KV_LEN=flex_seq_len_kv,
device=device
)

# Apply flex_attention with block mask
out_causal_flex = flex_attention_compiled(
flex_query, flex_key, flex_value,
block_mask=flex_block_mask
)
print(f"Output with causal block mask: {out_causal_flex.shape}")


######################################################################
# Combining score_mod and block_mask
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# You can combine both ``score_mod`` and ``block_mask`` for maximum
# flexibility. The block mask determines which blocks are computed,
# while score_mod modifies the scores within computed blocks.
#

out_combined = flex_attention_compiled(
flex_query, flex_key, flex_value,
score_mod=relative_position_bias,
block_mask=flex_block_mask
)
print(f"Output with combined score_mod and block_mask: {out_combined.shape}")

######################################################################
# Performance Comparison
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Let's compare the performance of flex_attention with standard SDPA:
#

# Benchmark standard SDPA
sdpa_time = benchmark_torch_function_in_microseconds(
F.scaled_dot_product_attention, flex_query, flex_key, flex_value,
is_causal=True
)
print(f"Standard SDPA (causal): {sdpa_time:.3f} microseconds")

# Benchmark flex_attention with causal block mask
flex_time = benchmark_torch_function_in_microseconds(
flex_attention_compiled, flex_query, flex_key, flex_value,
block_mask=flex_block_mask
)
print(f"FlexAttention (causal block_mask): {flex_time:.3f} microseconds")

# Benchmark flex_attention with score_mod
flex_score_mod_time = benchmark_torch_function_in_microseconds(
flex_attention_compiled, flex_query, flex_key, flex_value,
score_mod=relative_position_bias
)
print(f"FlexAttention (relative_position_bias): {flex_score_mod_time:.3f} microseconds")


######################################################################
# Conclusion
# ~~~~~~~~~~~
Expand All @@ -405,3 +565,9 @@ def generate_rand_batch(
# be used to explore the performance characteristics of a user defined
# module.
#
# Additionally, we explored ``flex_attention``, which extends SDPA with
# custom score modifications and sparse attention patterns. This enables
# implementing advanced attention mechanisms like ALiBi and relative
# position biases while maintaining high performance through compiled
# kernels.
#