From 760a9149a718d15032f4dff1a310c6fe16cd8b3c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 16:40:53 +0530 Subject: [PATCH 01/11] start custom block testing. --- .../test_modular_pipelines_custom_blocks.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/modular_pipelines/test_modular_pipelines_custom_blocks.py diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py new file mode 100644 index 000000000000..e0c1e8e495ea --- /dev/null +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -0,0 +1,78 @@ +from typing import List + +from diffusers import FluxTransformer2DModel +from diffusers.modular_pipelines import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState + + +class DummyCustomBlockSimple(ModularPipelineBlocks): + def __init__(self, use_dummy_model_component=False): + self.use_dummy_model_component = use_dummy_model_component + super().__init__() + + @property + def expected_components(self): + if self.use_dummy_model_component: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + else: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "output_prompt", + type_hint=str, + description="Modified prompt", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + old_prompt = block_state.prompt + block_state.output_prompt = "Modular diffusers + " + old_prompt + self.set_block_state(state, block_state) + + return components, state + + +class TestModularCustomBlocks: + def test_custom_block_properties(self): + custom_block = DummyCustomBlockSimple() + + assert not custom_block.expected_components + assert not custom_block.intermediate_inputs + + actual_inputs = [inp.name for inp in custom_block.inputs] + actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] + assert actual_inputs == ["prompt"] + assert actual_intermediate_outputs == ["output_prompt"] + + def test_custom_block_output(self): + custom_block = DummyCustomBlockSimple() + pipeline = custom_block.init_pipeline() + prompt = "Diffusers is nice" + output = pipeline(prompt=prompt) + + actual_inputs = [inp.name for inp in custom_block.inputs] + actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] + assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) + + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") + + def test_custom_block_supported_components(self): + custom_block = DummyCustomBlockSimple(use_dummy_model_component=True) + pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe") + pipe.load_components() + + assert len(pipe.components) == 1 + assert pipe.component_names[0] == "transformer" From 77e50155e60ad423ac369d3bb1d7a7182e4ec273 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 16:43:39 +0530 Subject: [PATCH 02/11] simplify modular workflow ci. --- .github/workflows/pr_modular_tests.yml | 76 ++++++++++---------------- 1 file changed, 30 insertions(+), 46 deletions(-) diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index 7081ee518d55..8b0580b81f86 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -77,62 +77,46 @@ jobs: run_fast_tests: needs: [check_code_quality, check_repository_consistency] - strategy: - fail-fast: false - matrix: - config: - - name: Fast PyTorch Modular Pipeline CPU tests - framework: pytorch_pipelines - runner: aws-highmemory-32-plus - image: diffusers/diffusers-pytorch-cpu - report: torch_cpu_modular_pipelines - - name: ${{ matrix.config.name }} - + name: Fast PyTorch Modular Pipeline CPU tests runs-on: - group: ${{ matrix.config.runner }} - + group: aws-highmemory-32-plus container: - image: ${{ matrix.config.image }} + image: diffusers/diffusers-pytorch-cpu options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ - defaults: run: shell: bash steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Install dependencies - run: | - uv pip install -e ".[quality]" - uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git - uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - - - name: Environment - run: | - python utils/print_env.py + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 - - name: Run fast PyTorch Pipeline CPU tests - if: ${{ matrix.config.framework == 'pytorch_pipelines' }} - run: | - pytest -n 8 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests/modular_pipelines + - name: Install dependencies + run: | + uv pip install -e ".[quality]" + uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git + uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - - name: Failure short reports - if: ${{ failure() }} - run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + - name: Environment + run: | + python utils/print_env.py - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v4 - with: - name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports - path: reports + - name: Run fast PyTorch Pipeline CPU tests + run: | + pytest -n 8 --max-worker-restart=0 --dist=loadfile \ + -s -v -k "not Flax and not Onnx" \ + --make-reports=tests_torch_cpu_modular_pipelines \ + tests/modular_pipelines + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports + path: reports From 1be88f036ff20ed25c72f6b7ed622e79ae1360c3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 17:03:02 +0530 Subject: [PATCH 03/11] up --- .../test_modular_pipelines_custom_blocks.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index e0c1e8e495ea..21643831699f 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -45,22 +45,24 @@ def __call__(self, components, state: PipelineState) -> PipelineState: class TestModularCustomBlocks: - def test_custom_block_properties(self): - custom_block = DummyCustomBlockSimple() - - assert not custom_block.expected_components - assert not custom_block.intermediate_inputs + def _test_block_properties(self, block): + assert not block.expected_components + assert not block.intermediate_inputs - actual_inputs = [inp.name for inp in custom_block.inputs] - actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] + actual_inputs = [inp.name for inp in block.inputs] + actual_intermediate_outputs = [out.name for out in block.intermediate_outputs] assert actual_inputs == ["prompt"] assert actual_intermediate_outputs == ["output_prompt"] + def test_custom_block_properties(self): + custom_block = DummyCustomBlockSimple() + self._test_block_properties(custom_block) + def test_custom_block_output(self): custom_block = DummyCustomBlockSimple() - pipeline = custom_block.init_pipeline() + pipe = custom_block.init_pipeline() prompt = "Diffusers is nice" - output = pipeline(prompt=prompt) + output = pipe(prompt=prompt) actual_inputs = [inp.name for inp in custom_block.inputs] actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] @@ -76,3 +78,18 @@ def test_custom_block_supported_components(self): assert len(pipe.components) == 1 assert pipe.component_names[0] == "transformer" + + def test_custom_block_loads_from_hub(self): + repo_id = "hf-internal-testing/tiny-modular-diffusers-block" + block = ModularPipelineBlocks.from_pretrained( + repo_id, + trust_remote_code=True, + ) + self._test_block_properties(block) + + pipe = block.init_pipeline() + + prompt = "Diffusers is nice" + output = pipe(prompt=prompt) + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") From 316b71ff2b9196a3e0fda81bb148da4d236a509a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 17:03:34 +0530 Subject: [PATCH 04/11] style. --- .../modular_pipelines/test_modular_pipelines_custom_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 21643831699f..0d721b983783 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -86,9 +86,9 @@ def test_custom_block_loads_from_hub(self): trust_remote_code=True, ) self._test_block_properties(block) - + pipe = block.init_pipeline() - + prompt = "Diffusers is nice" output = pipe(prompt=prompt) output_prompt = output.values["output_prompt"] From ecdd84304464007aa97cfc08e66ef122dd2303ca Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 17:10:10 +0530 Subject: [PATCH 05/11] up --- .../test_modular_pipelines_custom_blocks.py | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 0d721b983783..3d1bae152602 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -1,7 +1,18 @@ from typing import List +import torch + from diffusers import FluxTransformer2DModel -from diffusers.modular_pipelines import ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState +from diffusers.modular_pipelines import ( + ComponentSpec, + InputParam, + ModularPipelineBlocks, + OutputParam, + PipelineState, + WanModularPipeline, +) + +from ..testing_utils import nightly, require_torch, slow class DummyCustomBlockSimple(ModularPipelineBlocks): @@ -81,10 +92,7 @@ def test_custom_block_supported_components(self): def test_custom_block_loads_from_hub(self): repo_id = "hf-internal-testing/tiny-modular-diffusers-block" - block = ModularPipelineBlocks.from_pretrained( - repo_id, - trust_remote_code=True, - ) + block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) self._test_block_properties(block) pipe = block.init_pipeline() @@ -93,3 +101,19 @@ def test_custom_block_loads_from_hub(self): output = pipe(prompt=prompt) output_prompt = output.values["output_prompt"] assert output_prompt.startswith("Modular diffusers + ") + + +@slow +@nightly +@require_torch +class TestModularCustomBlocksIntegration: + def test_krea_realtime_video_loading(self): + repo_id = "krea/krea-realtime-video" + blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) + + pipe = WanModularPipeline(blocks, repo_id) + pipe.load_components( + trust_remote_code=True, + device_map="cuda", + torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, + ) From 5f1afc11ac8483bd08b1f0915cf11e7509209b77 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 18:19:07 +0530 Subject: [PATCH 06/11] up --- .../test_modular_pipelines_custom_blocks.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 3d1bae152602..ef4fe66c6f89 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -1,5 +1,7 @@ +from collections import deque from typing import List +import numpy as np import torch from diffusers import FluxTransformer2DModel @@ -110,6 +112,9 @@ class TestModularCustomBlocksIntegration: def test_krea_realtime_video_loading(self): repo_id = "krea/krea-realtime-video" blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) + block_names = sorted(blocks.sub_blocks) + + assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"]) pipe = WanModularPipeline(blocks, repo_id) pipe.load_components( @@ -117,3 +122,51 @@ def test_krea_realtime_video_loading(self): device_map="cuda", torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, ) + assert len(pipe.components) == 7 + assert sorted(pipe.components) == sorted( + ["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"] + ) + + def test_forward(self): + repo_id = "krea/krea-realtime-video" + blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) + pipe = WanModularPipeline(blocks, repo_id) + pipe.load_components( + trust_remote_code=True, + device_map="cuda", + torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, + ) + + num_frames_per_block = 2 + num_blocks = 2 + + state = PipelineState() + state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len)) + + prompt = ["a cat sitting on a boat"] + + for block in pipe.transformer.blocks: + block.self_attn.fuse_projections() + + for block_idx in range(num_blocks): + state = pipe( + state, + prompt=prompt, + num_inference_steps=2, + num_blocks=num_blocks, + num_frames_per_block=num_frames_per_block, + block_idx=block_idx, + generator=torch.manual_seed(42), + ) + current_frames = np.array(state.values["videos"][0]) + current_frames_flat = current_frames.flatten() + actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist() + + if block_idx == 0: + assert current_frames.shape == (5, 480, 832, 3) + expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193]) + else: + assert current_frames.shape == (8, 480, 832, 3) + expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191]) + + assert np.allclose(actual_slices, expected_slices) From ddb5ba734df21ee4af6838333b5f0460f5ba055b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 18:27:31 +0530 Subject: [PATCH 07/11] up --- .../test_modular_pipelines_custom_blocks.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index ef4fe66c6f89..66d243421cca 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -108,15 +108,16 @@ def test_custom_block_loads_from_hub(self): @slow @nightly @require_torch -class TestModularCustomBlocksIntegration: - def test_krea_realtime_video_loading(self): - repo_id = "krea/krea-realtime-video" - blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) +class TestKreaCustomBlocksIntegration: + repo_id = "krea/krea-realtime-video" + + def test_loading(self): + blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) block_names = sorted(blocks.sub_blocks) assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"]) - pipe = WanModularPipeline(blocks, repo_id) + pipe = WanModularPipeline(blocks, self.repo_id) pipe.load_components( trust_remote_code=True, device_map="cuda", @@ -128,9 +129,8 @@ def test_krea_realtime_video_loading(self): ) def test_forward(self): - repo_id = "krea/krea-realtime-video" - blocks = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) - pipe = WanModularPipeline(blocks, repo_id) + blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) + pipe = WanModularPipeline(blocks, self.repo_id) pipe.load_components( trust_remote_code=True, device_map="cuda", From b5f13d9b59747b3494b4314ba43ae1ae9d8820ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 18:28:06 +0530 Subject: [PATCH 08/11] up --- tests/modular_pipelines/test_modular_pipelines_custom_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 66d243421cca..04d4ec5772f9 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -111,7 +111,7 @@ def test_custom_block_loads_from_hub(self): class TestKreaCustomBlocksIntegration: repo_id = "krea/krea-realtime-video" - def test_loading(self): + def test_loading_from_hub(self): blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) block_names = sorted(blocks.sub_blocks) From 9f113f8138e617fd33a99e33adf3887eebc7c12a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 29 Oct 2025 21:25:21 +0530 Subject: [PATCH 09/11] up --- .github/workflows/pr_modular_tests.yml | 2 +- tests/conftest.py | 2 ++ tests/testing_utils.py | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index 8b0580b81f86..afaca1b90d3b 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -106,7 +106,7 @@ jobs: - name: Run fast PyTorch Pipeline CPU tests run: | pytest -n 8 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "not Flax and not Onnx" \ + -s -v \ --make-reports=tests_torch_cpu_modular_pipelines \ tests/modular_pipelines diff --git a/tests/conftest.py b/tests/conftest.py index fd76d1c84ee7..9558c23d3062 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,8 @@ def pytest_configure(config): config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") + config.addinivalue_line("markers", "slow: mark test as slow") + config.addinivalue_line("markers", "nightly: mark test as nightly") def pytest_addoption(parser): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 951ba4128033..8f5df4887d88 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -24,6 +24,7 @@ import numpy as np import PIL.Image import PIL.ImageOps +import pytest import requests from numpy.linalg import norm from packaging import version @@ -267,7 +268,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(pytest.mark.slow(test_case)) def nightly(test_case): @@ -277,7 +278,7 @@ def nightly(test_case): Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ - return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) + return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(pytest.mark.nightly(test_case)) def is_torch_compile(test_case): From b8809f76d51c890507887b59b81723c4b90e1a2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 31 Oct 2025 15:52:19 +0530 Subject: [PATCH 10/11] up --- src/diffusers/utils/dynamic_modules_utils.py | 3 +- .../test_modular_pipelines_custom_blocks.py | 100 ++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 627b1e0604dc..1a496e89dcaf 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -298,9 +298,10 @@ def get_cached_module_file( """ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) + print(f"{pretrained_model_name_or_path=}") module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) - + print(f"{module_file_or_url=}") if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local" diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 04d4ec5772f9..9c5fd5be326d 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -1,3 +1,20 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile from collections import deque from typing import List @@ -57,6 +74,58 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state +CODE_STR = """ +from diffusers.modular_pipelines import ( + ComponentSpec, + InputParam, + ModularPipelineBlocks, + OutputParam, + PipelineState, + WanModularPipeline, +) +from typing import List + +class DummyCustomBlockSimple(ModularPipelineBlocks): + def __init__(self, use_dummy_model_component=False): + self.use_dummy_model_component = use_dummy_model_component + super().__init__() + + @property + def expected_components(self): + if self.use_dummy_model_component: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + else: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "output_prompt", + type_hint=str, + description="Modified prompt", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + old_prompt = block_state.prompt + block_state.output_prompt = "Modular diffusers + " + old_prompt + self.set_block_state(state, block_state) + + return components, state +""" + + class TestModularCustomBlocks: def _test_block_properties(self, block): assert not block.expected_components @@ -84,6 +153,37 @@ def test_custom_block_output(self): output_prompt = output.values["output_prompt"] assert output_prompt.startswith("Modular diffusers + ") + def test_custom_block_saving_loading(self): + custom_block = DummyCustomBlockSimple() + + with tempfile.TemporaryDirectory() as tmpdir: + custom_block.save_pretrained(tmpdir) + assert any("modular_config.json" in k for k in os.listdir(tmpdir)) + + with open(os.path.join(tmpdir, "modular_config.json"), "r") as f: + config = json.load(f) + auto_map = config["auto_map"] + assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"} + + # For now, the Python script that implements the custom block has to be manually pushed to the Hub. + # This is why, we have to separately save the Python script here. + code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py") + with open(code_path, "w") as f: + f.write(CODE_STR) + + loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True) + + pipe = loaded_custom_block.init_pipeline() + prompt = "Diffusers is nice" + output = pipe(prompt=prompt) + + actual_inputs = [inp.name for inp in loaded_custom_block.inputs] + actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs] + assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) + + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") + def test_custom_block_supported_components(self): custom_block = DummyCustomBlockSimple(use_dummy_model_component=True) pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe") From c0ce538afce9b48609c83f6a30fb0ac5295a387c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 3 Nov 2025 08:31:06 +0530 Subject: [PATCH 11/11] Apply suggestions from code review --- src/diffusers/utils/dynamic_modules_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index ce7e117927f2..178657f00ae9 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -299,10 +299,8 @@ def get_cached_module_file( """ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) - print(f"{pretrained_model_name_or_path=}") module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) - print(f"{module_file_or_url=}") if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local"