-
Notifications
You must be signed in to change notification settings - Fork 6.5k
fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled #12562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
|
torchrun --nproc-per-node 2 test.py crash stack: |
|
@yiyixuxu @sayakpaul please help review |
|
Could you also supplement an output with the fix? |
| ) | ||
| if ts_seq_len is not None: | ||
| # Check if running under context parallel and split along seq_len dimension | ||
| if hasattr(self, '_parallel_config') and self._parallel_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate why this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when cp is enabled, seq_len is split, timestep_shape is [batch_size, seq_len, 6, inner_dim], so should be split in dim_1 as well since hidden state is split in seq_len dim as well. or else shape miss match will occur
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think we can just change _cp_plan? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan.py#L546
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean split timestep in forward? adding
"": {
"timestep": ContextParallelInput(split_dim=1, split_output=False)
}, to _cp_plan will make 5B work, but 14B fail since 5B timestep dims is 2. 14 timestep dims is 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is an interesting situation. To tackle these, I think we might have to revisit the ContextParallelInput and ContextParallelOutput definitions a bit more.
If we had a way to tell the partitioner that the input might have "dynamic" dimensions depending on the model configs (like in this case), and what it should do if that's the case, it might be more flexible as a solution.
@DN6 curious to know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another solution that come to my mind is a fix in pipeline wan. clean and more generic change like
from .pipeline_output import WanPipelineOutput
-
+from ...models._modeling_parallel import ContextParallelInput, ContextParallelOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -150,6 +150,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(expand_timesteps=expand_timesteps)
+ if expand_timesteps:
+ transformer._cp_plan.update({"": {
+ "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False)
+ }})
WDYT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change in src/diffusers/pipelines/wan/pipeline_wan.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is probably a bit more intrusive. I think what you suggested earlier would work as @DN6 suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I like this the best 👍🏽
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only downside to setting it via the pipeline is that it would still error out if you were trying to run inference outside the WanPipeline. So perhaps this is the best way to handle it for now.
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
logger.warning(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return x
else:
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Could you please explain the changes and also provide an example output?

fix the crash when testing CP for wan2.2-TI2V-5B
test script: