Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/ltx2_3_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ remat_policy: "NONE"
jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
text_encoder_dtype: 'bfloat16'
compile_text_encoder: False

run_name: 'ltx2_inference'
output_dir: ''
base_output_directory: ''
config_path: ''
save_config_to_gcs: False

Expand Down Expand Up @@ -69,6 +72,12 @@ logical_axis_rules: [
]
data_sharding: ['data', 'fsdp', 'context', 'tensor']

sharding:
transformer: 'default'
vae: 'default'
text_encoder: 'default'
text_connector: 'default'

dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1

Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ remat_policy: "NONE"
jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
text_encoder_dtype: 'bfloat16'
compile_text_encoder: False

run_name: 'ltx2_inference'
output_dir: ''
base_output_directory: ''
config_path: ''
save_config_to_gcs: False

Expand Down Expand Up @@ -74,6 +77,12 @@ logical_axis_rules: [
]
data_sharding: ['data', 'fsdp', 'context', 'tensor']

sharding:
transformer: 'default'
vae: 'default'
text_encoder: 'default'
text_connector: 'default'

dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1

Expand Down
24 changes: 15 additions & 9 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import functools
from functools import partial, reduce
from contextlib import nullcontext
from typing import Dict, Callable
from typing import (
Any,
Callable,
Dict,
Set,
Tuple,
Union,
)
import json
import yaml
import os
Expand All @@ -36,7 +43,6 @@
import optax
Comment thread
Perseus14 marked this conversation as resolved.
from maxdiffusion import max_logging
from maxdiffusion.checkpointing import checkpointing_utils
from maxdiffusion.models.attention_flax import AttentionOp
import flax.linen as nn
import flax.linen.module as module_lib
from flax.linen.summary import _process_inputs
Expand All @@ -50,13 +56,6 @@

from transformers import FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel
from flax import struct
from typing import (
Callable,
Any,
Tuple,
Union,
Set,
)
from flax import core
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel

Expand Down Expand Up @@ -676,6 +675,8 @@ def get_live_arrays():
# to retrieve layer parameters and calculate
def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs):
"""Calculates model tflops by passing a module."""
from maxdiffusion.models.attention_flax import AttentionOp
Comment thread
Perseus14 marked this conversation as resolved.

with module_lib._tabulate_context():
_ = jax.eval_shape(module.init, rngs, **kwargs)
calls = module_lib._context.call_info_stack[-1].calls
Expand Down Expand Up @@ -769,3 +770,8 @@ def maybe_initialize_jax_distributed_system(raw_keys):
max_logging.log("Jax distributed system initialized on GPU!")
else:
jax.distributed.initialize()


def safe_getattr(obj: Any, name: str, default: Any) -> Any:
"""Safely reads attribute from an object, returning default if obj is None or attribute missing."""
return getattr(obj, name, default) if obj is not None else default
30 changes: 11 additions & 19 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import contextlib
import functools
import math
from typing import Optional, Callable, Tuple, Dict
from typing import Optional, Callable, Tuple, Any, Dict
import flax.linen as nn
from flax import nnx
import jax
Expand All @@ -31,6 +31,7 @@
from einops import rearrange
from .. import common_types, max_logging
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
from maxdiffusion.max_utils import safe_getattr


from ..kernels import custom_splash_attention as custom_splash
Expand Down Expand Up @@ -1131,24 +1132,15 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: Optional[jax.lax.Precision] = None,
sharding_specs: Optional[Any] = None,
):
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

tpu_type = get_tpu_type()
is_ironwood = tpu_type == TpuType.TPU_7X

# Hardware-aware sharding specs: Ironwood (v7x) keeps the embedding dimension (embed)
# replicated (None) to minimize cross-device communication, while other hardware (default)
# shards it to prevent OOM issues.
if is_ironwood:
net0_kernel_spec = (None, "mlp")
net2_kernel_spec = ("mlp", None)
net2_bias_spec = (None,)
else:
net0_kernel_spec = ("embed", "mlp")
net2_kernel_spec = ("mlp", "embed")
net2_bias_spec = ("embed",)
net_0_kernel = safe_getattr(sharding_specs, "net_0_kernel", ("embed", "mlp"))
net_0_bias = safe_getattr(sharding_specs, "net_0_bias", ("mlp",))
net_2_kernel = safe_getattr(sharding_specs, "net_2_kernel", ("mlp", "embed"))
net_2_bias = safe_getattr(sharding_specs, "net_2_bias", ("embed",))

self.net_0 = nnx.Linear(
dim,
Expand All @@ -1158,8 +1150,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net0_kernel_spec),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_0_kernel),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_0_bias),
)
self.act = get_activation(activation_fn)
self.net_2 = nnx.Linear(
Expand All @@ -1170,8 +1162,8 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net2_kernel_spec),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net2_bias_spec),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net_2_kernel),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, net_2_bias),
)

def __call__(self, hidden_states: Array) -> Array:
Expand Down
63 changes: 39 additions & 24 deletions src/maxdiffusion/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from typing import Optional, Any
import flax.linen as nn
from flax import nnx
import jax.numpy as jnp
Expand All @@ -22,6 +22,7 @@
from ..models.attention_flax import NNXSimpleFeedForward
from ..models.normalization_flax import FP32LayerNorm
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
from maxdiffusion.max_utils import safe_getattr


