feat(ltx2): implement centralized, configuration-driven logical sharding strategy for LTX-2 and LTX-2.3#414
feat(ltx2): implement centralized, configuration-driven logical sharding strategy for LTX-2 and LTX-2.3#414Perseus14 wants to merge 1 commit into
Conversation
f6789b2 to
a37a799
Compare
c35d3da to
b45ac10
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The Pull Request successfully centralizes the sharding strategy for LTX-2 and LTX-2.3, which is a great architectural improvement. It eliminates hardware-specific logic from individual model layers and moves it to a configuration-driven registry. This significantly improves maintainability and makes it easier to support new hardware in the future.
🔍 General Feedback
- Correctness: Identified a potential
AttributeErrorin core model files (attention_flax.py,embeddings_flax.py) whensharding_specsisNone. This needs to be addressed as it will cause crashes when these components are used with default arguments. - Efficiency: The sharding specs are resolved repeatedly during inference in the pipeline. Storing these specs as pipeline attributes during initialization would be a minor but worthwhile optimization.
- Robustness: The strategy lookup logic silently defaults to a specific hardware profile on unknown input, which could hide configuration typos.
- Tests: The inclusion of
test_logical_sharding_ltx2.pyand updates to existing tests provide good coverage for the new sharding logic.
5bcf2e8 to
48c2d3d
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3. The refactoring significantly improves the modularity and maintainability of the sharding logic by decoupling it from individual model layers and hardware-specific checks.
🔍 General Feedback
- Architecture: The introduction of
logical_sharding_ltx2.pyis a great architectural improvement, making sharding strategies explicit and easily extensible. - Robustness: The use of
safe_getattrand fallback logic ensures that the model remains functional even with incomplete sharding specifications. - Performance: Moving pipeline decisions like VAE replication and text-encoder batching to the registry allows for better hardware-specific tuning.
- Minor Issues: I've noted a likely discrepancy in the text encoder batching logic for Ironwood and a change in the default VAE replication behavior that should be confirmed.
| "ltx2_dit": LTX2DiTShardingSpecs( | ||
| qkv_kernel=(None, "heads"), | ||
| out_kernel=("heads", None), | ||
| out_bias=(None,), |
There was a problem hiding this comment.
🟠 The use_batched_text_encoder flag for the ironwood strategy is set to False, but the previous logic in ltx2_pipeline.py (and the accompanying comment) indicated that batching the text encoder gives better results on Ironwood (TPU v7x). This change might lead to a performance regression on that hardware.
| out_bias=(None,), | |
| "text_encoder": TextEncoderShardingSpecs( | |
| use_batched_text_encoder=True, | |
| ), |
|
|
||
| @dataclass | ||
| class VAEShardingSpecs: | ||
| """Sharding specs for the VAE.""" |
There was a problem hiding this comment.
🟡 The default for force_replication in VAEShardingSpecs is set to True, which effectively changes the default behavior for VAE replication from False (in the previous ltx2_pipeline.py logic) to True for all strategies. While this is likely intended given the performance benefits mentioned in the comments, it's a notable change in default behavior.
| """Sharding specs for the VAE.""" | |
| force_replication: bool = True |
| weights_dtype: jnp.dtype = jnp.float32, | ||
| sharding_specs: Optional[LTX2DiTShardingSpecs] = None, | ||
| ): | ||
| self.num_mod_params = num_mod_params |
There was a problem hiding this comment.
🟢 Consider moving the get_sharding_specs import to the top of the file to follow standard Python practices, unless it was specifically placed here to avoid a circular dependency that couldn't be resolved otherwise.
| self.num_mod_params = num_mod_params | |
| if sharding_specs is None: | |
| from .logical_sharding_ltx2 import get_sharding_specs | |
| sharding_specs = get_sharding_specs("default", "ltx2_dit") |
…ing strategy for LTX-2 and LTX-2.3 TAG=agy CONV=2409d313-b545-4d9e-8d58-38945c7d46eb
Summary
This PR introduces a centralized, configuration-driven logical sharding strategy registry for LTX-2 and LTX-2.3 in MaxDiffusion. It eliminates ad-hoc hardware checks and hardcoded sharding constraints in model layers by moving sharding specifications to a centralized, hardware-aware registry.
Key Changes
logical_sharding_ltx2.pyto define sharding spec profiles for Ironwood (TPU v7x, 1D heads-wise sharding) and Trillium (TPU v6e, 2D heads + embed sharding).attention_flax.pyandembeddings_flax.pyusing generic duck-typing interfaces (getattrfallback logic) to prevent code coupling.force_replication) and text-encoding batching (use_batched_text_encoder) pipeline decisions to be configuration-driven under the central spec registry.sharding,text_encoder_dtype,compile_text_encoder, andbase_output_directoryparameters to LTX-2/2.3 configs, enabling dynamic text-encoder compilation and clean overrides via the CLI.test_logical_sharding_ltx2.py) to verify routing and hardware auto-detection logic.