-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[modular] add tests for qwen modular #12585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e35962b
94fa202
7d3c250
7334520
92d7977
7ad48f0
5821dd9
d7a887b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -132,6 +132,7 @@ def expected_components(self) -> List[ComponentSpec]: | |||||||||||||
| @property | ||||||||||||||
| def inputs(self) -> List[InputParam]: | ||||||||||||||
| return [ | ||||||||||||||
| InputParam("latents"), | ||||||||||||||
| InputParam(name="height"), | ||||||||||||||
| InputParam(name="width"), | ||||||||||||||
| InputParam(name="num_images_per_prompt", default=1), | ||||||||||||||
|
|
@@ -196,11 +197,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - | |||||||||||||
| f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" | ||||||||||||||
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| block_state.latents = randn_tensor( | ||||||||||||||
| shape, generator=block_state.generator, device=device, dtype=block_state.dtype | ||||||||||||||
| ) | ||||||||||||||
| block_state.latents = components.pachifier.pack_latents(block_state.latents) | ||||||||||||||
| if block_state.latents is None: | ||||||||||||||
| block_state.latents = randn_tensor( | ||||||||||||||
| shape, generator=block_state.generator, device=device, dtype=block_state.dtype | ||||||||||||||
| ) | ||||||||||||||
| block_state.latents = components.pachifier.pack_latents(block_state.latents) | ||||||||||||||
|
|
||||||||||||||
| self.set_block_state(state, block_state) | ||||||||||||||
| return components, state | ||||||||||||||
|
|
@@ -549,7 +550,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - | |||||||||||||
| block_state.width // components.vae_scale_factor // 2, | ||||||||||||||
| ) | ||||||||||||||
| ] | ||||||||||||||
| * block_state.batch_size | ||||||||||||||
| for _ in range(block_state.batch_size) | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have two options:
Regardless, the current implementation isn't exactly the same as how the standard pipeline implements it and would break for the batched input tests we have. |
||||||||||||||
| ] | ||||||||||||||
| block_state.txt_seq_lens = ( | ||||||||||||||
| block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,8 +74,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - | |
| block_state = self.get_block_state(state) | ||
|
|
||
| # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular | ||
| vae_scale_factor = 2 ** len(components.vae.temperal_downsample) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping |
||
| block_state.latents = components.pachifier.unpack_latents( | ||
| block_state.latents, block_state.height, block_state.width | ||
| block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor | ||
| ) | ||
| block_state.latents = block_state.latents.to(components.vae.dtype) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -503,6 +503,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | |
| block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] | ||
| block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] | ||
|
|
||
| block_state.negative_prompt_embeds = None | ||
| block_state.negative_prompt_embeds_mask = None | ||
|
Comment on lines
+506
to
+507
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Otherwise, no CFG settings would break. |
||
| if components.requires_unconditional_embeds: | ||
| negative_prompt = block_state.negative_prompt or "" | ||
| block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( | ||
|
|
@@ -627,6 +629,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | |
| device=device, | ||
| ) | ||
|
|
||
| block_state.negative_prompt_embeds = None | ||
| block_state.negative_prompt_embeds_mask = None | ||
| if components.requires_unconditional_embeds: | ||
| negative_prompt = block_state.negative_prompt or " " | ||
| block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( | ||
|
|
@@ -679,6 +683,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): | |
| device=device, | ||
| ) | ||
|
|
||
| block_state.negative_prompt_embeds = None | ||
| block_state.negative_prompt_embeds_mask = None | ||
| if components.requires_unconditional_embeds: | ||
| negative_prompt = block_state.negative_prompt or " " | ||
| block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,145 @@ | ||||||||||
| # coding=utf-8 | ||||||||||
| # Copyright 2025 HuggingFace Inc. | ||||||||||
| # | ||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||
| # You may obtain a copy of the License at | ||||||||||
| # | ||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
| # | ||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| # See the License for the specific language governing permissions and | ||||||||||
| # limitations under the License. | ||||||||||
|
|
||||||||||
| import unittest | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import PIL | ||||||||||
| import pytest | ||||||||||
| import torch | ||||||||||
|
|
||||||||||
| from diffusers import ClassifierFreeGuidance | ||||||||||
| from diffusers.modular_pipelines import ( | ||||||||||
| QwenImageAutoBlocks, | ||||||||||
| QwenImageEditAutoBlocks, | ||||||||||
| QwenImageEditModularPipeline, | ||||||||||
| QwenImageEditPlusAutoBlocks, | ||||||||||
| QwenImageEditPlusModularPipeline, | ||||||||||
| QwenImageModularPipeline, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| from ...testing_utils import torch_device | ||||||||||
| from ..test_modular_pipelines_common import ModularPipelineTesterMixin | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class QwenImageModularTests: | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| pipeline_class = QwenImageModularPipeline | ||||||||||
| pipeline_blocks_class = QwenImageAutoBlocks | ||||||||||
| repo = "hf-internal-testing/tiny-qwenimage-modular" | ||||||||||
|
|
||||||||||
| params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"]) | ||||||||||
| batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) | ||||||||||
|
Comment on lines
+38
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move these into the actual test object e.g |
||||||||||
|
|
||||||||||
| def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can just use the one defined in |
||||||||||
| pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) | ||||||||||
| pipeline.load_components(torch_dtype=torch_dtype) | ||||||||||
| pipeline.set_progress_bar_config(disable=None) | ||||||||||
| return pipeline | ||||||||||
|
|
||||||||||
| def get_dummy_inputs(self, device, seed=0): | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can move into actual test object e.g. |
||||||||||
| if str(device).startswith("mps"): | ||||||||||
| generator = torch.manual_seed(seed) | ||||||||||
| else: | ||||||||||
| generator = torch.Generator(device=device).manual_seed(seed) | ||||||||||
| inputs = { | ||||||||||
| "prompt": "dance monkey", | ||||||||||
| "negative_prompt": "bad quality", | ||||||||||
| "generator": generator, | ||||||||||
| "num_inference_steps": 2, | ||||||||||
| "height": 32, | ||||||||||
| "width": 32, | ||||||||||
| "max_sequence_length": 16, | ||||||||||
| "output_type": "np", | ||||||||||
| } | ||||||||||
| return inputs | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class QwenImageModularGuiderTests: | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| def test_guider_cfg(self, tol=1e-2): | ||||||||||
| pipe = self.get_pipeline() | ||||||||||
| pipe = pipe.to(torch_device) | ||||||||||
|
|
||||||||||
| guider = ClassifierFreeGuidance(guidance_scale=1.0) | ||||||||||
| pipe.update_components(guider=guider) | ||||||||||
|
|
||||||||||
| inputs = self.get_dummy_inputs(torch_device) | ||||||||||
| out_no_cfg = pipe(**inputs, output="images") | ||||||||||
|
|
||||||||||
| guider = ClassifierFreeGuidance(guidance_scale=7.5) | ||||||||||
| pipe.update_components(guider=guider) | ||||||||||
| inputs = self.get_dummy_inputs(torch_device) | ||||||||||
| out_cfg = pipe(**inputs, output="images") | ||||||||||
|
|
||||||||||
| assert out_cfg.shape == out_no_cfg.shape | ||||||||||
| max_diff = np.abs(out_cfg - out_no_cfg).max() | ||||||||||
| assert max_diff > tol, "Output with CFG must be different from normal inference" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class QwenImageModularPipelineFastTests( | ||||||||||
| QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase | ||||||||||
|
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unittest
Suggested change
|
||||||||||
| ): | ||||||||||
| def __init__(self, *args, **kwargs): | ||||||||||
| super().__init__(*args, **kwargs) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class QwenImageEditModularPipelineFastTests( | ||||||||||
| QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase | ||||||||||
| ): | ||||||||||
| pipeline_class = QwenImageEditModularPipeline | ||||||||||
| pipeline_blocks_class = QwenImageEditAutoBlocks | ||||||||||
| repo = "hf-internal-testing/tiny-qwenimage-edit-modular" | ||||||||||
|
|
||||||||||
| def get_dummy_inputs(self, device, seed=0): | ||||||||||
| inputs = super().get_dummy_inputs(device, seed) | ||||||||||
| inputs.pop("max_sequence_length") | ||||||||||
| inputs["image"] = PIL.Image.new("RGB", (32, 32), 0) | ||||||||||
| return inputs | ||||||||||
|
|
||||||||||
| def test_guider_cfg(self): | ||||||||||
| super().test_guider_cfg(7e-5) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class QwenImageEditPlusModularPipelineFastTests( | ||||||||||
| QwenImageModularTests, QwenImageModularGuiderTests, ModularPipelineTesterMixin, unittest.TestCase | ||||||||||
| ): | ||||||||||
| pipeline_class = QwenImageEditPlusModularPipeline | ||||||||||
| pipeline_blocks_class = QwenImageEditPlusAutoBlocks | ||||||||||
| repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular" | ||||||||||
|
|
||||||||||
| # No `mask_image` yet. | ||||||||||
| params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"]) | ||||||||||
| batch_params = frozenset(["prompt", "negative_prompt", "image"]) | ||||||||||
|
|
||||||||||
| def get_dummy_inputs(self, device, seed=0): | ||||||||||
| inputs = super().get_dummy_inputs(device, seed) | ||||||||||
| inputs.pop("max_sequence_length") | ||||||||||
| image = PIL.Image.new("RGB", (32, 32), 0) | ||||||||||
| inputs["image"] = [image] | ||||||||||
| return inputs | ||||||||||
|
|
||||||||||
| @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) | ||||||||||
| def test_num_images_per_prompt(self): | ||||||||||
| super().test_num_images_per_prompt() | ||||||||||
|
|
||||||||||
| @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) | ||||||||||
| def test_inference_batch_consistent(): | ||||||||||
| super().test_inference_batch_consistent() | ||||||||||
|
|
||||||||||
| @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) | ||||||||||
| def test_inference_batch_single_identical(): | ||||||||||
| super().test_inference_batch_single_identical() | ||||||||||
|
Comment on lines
+132
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are skipped in the standard pipeline tests, too. |
||||||||||
|
|
||||||||||
| def test_guider_cfg(self): | ||||||||||
| super().test_guider_cfg(1e-3) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to how it's done in the other pipelines.