diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index a976126da7fe..92fe8558b7a7 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -14,6 +14,7 @@ import html import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import regex as re @@ -502,6 +503,8 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + control_hidden_states: Optional[torch.Tensor] = None, + control_hidden_states_scale: Optional[torch.Tensor] = None, ): r""" The call function to the pipeline for generation. @@ -559,6 +562,13 @@ def __call__( max_sequence_length (`int`, defaults to `512`): The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. + control_hidden_states (`torch.Tensor`, *optional*): + Control tensor for the VACE control path. Shape: `(B, C, T_patch, H_patch, W_patch)`. If omitted, a + neutral zero tensor of the correct size/dtype is created automatically. **If the underlying transformer + does not support these kwargs, this argument is ignored.** + control_hidden_states_scale (`torch.Tensor`, *optional*): + 1D tensor of scaling factors for VACE layers (length = number of VACE layers). Defaults to ones. + **Ignored if unsupported.** Examples: @@ -593,6 +603,20 @@ def __call__( self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False + base_tr = ( + self.transformer.get_base_model() if hasattr(self.transformer, "get_base_model") else self.transformer + ) + _sig = inspect.signature(base_tr.forward) + _supports_control = ( + "control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters + ) + # Warn if user passed control kwargs but model won't consume them + if not _supports_control and (control_hidden_states is not None or control_hidden_states_scale is not None): + warnings.warn( + "control_hidden_states/control_hidden_states_scale were provided, but the underlying transformer " + "does not accept these kwargs; they will be ignored.", + stacklevel=2, + ) device = self._execution_device @@ -647,6 +671,30 @@ def __call__( latent_timestep, ) + # Precompute shapes we’ll need + B = batch_size * num_videos_per_prompt + + # Build neutral control tensors only if the base transformer supports them + if _supports_control: + cfg_tr = self.transformer.config # FrozenDict-like + C_ctrl = cfg_tr.get("vace_in_channels", cfg_tr.get("out_channels", cfg_tr.get("in_channels", 320))) + ps = cfg_tr.get("patch_size", (1, 1, 1)) + if isinstance(ps, int): + pt = ph = pw = ps + else: + pt, ph, pw = (ps[0], ps[1], ps[2]) if len(ps) == 3 else (ps[0], ps[0], ps[0]) + + # On first use, create neutral one-token control + if control_hidden_states is None: + control_hidden_states = torch.zeros( + (B, int(C_ctrl), int(pt), int(ph), int(pw)), device=device, dtype=transformer_dtype + ) + # Layer-wise scale vector (not batched) + if control_hidden_states_scale is None: + vls = cfg_tr.get("vace_layers", []) + n_layers = len(vls) if isinstance(vls, (list, tuple)) else int(vls or 0) + control_hidden_states_scale = torch.ones(max(1, n_layers), device=device, dtype=transformer_dtype) + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -660,22 +708,35 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + # Prepare kwargs for transformer call; keep identical for cond/uncond (swap only encoder_hidden_states) + call_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } + + # If supported, attach control tensors; ensure batch/device/dtype match latent input + if _supports_control: + if control_hidden_states.shape[0] != latent_model_input.shape[0]: + control_hidden_states = control_hidden_states.expand( + latent_model_input.shape[0], -1, -1, -1, -1 + ) + call_kwargs["control_hidden_states"] = control_hidden_states.to( + device=latent_model_input.device, dtype=transformer_dtype + ) + call_kwargs["control_hidden_states_scale"] = control_hidden_states_scale.to( + device=latent_model_input.device, dtype=transformer_dtype + ) + + # Cond pass + noise_pred = self.transformer(**call_kwargs)[0] if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + # Uncond pass: swap encoder_hidden_states; keep control kwargs identical + call_kwargs["encoder_hidden_states"] = negative_prompt_embeds + noise_uncond = self.transformer(**call_kwargs)[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py index 27ada121ca48..19713ac38e70 100644 --- a/tests/pipelines/wan/test_wan_video_to_video.py +++ b/tests/pipelines/wan/test_wan_video_to_video.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest import torch @@ -50,6 +51,12 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False supports_dduf = False + def _supports_control_kwargs(self, transformer) -> bool: + """Return True if the base transformer's forward() accepts VACE control kwargs.""" + base = transformer.get_base_model() if hasattr(transformer, "get_base_model") else transformer + sig = inspect.signature(base.forward) + return "control_hidden_states" in sig.parameters and "control_hidden_states_scale" in sig.parameters + def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLWan( @@ -147,3 +154,102 @@ def test_float16_inference(self): ) def test_save_load_float16(self): pass + + def test_neutral_control_injection_no_crash_latent(self): + device = "cpu" + + # Reuse the same tiny components + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + + # If transformer doesn't support control kwargs, this test isn't applicable. + if not self._supports_control_kwargs(pipe.transformer): + self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control injection test.") + + # --- Ensure VACE fields exist for control tensor sizing --- + # Prefer real module in_channels if present + pe = getattr(pipe.transformer, "vace_patch_embedding", None) + if pe is not None and hasattr(pe, "in_channels"): + vace_in = int(pe.in_channels) + else: + # fallback to model config fields + vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels)) + # also set it to help the pipeline code path + pipe.transformer.config.vace_in_channels = vace_in + + # vace_layers: ensure non-empty so scale vector has length >=1 + if not hasattr(pipe.transformer.config, "vace_layers"): + pipe.transformer.config.vace_layers = [0, 1] + + # Patch: we run in latent mode; skip VAE decode & video preprocessing + # Build tiny latents matching transformer.config.in_channels + C = int(pipe.transformer.config.in_channels) + # Very small T/H/W to keep speed + latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32) + + out = pipe( + video=None, + prompt="test", + negative_prompt=None, + height=16, + width=16, + num_inference_steps=2, + guidance_scale=1.0, # disable CFG branch to keep path minimal + strength=0.5, + generator=None, + latents=latents, # <- latent path, so we don’t need real VAE/video_processor + prompt_embeds=None, + negative_prompt_embeds=None, + output_type="latent", # <- prevents decode/postprocess + return_dict=True, + max_sequence_length=16, + ).frames + + # Assert: no crash and the latent shape is preserved + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(tuple(out.shape), tuple(latents.shape)) + + def test_neutral_control_injection_with_cfg(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + + if not self._supports_control_kwargs(pipe.transformer): + self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control+CFG test.") + + # Ensure VACE sizing hints exist (as above) + pe = getattr(pipe.transformer, "vace_patch_embedding", None) + if pe is not None and hasattr(pe, "in_channels"): + vace_in = int(pe.in_channels) + else: + vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels)) + pipe.transformer.config.vace_in_channels = vace_in + if not hasattr(pipe.transformer.config, "vace_layers"): + pipe.transformer.config.vace_layers = [0, 1, 2] + + C = int(pipe.transformer.config.in_channels) + latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32) + + out = pipe( + video=None, + prompt="test", + negative_prompt="", + height=16, + width=16, + num_inference_steps=2, + guidance_scale=3.5, # trigger CFG (uncond) path + strength=0.5, + generator=None, + latents=latents, + prompt_embeds=None, + negative_prompt_embeds=None, + output_type="latent", + return_dict=True, + max_sequence_length=16, + ).frames + + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(tuple(out.shape), tuple(latents.shape))