def get_sinusoidal_embeddings(
Expand Down Expand Up @@ -85,7 +86,12 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
sharding_specs: Optional[Any] = None,
):
linear_1_kernel = safe_getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
linear_1_bias = safe_getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
linear_2_kernel = safe_getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
linear_2_bias = safe_getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
self.linear_1 = nnx.Linear(
rngs=rngs,
in_features=in_channels,
Expand All @@ -96,12 +102,9 @@ def __init__(
precision=precision,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"embed",
"mlp",
),
linear_1_kernel,
),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
)

if cond_proj_dim is not None:
Expand All @@ -128,12 +131,9 @@ def __init__(
precision=precision,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"mlp",
"embed",
),
linear_2_kernel,
),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
)

if post_act_fn is None:
Expand Down Expand Up @@ -341,7 +341,12 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
sharding_specs: Optional[Any] = None,
):
linear_1_kernel = safe_getattr(sharding_specs, "emb_linear_1_kernel", ("embed", "mlp"))
linear_1_bias = safe_getattr(sharding_specs, "emb_linear_1_bias", ("mlp",))
linear_2_kernel = safe_getattr(sharding_specs, "emb_linear_2_kernel", ("mlp", "embed"))
linear_2_bias = safe_getattr(sharding_specs, "emb_linear_2_bias", ("embed",))
if out_features is None:
out_features = hidden_size

Expand All @@ -355,12 +360,9 @@ def __init__(
precision=precision,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"embed",
"mlp",
),
linear_1_kernel,
),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_1_bias),
)
self.act_1 = get_activation(act_fn)

Expand All @@ -374,12 +376,9 @@ def __init__(
precision=precision,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(
"mlp",
"embed",
),
linear_2_kernel,
),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(nnx.initializers.zeros, linear_2_bias),
)

def __call__(self, caption):
Expand Down Expand Up @@ -535,22 +534,38 @@ def __init__(
use_additional_conditions: bool = False,
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
sharding_specs: Optional[Any] = None,
):
self.outdim = size_emb_dim
self.use_additional_conditions = use_additional_conditions

self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
rngs=rngs,
in_channels=256,
time_embed_dim=embedding_dim,
dtype=dtype,
weights_dtype=weights_dtype,
sharding_specs=sharding_specs,
)

if use_additional_conditions:
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
rngs=rngs,
in_channels=256,
time_embed_dim=size_emb_dim,
dtype=dtype,
weights_dtype=weights_dtype,
sharding_specs=sharding_specs,
)
self.aspect_ratio_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
rngs=rngs,
in_channels=256,
time_embed_dim=size_emb_dim,
dtype=dtype,
weights_dtype=weights_dtype,
sharding_specs=sharding_specs,
)

def __call__(
Expand Down
37 changes: 13 additions & 24 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp
from ... import common_types
from ..attention_flax import NNXAttentionOp
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
from .logical_sharding_ltx2 import get_sharding_specs, LTX2DiTShardingSpecs

Array = common_types.Array
Mesh = common_types.Mesh
Expand Down Expand Up @@ -350,9 +350,7 @@ def __init__(
rope_type: str = "interleaved",
flash_block_sizes: BlockSizes = None,
flash_min_seq_length: int = 4096,
qkv_sharding_spec: Optional[tuple] = None,
out_sharding_spec: Optional[tuple] = None,
out_bias_sharding_spec: Optional[tuple] = None,
sharding_specs: Optional[LTX2DiTShardingSpecs] = None,
gated_attn: bool = False,
):
self.heads = heads
Expand All @@ -361,33 +359,24 @@ def __init__(
self.inner_dim = dim_head * heads
self.dropout_rate = dropout

# Auto-detect hardware for sharding specs if not overridden
tpu_type = get_tpu_type()
is_ironwood = tpu_type == TpuType.TPU_7X

# Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated)
# to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions.
# This has currently only been tested on Trillium (v6e) and Ironwood (v7x).
if qkv_sharding_spec is None:
qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads")
if out_sharding_spec is None:
out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed")
if out_bias_sharding_spec is None:
out_bias_sharding_spec = (None,) if is_ironwood else ("embed",)
if sharding_specs is None:
specs = get_sharding_specs("default", "ltx2_dit")
else:
specs = sharding_specs

# 1. Define Partitioned Initializers (Logical Axes)
# Q, K, V kernels: [in_features (embed), out_features (heads)]
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec)
qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.qkv_kernel)
# Q, K, V biases: [out_features (heads)]
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.qkv_bias)

# Out kernel: [in_features (heads), out_features (embed)]
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec)
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.out_kernel)
# Out bias: [out_features (embed)]
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec)
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), specs.out_bias)

# Norm scales
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), specs.norm_scale)

# 2. Projections
self.to_q = nnx.Linear(
Expand Down Expand Up @@ -450,8 +439,8 @@ def __init__(
query_dim,
heads,
use_bias=True,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), specs.gate_logits_kernel),
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), specs.gate_logits_bias),
rngs=rngs,
dtype=dtype,
)
Expand Down
Loading
Loading