Skip to content

feat(ltx2): implement centralized, configuration-driven logical sharding strategy for LTX-2 and LTX-2.3#414

Open
Perseus14 wants to merge 1 commit into
mainfrom
ltx2_sharding
Open

feat(ltx2): implement centralized, configuration-driven logical sharding strategy for LTX-2 and LTX-2.3#414
Perseus14 wants to merge 1 commit into
mainfrom
ltx2_sharding

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

Summary

This PR introduces a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3 in MaxDiffusion. It eliminates ad-hoc hardware checks and hardcoded sharding constraints in model layers by moving sharding specifications to a centralized, hardware-aware registry.

Key Changes

  • Centralized Specs Registry: Created logical_sharding_ltx2.py to define sharding spec profiles for Ironwood (TPU v7x, 1D heads-wise sharding) and Trillium (TPU v6e, 2D heads + embed sharding).
  • Model Weight Parameterization: Replaced hardcoded partitioning in LTX2 transformer, attention, VAE timestep embeddings, connectors, and the new LTX-2.3 gated attention projection layers.
  • Decoupled Shared Layers: Parameterized shared FFN and text projection layers in attention_flax.py and embeddings_flax.py using generic duck-typing interfaces (getattr fallback logic) to prevent code coupling.
  • Config-Driven Pipeline Choices: Moved VAE replication (force_replication) and text-encoding batching (use_batched_text_encoder) pipeline decisions to be configuration-driven under the central spec registry.
  • Robust Configuration & CLI support: Added sharding, text_encoder_dtype, compile_text_encoder, and base_output_directory parameters to LTX-2/2.3 configs, enabling dynamic text-encoder compilation and clean overrides via the CLI.
  • Verification: Added non-brittle unit tests (test_logical_sharding_ltx2.py) to verify routing and hardware auto-detection logic.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 21, 2026 07:53
@github-actions
Copy link
Copy Markdown

@Perseus14 Perseus14 force-pushed the ltx2_sharding branch 4 times, most recently from f6789b2 to a37a799 Compare May 21, 2026 10:12
@Perseus14 Perseus14 requested review from mbohlool and prishajain1 May 21, 2026 10:19
@Perseus14 Perseus14 force-pushed the ltx2_sharding branch 4 times, most recently from c35d3da to b45ac10 Compare May 21, 2026 16:04
@Perseus14 Perseus14 self-assigned this May 21, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request successfully centralizes the sharding strategy for LTX-2 and LTX-2.3, which is a great architectural improvement. It eliminates hardware-specific logic from individual model layers and moves it to a configuration-driven registry. This significantly improves maintainability and makes it easier to support new hardware in the future.

🔍 General Feedback

  • Correctness: Identified a potential AttributeError in core model files (attention_flax.py, embeddings_flax.py) when sharding_specs is None. This needs to be addressed as it will cause crashes when these components are used with default arguments.
  • Efficiency: The sharding specs are resolved repeatedly during inference in the pipeline. Storing these specs as pipeline attributes during initialization would be a minor but worthwhile optimization.
  • Robustness: The strategy lookup logic silently defaults to a specific hardware profile on unknown input, which could hide configuration typos.
  • Tests: The inclusion of test_logical_sharding_ltx2.py and updates to existing tests provide good coverage for the new sharding logic.

Comment thread src/maxdiffusion/models/attention_flax.py Outdated
Comment thread src/maxdiffusion/models/embeddings_flax.py Outdated
Comment thread src/maxdiffusion/models/embeddings_flax.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/models/ltx2/logical_sharding_ltx2.py Outdated
@Perseus14 Perseus14 force-pushed the ltx2_sharding branch 4 times, most recently from 5bcf2e8 to 48c2d3d Compare May 21, 2026 17:24
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3. The refactoring significantly improves the modularity and maintainability of the sharding logic by decoupling it from individual model layers and hardware-specific checks.

🔍 General Feedback

  • Architecture: The introduction of logical_sharding_ltx2.py is a great architectural improvement, making sharding strategies explicit and easily extensible.
  • Robustness: The use of safe_getattr and fallback logic ensures that the model remains functional even with incomplete sharding specifications.
  • Performance: Moving pipeline decisions like VAE replication and text-encoder batching to the registry allows for better hardware-specific tuning.
  • Minor Issues: I've noted a likely discrepancy in the text encoder batching logic for Ironwood and a change in the default VAE replication behavior that should be confirmed.

"ltx2_dit": LTX2DiTShardingSpecs(
qkv_kernel=(None, "heads"),
out_kernel=("heads", None),
out_bias=(None,),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The use_batched_text_encoder flag for the ironwood strategy is set to False, but the previous logic in ltx2_pipeline.py (and the accompanying comment) indicated that batching the text encoder gives better results on Ironwood (TPU v7x). This change might lead to a performance regression on that hardware.

Suggested change
out_bias=(None,),
"text_encoder": TextEncoderShardingSpecs(
use_batched_text_encoder=True,
),


@dataclass
class VAEShardingSpecs:
"""Sharding specs for the VAE."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The default for force_replication in VAEShardingSpecs is set to True, which effectively changes the default behavior for VAE replication from False (in the previous ltx2_pipeline.py logic) to True for all strategies. While this is likely intended given the performance benefits mentioned in the comments, it's a notable change in default behavior.

Suggested change
"""Sharding specs for the VAE."""
force_replication: bool = True

weights_dtype: jnp.dtype = jnp.float32,
sharding_specs: Optional[LTX2DiTShardingSpecs] = None,
):
self.num_mod_params = num_mod_params
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Consider moving the get_sharding_specs import to the top of the file to follow standard Python practices, unless it was specifically placed here to avoid a circular dependency that couldn't be resolved otherwise.

Suggested change
self.num_mod_params = num_mod_params
if sharding_specs is None:
from .logical_sharding_ltx2 import get_sharding_specs
sharding_specs = get_sharding_specs("default", "ltx2_dit")

…ing strategy for LTX-2 and LTX-2.3

TAG=agy
CONV=2409d313-b545-4d9e-8d58-38945c7d46eb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant