diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 915fe453b90b..6491d17b4f46 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -203,10 +203,12 @@ def post_forward(self, module, output): def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: - raise ValueError( - f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + logger.warning_once( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + return x + else: + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) class ContextParallelGatherHook(ModelHook): diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..6f3993eb3f64 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -555,6 +555,9 @@ class WanTransformer3DModel( "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + "": { + "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, } @register_to_config