From 74058f589d63b01ab19807810c5ef847842a2c12 Mon Sep 17 00:00:00 2001 From: Sagar Chapara Date: Fri, 15 May 2026 19:20:54 +0000 Subject: [PATCH] Add explicit Ulysses ring attention sharding with segment-id masking --- src/maxdiffusion/common_types.py | 12 + src/maxdiffusion/configs/base_wan_14b.yml | 4 +- src/maxdiffusion/configs/base_wan_1_3b.yml | 4 +- src/maxdiffusion/configs/base_wan_27b.yml | 4 +- src/maxdiffusion/configs/base_wan_animate.yml | 4 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 4 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 4 +- src/maxdiffusion/max_utils.py | 2 +- src/maxdiffusion/models/attention_flax.py | 364 ++++++++++++++---- .../wan/transformers/transformer_wan.py | 28 +- .../wan/transformers/transformer_wan_vace.py | 21 +- .../pipelines/wan/wan_pipeline.py | 7 +- src/maxdiffusion/pyconfig.py | 15 +- src/maxdiffusion/tests/attention_test.py | 303 +++++++++++++++ 14 files changed, 668 insertions(+), 108 deletions(-) diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index a92d5ec3b..9cb759476 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -95,3 +95,15 @@ [CROSS_ATTN_Q_LENGTH, CONTEXT], [CROSS_ATTN_KV_LENGTH, CONTEXT], ] + +### Common axis rules for 2D Ulysses + ring attention ### +# Public configs shard sequence on `context`; attention code privately reshapes +# that axis into hidden ring and Ulysses axes for the hybrid kernel. +ULYSSES_RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, CONTEXT], + [SELF_ATTN_KV_LENGTH, CONTEXT], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, CONTEXT], + [CROSS_ATTN_KV_LENGTH, None], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 319bfbc72..2df0ff673 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -69,9 +69,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 3134ed93d..883930b61 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -65,9 +65,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index dfe300ddf..e1d0671b1 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -69,9 +69,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 7b3334c79..9eb1612a7 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -67,9 +67,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index f722e04e2..a855e59c2 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -69,9 +69,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0aa533b40..7a7ae8822 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -69,9 +69,11 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring use_base2_exp: True use_experimental_scheduler: True +# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. +ulysses_shards: -1 flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 8cff92a33..bb47035ec 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -617,7 +617,7 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: - attention_is_tokamax = "tokamax" in config.attention + attention_is_tokamax = "tokamax" in config.attention or config.attention == "ulysses_ring" user_block_sizes: Dict[str, int] = config.flash_block_sizes if attention_is_tokamax: max_logging.log( diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index d740fbc4f..b8e563345 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -46,6 +46,7 @@ AxisNames = common_types.AxisNames +CONTEXT = common_types.CONTEXT BATCH = common_types.BATCH LENGTH = common_types.LENGTH KV_LENGTH = common_types.KV_LENGTH @@ -61,6 +62,9 @@ CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH +INTERNAL_RING_AXIS = "ring" +INTERNAL_ULYSSES_AXIS = "ulysses" + def _coerce_tokamax_block_sizes(block_sizes): # Tokamax requires fused bwd; convert if needed. @@ -159,6 +163,42 @@ def _unflatten_heads(tensor, heads): return tensor +def _replace_mesh_axis(axis_spec, old_axis: str, new_axes: tuple[str, ...]): + if axis_spec == old_axis: + return new_axes + if isinstance(axis_spec, tuple): + replacement = [] + for axis in axis_spec: + if axis == old_axis: + replacement.extend(new_axes) + else: + replacement.append(axis) + return tuple(replacement) + return axis_spec + + +def _replace_mesh_axis_names(axis_names, old_axis: str, new_axes: tuple[str, ...]): + return jax.sharding.PartitionSpec(*(_replace_mesh_axis(axis_name, old_axis, new_axes) for axis_name in axis_names)) + + +def _create_internal_ulysses_ring_mesh( + mesh: Mesh, + ring_shards: int, + ulysses_shards: int, + ring_axis: str = INTERNAL_RING_AXIS, + ulysses_axis: str = INTERNAL_ULYSSES_AXIS, +) -> Mesh: + """Split the public context mesh axis into private ring and Ulysses axes.""" + mesh_axis_names = tuple(mesh.axis_names) + context_axis_index = mesh_axis_names.index(CONTEXT) + devices = mesh.devices + new_shape = devices.shape[:context_axis_index] + (ring_shards, ulysses_shards) + devices.shape[context_axis_index + 1 :] + new_axis_names = ( + mesh_axis_names[:context_axis_index] + (ring_axis, ulysses_axis) + mesh_axis_names[context_axis_index + 1 :] + ) + return Mesh(devices.reshape(new_shape), new_axis_names) + + def _reshape_data_for_flash(tensor, heads, num_context_shards=1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -306,6 +346,41 @@ def convert_to_tokamax_splash_config( ) +def _build_padding_segment_ids( + query_seq_len: int, + q_padded_len: int, + key_seq_len: int, + kv_padded_len: int, + attention_mask: jax.Array | None, + segment_ids_cls=splash_attention_kernel.SegmentIds, +): + """Build splash segment ids that mask q/kv padding and the attention mask. + + Padding tokens get segment id 0, valid tokens 1. An optional attention_mask + (batch, kv_len) is folded into the kv segment ids; positions beyond the mask + but within key_seq_len default to valid, and positions beyond key_seq_len are + padding. Shared by flash, ulysses, and ulysses+ring kernels. + """ + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + + if attention_mask is not None: + mask_len = min(key_seq_len, attention_mask.shape[1]) + kv_mask_for_batch = attention_mask[0, :mask_len] + # Tokens past the mask but within key_seq_len are assumed valid. + if key_seq_len > mask_len: + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, jnp.ones((key_seq_len - mask_len,), jnp.int32)], axis=0) + # Tokens past key_seq_len are padding. + if kv_padded_len > key_seq_len: + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, jnp.zeros((kv_padded_len - key_seq_len,), jnp.int32)], axis=0) + kv_segment_ids = (kv_segment_ids * kv_mask_for_batch).astype(jnp.int32) + + return segment_ids_cls(q=q_segment_ids, kv=kv_segment_ids) + + def _tpu_flash_attention( query: jax.Array, key: jax.Array, @@ -325,7 +400,7 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - num_context_shards = mesh.shape["context"] + num_context_shards = mesh.shape[CONTEXT] query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) value, _ = _reshape_data_for_flash(value, heads, num_context_shards) @@ -368,35 +443,12 @@ def wrap_flash_attention(query, key, value): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - q_padded_len = query.shape[2] - q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) - q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) - - kv_padded_len = key.shape[2] - kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) - kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) - - # If attention_mask is provided, apply it to kv_segment_ids - if attention_mask is not None: - mask_len = min(key_seq_len, attention_mask.shape[1]) - kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,) - # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid) - if key_seq_len > mask_len: - extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) - kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,) - # Pad to kv_padded_len - if kv_padded_len > key_seq_len: - padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) - kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,) - else: - kv_mask_padded = kv_mask_for_batch - # Both are (kv_padded_len,) - element-wise multiplication - kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) - - if attention_kernel == "tokamax_ring": - segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) - else: - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + segment_ids_cls = ( + tokamax_splash_base.SegmentIds if attention_kernel == "tokamax_ring" else splash_attention_kernel.SegmentIds + ) + segment_ids = _build_padding_segment_ids( + query_seq_len, query.shape[2], key_seq_len, key.shape[2], attention_mask, segment_ids_cls + ) # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. @@ -429,7 +481,7 @@ def wrap_flash_attention(query, key, value): use_experimental_scheduler=use_experimental_scheduler, ), save_residuals=False, - ring_axis="context", + ring_axis=CONTEXT, rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids ) else: @@ -457,13 +509,13 @@ def wrap_flash_attention(query, key, value): perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)] - k1 = jax.lax.ppermute(key, axis_name="context", perm=perm) - v1 = jax.lax.ppermute(value, axis_name="context", perm=perm) + k1 = jax.lax.ppermute(key, axis_name=CONTEXT, perm=perm) + v1 = jax.lax.ppermute(value, axis_name=CONTEXT, perm=perm) def ring_scan_body(carry, _): m, l, o, k_current, v_current = carry - k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm) + k_next = jax.lax.ppermute(k_current, axis_name=CONTEXT, perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name=CONTEXT, perm=perm) out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) @@ -533,13 +585,12 @@ def _ulysses_attention( Tensors arrive sequence-sharded on the context axis. Inside a shard_map the all-to-all collectives trade sequence shards for head shards, run local - splash attention on the full sequence with a subset of heads, then all-to-all - back. + splash attention on the full sequence with a subset of heads, then + all-to-all back. """ - axis_name = "context" + axis_name = CONTEXT num_shards = mesh.shape[axis_name] - # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting. query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards) key, _ = _reshape_data_for_flash(key, heads, num_shards) value, _ = _reshape_data_for_flash(value, heads, num_shards) @@ -551,7 +602,6 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) - if not use_custom_kernel: block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") @@ -566,8 +616,8 @@ def _ulysses_attention( check_vma=False, ) def wrap_ulysses_attention(query, key, value): - # Swap sharding modes: each device gives up a slice of sequence and gathers - # a slice of heads, so the local splash kernel sees the full sequence. + # Swap sharding: each device gives up a slice of heads and gathers + # a slice of sequence, so the local kernel sees the full sequence. query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) @@ -638,30 +688,7 @@ def wrap_ulysses_attention(query, key, value): mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - q_padded_len = query.shape[2] - q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) - q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) - - kv_padded_len = key.shape[2] - kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) - kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) - - # Reuse the standard flash-attention masking convention by zeroing invalid - # KV positions in the segment ids passed down to splash. - if attention_mask is not None: - mask_len = min(key_seq_len, attention_mask.shape[1]) - kv_mask_for_batch = attention_mask[0, :mask_len] - if key_seq_len > mask_len: - extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) - kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) - if kv_padded_len > key_seq_len: - padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) - kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) - else: - kv_mask_padded = kv_mask_for_batch - kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) - - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + segment_ids = _build_padding_segment_ids(query_seq_len, query.shape[2], key_seq_len, key.shape[2], attention_mask) if not mask_padding_tokens: segment_ids = None @@ -677,11 +704,158 @@ def wrap_ulysses_attention(query, key, value): attention_output = vmapped_splash(query, key, value, segment_ids) attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - # Restore the original layout expected by the rest of the model: - # head-sharded / full-sequence -> sequence-sharded / full-heads. + # Restore original layout: head-sharded/full-sequence -> sequence-sharded/full-heads. + attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + return attention_output + + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" + ) + x = wrap_ulysses_attention(query, key, value) + x = x[:, :, :orig_q_seq_len, :] + x = _reshape_heads_to_head_dim(x) + + return x + + +def _ulysses_ring_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + mask_padding_tokens: bool = True, + residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, + ulysses_axis: str = INTERNAL_ULYSSES_AXIS, + ring_axis: str = INTERNAL_RING_AXIS, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, +) -> jax.Array: + """2D context-parallel attention using a private Ulysses x ring mesh. + + Public configs only shard sequence on the context axis. Internally this + reshapes that same device axis into hidden ring and Ulysses axes, runs the + Ulysses all-to-all over the hidden Ulysses axis, and rotates K/V over the + hidden ring axis. + """ + + context_axis = CONTEXT + if context_axis not in mesh.shape: + raise ValueError(f"Ulysses ring attention requires mesh axis {context_axis!r}, got mesh axes {mesh.shape}.") + + num_context_shards = mesh.shape[context_axis] + num_ulysses_shards = ulysses_shards + if num_ulysses_shards <= 0: + raise ValueError("Ulysses ring attention requires ulysses_shards to be set from config or command line.") + if num_context_shards % num_ulysses_shards != 0: + raise ValueError( + "Ulysses ring attention requires the requested Ulysses shard count to divide the context shard count, " + f"got context_shards={num_context_shards} and ulysses_shards={num_ulysses_shards}." + ) + if heads % num_ulysses_shards != 0: + raise ValueError( + "Ulysses ring attention requires the number of heads to be divisible by the requested Ulysses shard count, " + f"got heads={heads} and ulysses_shards={num_ulysses_shards}." + ) + num_ring_shards = num_context_shards // num_ulysses_shards + internal_mesh = _create_internal_ulysses_ring_mesh( + mesh, + ring_shards=num_ring_shards, + ulysses_shards=num_ulysses_shards, + ring_axis=ring_axis, + ulysses_axis=ulysses_axis, + ) + internal_sequence_axes = (ring_axis, ulysses_axis) + num_sequence_shards = num_context_shards + + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_sequence_shards) + key, _ = _reshape_data_for_flash(key, heads, num_sequence_shards) + value, _ = _reshape_data_for_flash(value, heads, num_sequence_shards) + + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring") + + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + internal_q_axis_names = _replace_mesh_axis_names(q_axis_names, context_axis, internal_sequence_axes) + internal_kv_axis_names = _replace_mesh_axis_names(kv_axis_names, context_axis, internal_sequence_axes) + + @functools.partial( + jax.shard_map, + mesh=internal_mesh, + in_specs=(internal_q_axis_names, internal_kv_axis_names, internal_kv_axis_names), + out_specs=internal_q_axis_names, + check_vma=False, + ) + def wrap_ulysses_ring_attention(query, key, value): + # Swap sharding: each device gives up a slice of heads and gathers + # a slice of sequence, so the local kernel sees the full sequence. + query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=1, concat_axis=2, tiled=True) + + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) + else: + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) + + q_padded_len = query.shape[2] + kv_padded_len = key.shape[2] + total_kv_len = kv_padded_len * num_ring_shards + + # Mask q/kv padding via segment ids, same as the tokamax_ring kernel. Each + # ring shard pads identically so every shard shares the same per-shard ids + # and rotation is unneeded. + segment_ids = _build_padding_segment_ids( + query_seq_len, q_padded_len, key_seq_len, kv_padded_len, attention_mask, tokamax_splash_base.SegmentIds + ) + + if not mask_padding_tokens: + segment_ids = None + + mask = tokamax_splash_attention_mask.FullMask(_shape=(q_padded_len, total_kv_len)) + + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( + mask=mask, + is_mqa=False, + config=convert_to_tokamax_splash_config( + block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ), + save_residuals=False, + ring_axis=ring_axis, + kv_seq_shards=num_ring_shards, + rotate_segment_ids=False, + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, segment_ids) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + # Restore original layout: head-sharded/full-sequence -> sequence-sharded/full-heads. attention_output = jax.lax.all_to_all( attention_output, - axis_name=axis_name, + axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True, @@ -694,7 +868,8 @@ def wrap_ulysses_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) - x = wrap_ulysses_attention(query, key, value) + x = wrap_ulysses_ring_attention(query, key, value) + x = jax.lax.with_sharding_constraint(x, q_axis_names) x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) @@ -865,6 +1040,27 @@ def ulysses_kernel(q, k, v, context): ) +@register_kernel("ulysses_ring") +def ulysses_ring_kernel(q, k, v, context): + return _ulysses_ring_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_base2_exp=context["use_base2_exp"], + use_experimental_scheduler=context["use_experimental_scheduler"], + ulysses_shards=context["ulysses_shards"], + ) + + @register_kernel("flash") def flash_kernel(q, k, v, context): return _tpu_flash_attention( @@ -953,6 +1149,7 @@ def _apply_attention( attention_mask: Array = None, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, ): """Routes to different attention kernels using a module-level registry.""" @@ -962,7 +1159,7 @@ def _apply_attention( seq_len_idx = 2 can_use_flash_attention = True - if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]: + if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom", "ulysses_ring"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -983,6 +1180,7 @@ def _apply_attention( "scale": scale, "use_base2_exp": use_base2_exp, "use_experimental_scheduler": use_experimental_scheduler, + "ulysses_shards": ulysses_shards, "dim_head": dim_head, "split_head_dim": split_head_dim, "float32_qk_product": float32_qk_product, @@ -1204,10 +1402,12 @@ def __init__( residual_checkpoint_name: str | None = None, use_base2_exp: bool = False, use_experimental_scheduler: bool = False, + ulysses_shards: int = -1, ): self.dpa_layer = None self.use_base2_exp = use_base2_exp self.use_experimental_scheduler = use_experimental_scheduler + self.ulysses_shards = ulysses_shards if attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error @@ -1270,6 +1470,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask attention_mask=attention_mask, use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False, use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False, + ulysses_shards=(self.ulysses_shards if hasattr(self, "ulysses_shards") else -1), ) @@ -1290,6 +1491,7 @@ class AttentionOp(nn.Module): quant: Quant = None use_base2_exp: bool = False use_experimental_scheduler: bool = False + ulysses_shards: int = -1 def setup(self): self.dpa_layer = None @@ -1337,6 +1539,7 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask attention_mask=attention_mask, use_base2_exp=self.use_base2_exp, use_experimental_scheduler=self.use_experimental_scheduler, + ulysses_shards=self.ulysses_shards, ) @@ -1373,9 +1576,15 @@ def __init__( enable_jax_named_scopes: bool = False, added_kv_proj_dim: Optional[int] = None, # New for I2V image_seq_len: Optional[int] = None, # New for I2V - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } + if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -1395,8 +1604,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - if attention_kernel == "tokamax_ring" and not is_self_attention: - attention_kernel = "tokamax_flash" # do not use ring attention for cross attention + if attention_kernel in ("tokamax_ring", "ulysses_ring") and not is_self_attention: + attention_kernel = "tokamax_flash" self.added_kv_proj_dim = added_kv_proj_dim # New for I2V self.image_seq_len = image_seq_len # New for I2V tpu_type = get_tpu_type() @@ -1419,8 +1628,9 @@ def __init__( quant=quant, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + use_base2_exp=attention_config["use_base2_exp"], + use_experimental_scheduler=attention_config["use_experimental_scheduler"], + ulysses_shards=attention_config["ulysses_shards"], ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index f5057f50d..40c6be3f7 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -353,10 +353,15 @@ def __init__( dropout: float = 0.0, mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } # 1. Self-attention self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) @@ -379,8 +384,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) # 1. Cross-attention @@ -405,8 +409,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -570,14 +573,19 @@ def __init__( mask_padding_tokens: bool = True, scan_layers: bool = True, enable_jax_named_scopes: bool = False, - use_base2_exp: bool = False, - use_experimental_scheduler: bool = False, + attention_config: Optional[dict] = None, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.num_layers = num_layers self.scan_layers = scan_layers self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": False, + "use_experimental_scheduler": False, + "ulysses_shards": -1, + **(attention_config or {}), + } # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -637,8 +645,7 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, added_kv_proj_dim=added_kv_proj_dim, image_seq_len=image_seq_len, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -667,6 +674,7 @@ def init_block(rngs): precision=precision, attention=attention, enable_jax_named_scopes=enable_jax_named_scopes, + attention_config=attention_config, ) blocks.append(block) self.blocks = nnx.data(blocks) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index fcb9151f8..aa70e1e67 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -97,6 +97,10 @@ def __init__( self.enable_jax_named_scopes = enable_jax_named_scopes self.apply_input_projection = apply_input_projection self.apply_output_projection = apply_output_projection + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + } # 1. Input projection self.proj_in = nnx.data([None]) @@ -132,8 +136,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) # 3. Cross-attention @@ -156,8 +159,7 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) assert cross_attn_norm is True, "cross_attn_norm must be True" self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -342,6 +344,10 @@ def __init__( self.num_layers = num_layers self.scan_layers = scan_layers self.enable_jax_named_scopes = enable_jax_named_scopes + attention_config = { + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + } # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -401,8 +407,7 @@ def __init__( dropout=dropout, mask_padding_tokens=mask_padding_tokens, enable_jax_named_scopes=enable_jax_named_scopes, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + attention_config=attention_config, ) blocks.append(block) self.blocks = blocks @@ -433,8 +438,8 @@ def __init__( enable_jax_named_scopes=enable_jax_named_scopes, apply_input_projection=vace_block_id == 0, apply_output_projection=True, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, + use_base2_exp=attention_config["use_base2_exp"], + use_experimental_scheduler=attention_config["use_experimental_scheduler"], ) vace_blocks.append(vace_block) self.vace_blocks = vace_blocks diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 9e92449c7..0cf4a015d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -138,8 +138,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes - wan_config["use_base2_exp"] = config.use_base2_exp - wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler + wan_config["attention_config"] = { + "use_base2_exp": config.use_base2_exp, + "use_experimental_scheduler": config.use_experimental_scheduler, + "ulysses_shards": getattr(config, "ulysses_shards", -1), + } # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 19f3ec306..46d0c4280 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -38,6 +38,7 @@ RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES, ULYSSES_ATTENTION_AXIS_RULES, + ULYSSES_RING_ATTENTION_AXIS_RULES, ) _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO, LTX2_3} @@ -214,10 +215,11 @@ def user_init(raw_keys): raw_keys["vae_logical_axis_rules"] = _lists_to_tuples(raw_keys["vae_logical_axis_rules"]) # Verify qkv is sharded across sequence. attention = raw_keys["attention"] - uses_ring_attention = "ring" in attention - uses_ulysses_attention = "ulysses" in attention + uses_ulysses_ring_attention = attention == "ulysses_ring" + uses_ring_attention = "ring" in attention and not uses_ulysses_ring_attention + uses_ulysses_attention = "ulysses" in attention and not uses_ulysses_ring_attention uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"] - if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding: + if uses_ring_attention or uses_ulysses_attention or uses_ulysses_ring_attention or uses_uniform_sequence_sharding: max_logging.log( "Adding sequence sharding to q and kv if not already present because " f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set." @@ -233,7 +235,12 @@ def user_init(raw_keys): if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}") - if uses_ring_attention: + if uses_ulysses_ring_attention: + for ulysses_ring_attention_axis_rule in ULYSSES_RING_ATTENTION_AXIS_RULES: + if ulysses_ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ulysses ring attention axis rule {ulysses_ring_attention_axis_rule}") + new_rules.append(ulysses_ring_attention_axis_rule) + elif uses_ring_attention: for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 5c95dff8b..708af4066 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -43,6 +43,10 @@ def _ulysses_mesh(self): devices = np.array(jax.devices()[:2]).reshape(1, 1, 2, 1) return Mesh(devices, ("data", "fsdp", "context", "tensor")) + def _ulysses_ring_mesh(self): + devices = np.array(jax.devices()[:4]).reshape(1, 1, 4, 1) + return Mesh(devices, ("data", "fsdp", "context", "tensor")) + def _ulysses_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -52,6 +56,15 @@ def _ulysses_axis_rules(self): (attention_flax.D_KV, None), ) + def _ulysses_ring_axis_rules(self): + return ( + (attention_flax.BATCH, "data"), + (attention_flax.SELF_ATTN_HEAD, None), + (attention_flax.SELF_ATTN_Q_LENGTH, "context"), + (attention_flax.SELF_ATTN_KV_LENGTH, "context"), + (attention_flax.D_KV, None), + ) + def _flash_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -441,6 +454,296 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, expected)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention layout test requires at least 4 devices.") + def test_ulysses_ring_attention_round_trips_query_when_heads_are_divisible(self): + """Hybrid Ulysses+ring should preserve layout while only exposing context.""" + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + def fake_make_ring_attention(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=4, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention mask test requires at least 4 devices.") + def test_ulysses_ring_attention_masks_global_kv_padding(self): + """Hybrid Ulysses+ring masks padding via segment ids, not a NumpyMask.""" + batch = 1 + length = 7 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + segment_ids_seen = [] + + def fake_make_ring_attention(**kwargs): + # Padding should be handled by segment ids, so the kernel gets a FullMask. + assert kwargs["rotate_segment_ids"] is False + assert not hasattr(kwargs["mask"], "array") + + def fake_kernel(q, k, v, segment_ids): + del k, v + # Record padding masking is segment-id based; don't leak tracers outside. + segment_ids_seen.append(segment_ids is not None and hasattr(segment_ids, "q") and hasattr(segment_ids, "kv")) + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + # Padding is masked via segment ids (q and kv), not a NumpyMask. + self.assertEqual(segment_ids_seen, [True]) + + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention mask test requires at least 4 devices.") + def test_ulysses_ring_attention_folds_attention_mask_into_segment_ids(self): + """Hybrid Ulysses+ring zeros masked kv tokens and rotates segment ids.""" + batch = 1 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + attention_mask = jnp.array([[1, 1, 0, 0, 1, 1, 1, 0]], dtype=jnp.int32) + mesh = self._ulysses_ring_mesh() + seen = [] + + def fake_make_ring_attention(**kwargs): + seen.append(kwargs["rotate_segment_ids"]) + + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + query, + query, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + attention_mask=attention_mask, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + # Same convention as the tokamax_ring kernel: shards pad identically, no rotation. + self.assertEqual(seen, [False]) + + def test_ulysses_ring_attention_raises_when_heads_are_not_divisible_by_ulysses_shards(self): + """The hidden all-to-all head split still requires divisible heads.""" + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 3 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex( + ValueError, + r"heads=3 and ulysses_shards=2", + ): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=2, + ) + + def test_ulysses_ring_attention_raises_when_ulysses_shards_are_not_set(self): + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex(ValueError, r"ulysses_shards"): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + def test_ulysses_ring_attention_raises_when_ulysses_shards_do_not_divide_context(self): + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex(ValueError, r"context_shards=4 and ulysses_shards=3"): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ulysses_shards=3, + ) + if __name__ == "__main__": absltest.main()