From f48d4d269b8da8384d9b95fd78bad995e3baa6ad Mon Sep 17 00:00:00 2001 From: Mathias Scherman Date: Sun, 2 Nov 2025 12:51:48 +0000 Subject: [PATCH 1/4] Fix WanVideoToVideoPipeline to conditionally handle VACE control inputs --- .../pipelines/wan/pipeline_wan_video2video.py | 65 +++++++++-- .../pipelines/wan/test_wan_video_to_video.py | 106 ++++++++++++++++++ 2 files changed, 162 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index a976126da7fe..4b8ffed1e2e2 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -502,6 +502,8 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + control_hidden_states: Optional[torch.Tensor] = None, + control_hidden_states_scale: Optional[torch.Tensor] = None, ): r""" The call function to the pipeline for generation. @@ -559,6 +561,14 @@ def __call__( max_sequence_length (`int`, defaults to `512`): The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. + control_hidden_states (`torch.Tensor`, *optional*): + Control tensor for the VACE control path. Expected shape: + `(B, C, T_patch, H_patch, W_patch)`. If not provided, a neutral zero tensor of the correct + size and dtype is automatically created. + **Note:** This argument was added to prevent crashes when running without control features. + control_hidden_states_scale (`torch.Tensor`, *optional*): + A 1D tensor of scaling factors (length = number of VACE layers). Defaults to a vector of + ones if not provided. Examples: @@ -593,6 +603,11 @@ def __call__( self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False + base_tr = self.transformer.get_base_model() if hasattr(self.transformer, "get_base_model") else self.transformer + _sig = inspect.signature(base_tr.forward) + _supports_control = ( + "control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters + ) device = self._execution_device @@ -647,6 +662,27 @@ def __call__( latent_timestep, ) + # Precompute shapes/dtypes/devices we’ll need + B = batch_size * num_videos_per_prompt + + # Optionally build neutral control tensors if supported + if _supports_control: + cfg_tr = self.transformer.config # FrozenDict-like + C_ctrl = cfg_tr.get("vace_in_channels", cfg_tr.get("out_channels", cfg_tr.get("in_channels", 320))) + ps = cfg_tr.get("patch_size", (1, 1, 1)) + if isinstance(ps, int): + pt = ph = pw = ps + else: + pt, ph, pw = (ps[0], ps[1], ps[2]) if len(ps) == 3 else (ps[0], ps[0], ps[0]) + + if control_hidden_states is None: + control_hidden_states = torch.zeros((B, int(C_ctrl), int(pt), int(ph), int(pw)), + device=device, dtype=transformer_dtype) + if control_hidden_states_scale is None: + vls = cfg_tr.get("vace_layers", []) + n_layers = len(vls) if isinstance(vls, (list, tuple)) else int(vls or 0) + control_hidden_states_scale = torch.ones(max(1, n_layers), device=device, dtype=transformer_dtype) + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -660,22 +696,33 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( + # Build call kwargs + call_kwargs = dict( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, - )[0] + ) + # If supported, attach control tensors; ensure correct batch/device/dtype + if _supports_control: + if control_hidden_states.shape[0] != latent_model_input.shape[0]: + control_hidden_states = control_hidden_states.expand(latent_model_input.shape[0], -1, -1, -1, -1) + call_kwargs["control_hidden_states"] = control_hidden_states.to( + device=latent_model_input.device, dtype=transformer_dtype + ) + call_kwargs["control_hidden_states_scale"] = control_hidden_states_scale.to( + device=latent_model_input.device, dtype=transformer_dtype + ) + + # Cond pass + noise_pred = self.transformer(**call_kwargs)[0] + if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + # Uncond pass: swap encoder_hidden_states; keep control kwargs identical + call_kwargs["encoder_hidden_states"] = negative_prompt_embeds + noise_uncond = self.transformer(**call_kwargs)[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py index 27ada121ca48..19713ac38e70 100644 --- a/tests/pipelines/wan/test_wan_video_to_video.py +++ b/tests/pipelines/wan/test_wan_video_to_video.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest import torch @@ -50,6 +51,12 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False supports_dduf = False + def _supports_control_kwargs(self, transformer) -> bool: + """Return True if the base transformer's forward() accepts VACE control kwargs.""" + base = transformer.get_base_model() if hasattr(transformer, "get_base_model") else transformer + sig = inspect.signature(base.forward) + return "control_hidden_states" in sig.parameters and "control_hidden_states_scale" in sig.parameters + def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLWan( @@ -147,3 +154,102 @@ def test_float16_inference(self): ) def test_save_load_float16(self): pass + + def test_neutral_control_injection_no_crash_latent(self): + device = "cpu" + + # Reuse the same tiny components + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + + # If transformer doesn't support control kwargs, this test isn't applicable. + if not self._supports_control_kwargs(pipe.transformer): + self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control injection test.") + + # --- Ensure VACE fields exist for control tensor sizing --- + # Prefer real module in_channels if present + pe = getattr(pipe.transformer, "vace_patch_embedding", None) + if pe is not None and hasattr(pe, "in_channels"): + vace_in = int(pe.in_channels) + else: + # fallback to model config fields + vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels)) + # also set it to help the pipeline code path + pipe.transformer.config.vace_in_channels = vace_in + + # vace_layers: ensure non-empty so scale vector has length >=1 + if not hasattr(pipe.transformer.config, "vace_layers"): + pipe.transformer.config.vace_layers = [0, 1] + + # Patch: we run in latent mode; skip VAE decode & video preprocessing + # Build tiny latents matching transformer.config.in_channels + C = int(pipe.transformer.config.in_channels) + # Very small T/H/W to keep speed + latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32) + + out = pipe( + video=None, + prompt="test", + negative_prompt=None, + height=16, + width=16, + num_inference_steps=2, + guidance_scale=1.0, # disable CFG branch to keep path minimal + strength=0.5, + generator=None, + latents=latents, # <- latent path, so we don’t need real VAE/video_processor + prompt_embeds=None, + negative_prompt_embeds=None, + output_type="latent", # <- prevents decode/postprocess + return_dict=True, + max_sequence_length=16, + ).frames + + # Assert: no crash and the latent shape is preserved + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(tuple(out.shape), tuple(latents.shape)) + + def test_neutral_control_injection_with_cfg(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + + if not self._supports_control_kwargs(pipe.transformer): + self.skipTest("Transformer doesn't accept VACE control kwargs; skipping control+CFG test.") + + # Ensure VACE sizing hints exist (as above) + pe = getattr(pipe.transformer, "vace_patch_embedding", None) + if pe is not None and hasattr(pe, "in_channels"): + vace_in = int(pe.in_channels) + else: + vace_in = int(getattr(pipe.transformer.config, "vace_in_channels", pipe.transformer.config.in_channels)) + pipe.transformer.config.vace_in_channels = vace_in + if not hasattr(pipe.transformer.config, "vace_layers"): + pipe.transformer.config.vace_layers = [0, 1, 2] + + C = int(pipe.transformer.config.in_channels) + latents = torch.zeros((1, C, 2, 8, 8), device=device, dtype=torch.float32) + + out = pipe( + video=None, + prompt="test", + negative_prompt="", + height=16, + width=16, + num_inference_steps=2, + guidance_scale=3.5, # trigger CFG (uncond) path + strength=0.5, + generator=None, + latents=latents, + prompt_embeds=None, + negative_prompt_embeds=None, + output_type="latent", + return_dict=True, + max_sequence_length=16, + ).frames + + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(tuple(out.shape), tuple(latents.shape)) From fc5634fece7546ac0a7ac6f1a20be0fc4ca1a57c Mon Sep 17 00:00:00 2001 From: Mathias Scherman Date: Sun, 2 Nov 2025 12:59:43 +0000 Subject: [PATCH 2/4] style --- .../pipelines/wan/pipeline_wan_video2video.py | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 4b8ffed1e2e2..7bf8a2c5d780 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -562,13 +562,12 @@ def __call__( The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. control_hidden_states (`torch.Tensor`, *optional*): - Control tensor for the VACE control path. Expected shape: - `(B, C, T_patch, H_patch, W_patch)`. If not provided, a neutral zero tensor of the correct - size and dtype is automatically created. - **Note:** This argument was added to prevent crashes when running without control features. + Control tensor for the VACE control path. Expected shape: `(B, C, T_patch, H_patch, W_patch)`. If not + provided, a neutral zero tensor of the correct size and dtype is automatically created. **Note:** This + argument was added to prevent crashes when running without control features. control_hidden_states_scale (`torch.Tensor`, *optional*): - A 1D tensor of scaling factors (length = number of VACE layers). Defaults to a vector of - ones if not provided. + A 1D tensor of scaling factors (length = number of VACE layers). Defaults to a vector of ones if not + provided. Examples: @@ -603,7 +602,9 @@ def __call__( self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False - base_tr = self.transformer.get_base_model() if hasattr(self.transformer, "get_base_model") else self.transformer + base_tr = ( + self.transformer.get_base_model() if hasattr(self.transformer, "get_base_model") else self.transformer + ) _sig = inspect.signature(base_tr.forward) _supports_control = ( "control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters @@ -676,8 +677,9 @@ def __call__( pt, ph, pw = (ps[0], ps[1], ps[2]) if len(ps) == 3 else (ps[0], ps[0], ps[0]) if control_hidden_states is None: - control_hidden_states = torch.zeros((B, int(C_ctrl), int(pt), int(ph), int(pw)), - device=device, dtype=transformer_dtype) + control_hidden_states = torch.zeros( + (B, int(C_ctrl), int(pt), int(ph), int(pw)), device=device, dtype=transformer_dtype + ) if control_hidden_states_scale is None: vls = cfg_tr.get("vace_layers", []) n_layers = len(vls) if isinstance(vls, (list, tuple)) else int(vls or 0) @@ -697,18 +699,20 @@ def __call__( timestep = t.expand(latents.shape[0]) # Build call kwargs - call_kwargs = dict( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - ) + call_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } # If supported, attach control tensors; ensure correct batch/device/dtype if _supports_control: if control_hidden_states.shape[0] != latent_model_input.shape[0]: - control_hidden_states = control_hidden_states.expand(latent_model_input.shape[0], -1, -1, -1, -1) + control_hidden_states = control_hidden_states.expand( + latent_model_input.shape[0], -1, -1, -1, -1 + ) call_kwargs["control_hidden_states"] = control_hidden_states.to( device=latent_model_input.device, dtype=transformer_dtype ) @@ -718,7 +722,7 @@ def __call__( # Cond pass noise_pred = self.transformer(**call_kwargs)[0] - + if self.do_classifier_free_guidance: # Uncond pass: swap encoder_hidden_states; keep control kwargs identical call_kwargs["encoder_hidden_states"] = negative_prompt_embeds From 565650b2ae65f6f2ef6b8da9ff1e6e790a6aa019 Mon Sep 17 00:00:00 2001 From: Mathias Scherman Date: Sun, 2 Nov 2025 13:30:57 +0000 Subject: [PATCH 3/4] Add warnings, improve docstring --- .../pipelines/wan/pipeline_wan_video2video.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 7bf8a2c5d780..e465cfed18bd 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -14,6 +14,7 @@ import html import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import regex as re @@ -562,12 +563,11 @@ def __call__( The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. control_hidden_states (`torch.Tensor`, *optional*): - Control tensor for the VACE control path. Expected shape: `(B, C, T_patch, H_patch, W_patch)`. If not - provided, a neutral zero tensor of the correct size and dtype is automatically created. **Note:** This - argument was added to prevent crashes when running without control features. + Control tensor for the VACE control path. Shape: `(B, C, T_patch, H_patch, W_patch)`. If omitted, a neutral + zero tensor of the correct size/dtype is created automatically. **If the underlying transformer does not support + these kwargs, this argument is ignored.** control_hidden_states_scale (`torch.Tensor`, *optional*): - A 1D tensor of scaling factors (length = number of VACE layers). Defaults to a vector of ones if not - provided. + 1D tensor of scaling factors (length = number of VACE layers). Defaults to ones. **Ignored if unsupported.** Examples: @@ -609,6 +609,12 @@ def __call__( _supports_control = ( "control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters ) + if not _supports_control and (control_hidden_states is not None or control_hidden_states_scale is not None): + warnings.warn( + "control_hidden_states/control_hidden_states_scale were provided, but the underlying transformer " + "does not accept these kwargs; they will be ignored.", + stacklevel=2, + ) device = self._execution_device From bae6994ce32d76b6a18392d454ea83cae54c57ed Mon Sep 17 00:00:00 2001 From: Mathias Scherman Date: Sun, 2 Nov 2025 13:33:40 +0000 Subject: [PATCH 4/4] docs --- .../pipelines/wan/pipeline_wan_video2video.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index e465cfed18bd..92fe8558b7a7 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -563,11 +563,12 @@ def __call__( The maximum sequence length of the text encoder. If the prompt is longer than this, it will be truncated. If the prompt is shorter, it will be padded to this length. control_hidden_states (`torch.Tensor`, *optional*): - Control tensor for the VACE control path. Shape: `(B, C, T_patch, H_patch, W_patch)`. If omitted, a neutral - zero tensor of the correct size/dtype is created automatically. **If the underlying transformer does not support - these kwargs, this argument is ignored.** + Control tensor for the VACE control path. Shape: `(B, C, T_patch, H_patch, W_patch)`. If omitted, a + neutral zero tensor of the correct size/dtype is created automatically. **If the underlying transformer + does not support these kwargs, this argument is ignored.** control_hidden_states_scale (`torch.Tensor`, *optional*): - 1D tensor of scaling factors (length = number of VACE layers). Defaults to ones. **Ignored if unsupported.** + 1D tensor of scaling factors for VACE layers (length = number of VACE layers). Defaults to ones. + **Ignored if unsupported.** Examples: @@ -609,6 +610,7 @@ def __call__( _supports_control = ( "control_hidden_states" in _sig.parameters and "control_hidden_states_scale" in _sig.parameters ) + # Warn if user passed control kwargs but model won't consume them if not _supports_control and (control_hidden_states is not None or control_hidden_states_scale is not None): warnings.warn( "control_hidden_states/control_hidden_states_scale were provided, but the underlying transformer " @@ -669,10 +671,10 @@ def __call__( latent_timestep, ) - # Precompute shapes/dtypes/devices we’ll need + # Precompute shapes we’ll need B = batch_size * num_videos_per_prompt - # Optionally build neutral control tensors if supported + # Build neutral control tensors only if the base transformer supports them if _supports_control: cfg_tr = self.transformer.config # FrozenDict-like C_ctrl = cfg_tr.get("vace_in_channels", cfg_tr.get("out_channels", cfg_tr.get("in_channels", 320))) @@ -682,10 +684,12 @@ def __call__( else: pt, ph, pw = (ps[0], ps[1], ps[2]) if len(ps) == 3 else (ps[0], ps[0], ps[0]) + # On first use, create neutral one-token control if control_hidden_states is None: control_hidden_states = torch.zeros( (B, int(C_ctrl), int(pt), int(ph), int(pw)), device=device, dtype=transformer_dtype ) + # Layer-wise scale vector (not batched) if control_hidden_states_scale is None: vls = cfg_tr.get("vace_layers", []) n_layers = len(vls) if isinstance(vls, (list, tuple)) else int(vls or 0) @@ -704,7 +708,7 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - # Build call kwargs + # Prepare kwargs for transformer call; keep identical for cond/uncond (swap only encoder_hidden_states) call_kwargs = { "hidden_states": latent_model_input, "timestep": timestep, @@ -713,7 +717,7 @@ def __call__( "return_dict": False, } - # If supported, attach control tensors; ensure correct batch/device/dtype + # If supported, attach control tensors; ensure batch/device/dtype match latent input if _supports_control: if control_hidden_states.shape[0] != latent_model_input.shape[0]: control_hidden_states = control_hidden_states.expand(