From e35962b0decb95aec8de751da7768ee2d0aabaa4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 3 Nov 2025 18:42:47 +0530 Subject: [PATCH 1/4] add tests for qwenimage modular. --- .../qwenimage/before_denoise.py | 13 +-- .../modular_pipelines/qwenimage/decoders.py | 3 +- .../modular_pipelines/qwenimage/encoders.py | 2 + .../qwenimage/modular_pipeline.py | 5 +- tests/modular_pipelines/qwen/__init__.py | 0 .../qwen/test_modular_pipeline_qwenimage.py | 85 +++++++++++++++++++ 6 files changed, 97 insertions(+), 11 deletions(-) create mode 100644 tests/modular_pipelines/qwen/__init__.py create mode 100644 tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index fdec95dc506e..f10200503141 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -132,6 +132,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ + InputParam("latents"), InputParam(name="height"), InputParam(name="width"), InputParam(name="num_images_per_prompt", default=1), @@ -196,11 +197,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - - block_state.latents = randn_tensor( - shape, generator=block_state.generator, device=device, dtype=block_state.dtype - ) - block_state.latents = components.pachifier.pack_latents(block_state.latents) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) self.set_block_state(state, block_state) return components, state @@ -549,7 +550,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state.width // components.vae_scale_factor // 2, ) ] - * block_state.batch_size + for _ in range(block_state.batch_size) ] block_state.txt_seq_lens = ( block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 6c82fe989e55..aedb0e4018f3 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -74,8 +74,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state = self.get_block_state(state) # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular + vae_scale_factor = 2 ** len(components.vae.temperal_downsample) block_state.latents = components.pachifier.unpack_latents( - block_state.latents, block_state.height, block_state.width + block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor ) block_state.latents = block_state.latents.to(components.vae.dtype) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 04fb3fdc947b..b025c2dc5071 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -503,6 +503,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or "" block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py index d9e30864f660..59e1a13a5db2 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin): config_name = "config.json" @register_to_config - def __init__( - self, - patch_size: int = 2, - ): + def __init__(self, patch_size: int = 2): super().__init__() def pack_latents(self, latents): diff --git a/tests/modular_pipelines/qwen/__init__.py b/tests/modular_pipelines/qwen/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py new file mode 100644 index 000000000000..024edf5f34b0 --- /dev/null +++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import ClassifierFreeGuidance +from diffusers.modular_pipelines import QwenImageAutoBlocks, QwenImageModularPipeline + +from ...testing_utils import torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class QwenImagexModularTests: + pipeline_class = QwenImageModularPipeline + pipeline_blocks_class = QwenImageAutoBlocks + repo = "hf-internal-testing/tiny-qwenimage-modular" + + params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"]) + batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) + pipeline.load_components(torch_dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=None) + return pipeline + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } + return inputs + + +class QwenImageModularGuiderTests: + def test_guider_cfg(self): + pipe = self.get_pipeline() + pipe = pipe.to(torch_device) + + guider = ClassifierFreeGuidance(guidance_scale=1.0) + pipe.update_components(guider=guider) + + inputs = self.get_dummy_inputs(torch_device) + out_no_cfg = pipe(**inputs, output="images") + + guider = ClassifierFreeGuidance(guidance_scale=7.5) + pipe.update_components(guider=guider) + inputs = self.get_dummy_inputs(torch_device) + out_cfg = pipe(**inputs, output="images") + + assert out_cfg.shape == out_no_cfg.shape + max_diff = np.abs(out_cfg - out_no_cfg).max() + assert max_diff > 1e-2, "Output with CFG must be different from normal inference" + + +class QwenImageModularPipelineFastTests( + QwenImagexModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) From 94fa2029a43d7e6e07d208ec4e9ef6df25bcba65 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 4 Nov 2025 08:36:23 +0530 Subject: [PATCH 2/4] qwenimage edit. --- .../qwen/test_modular_pipeline_qwenimage.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py index 024edf5f34b0..adcd08e80ca5 100644 --- a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py +++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py @@ -16,16 +16,22 @@ import unittest import numpy as np +import PIL import torch from diffusers import ClassifierFreeGuidance -from diffusers.modular_pipelines import QwenImageAutoBlocks, QwenImageModularPipeline +from diffusers.modular_pipelines import ( + QwenImageAutoBlocks, + QwenImageEditAutoBlocks, + QwenImageEditModularPipeline, + QwenImageModularPipeline, +) from ...testing_utils import torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin -class QwenImagexModularTests: +class QwenImageModularTests: pipeline_class = QwenImageModularPipeline pipeline_blocks_class = QwenImageAutoBlocks repo = "hf-internal-testing/tiny-qwenimage-modular" @@ -79,7 +85,20 @@ def test_guider_cfg(self): class QwenImageModularPipelineFastTests( - QwenImagexModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase + QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase ): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + +class QwenImageEditModularPipelineFastTests( + QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase +): + pipeline_class = QwenImageEditModularPipeline + pipeline_blocks_class = QwenImageEditAutoBlocks + repo = "hf-internal-testing/tiny-qwenimage-edit-modular" + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + inputs["image"] = PIL.Image.new("RGB", (32, 32), 0) + return inputs From 7d3c250722e0d9bc900e91e85217d8667204293a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 4 Nov 2025 10:05:34 +0530 Subject: [PATCH 3/4] qwenimage edit plus. --- .../modular_pipelines/qwenimage/encoders.py | 4 ++ .../qwen/test_modular_pipeline_qwenimage.py | 45 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index b025c2dc5071..3b56981e5290 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -629,6 +629,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): device=device, ) + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( @@ -681,6 +683,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): device=device, ) + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py index adcd08e80ca5..1a49fc222532 100644 --- a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py +++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py @@ -17,6 +17,7 @@ import numpy as np import PIL +import pytest import torch from diffusers import ClassifierFreeGuidance @@ -24,6 +25,8 @@ QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, + QwenImageEditPlusAutoBlocks, + QwenImageEditPlusModularPipeline, QwenImageModularPipeline, ) @@ -64,7 +67,7 @@ def get_dummy_inputs(self, device, seed=0): class QwenImageModularGuiderTests: - def test_guider_cfg(self): + def test_guider_cfg(self, tol=1e-2): pipe = self.get_pipeline() pipe = pipe.to(torch_device) @@ -81,7 +84,7 @@ def test_guider_cfg(self): assert out_cfg.shape == out_no_cfg.shape max_diff = np.abs(out_cfg - out_no_cfg).max() - assert max_diff > 1e-2, "Output with CFG must be different from normal inference" + assert max_diff > tol, "Output with CFG must be different from normal inference" class QwenImageModularPipelineFastTests( @@ -100,5 +103,43 @@ class QwenImageEditModularPipelineFastTests( def get_dummy_inputs(self, device, seed=0): inputs = super().get_dummy_inputs(device, seed) + inputs.pop("max_sequence_length") inputs["image"] = PIL.Image.new("RGB", (32, 32), 0) return inputs + + def test_guider_cfg(self): + super().test_guider_cfg(7e-5) + + +class QwenImageEditPlusModularPipelineFastTests( + QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase +): + pipeline_class = QwenImageEditPlusModularPipeline + pipeline_blocks_class = QwenImageEditPlusAutoBlocks + repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular" + + # No `mask_image` yet. + params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"]) + batch_params = frozenset(["prompt", "negative_prompt", "image"]) + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed) + inputs.pop("max_sequence_length") + image = PIL.Image.new("RGB", (32, 32), 0) + inputs["image"] = [image] + return inputs + + @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) + def test_num_images_per_prompt(self): + super().test_num_images_per_prompt() + + @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) + def test_inference_batch_consistent(): + super().test_inference_batch_consistent() + + @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) + def test_inference_batch_single_identical(): + super().test_inference_batch_single_identical() + + def test_guider_cfg(self): + super().test_guider_cfg(1e-3) From 7ad48f0c2755dd6f347912627cd0369affb81d64 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 4 Nov 2025 12:08:38 +0530 Subject: [PATCH 4/4] empty