Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 76 additions & 34 deletions tests/modular_pipelines/flux/test_modular_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import random
import tempfile
import unittest

import numpy as np
import PIL
Expand All @@ -34,21 +33,16 @@
from ..test_modular_pipelines_common import ModularPipelineTesterMixin


class FluxModularTests:
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular"

def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_components(torch_dtype=torch_dtype)
return pipeline
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
Expand All @@ -57,36 +51,47 @@ def get_dummy_inputs(self, device, seed=0):
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
"output_type": "pt",
}
return inputs


class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])

class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular"

class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])

def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = super().get_pipeline(components_manager, torch_dtype)

# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
# fixed constants instead of
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
return pipeline

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
inputs["height"] = 8
inputs["width"] = 8
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 4,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "pt",
}
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")

inputs["image"] = init_image
inputs["strength"] = 0.5

return inputs

def test_save_from_pretrained(self):
Expand All @@ -96,6 +101,7 @@ def test_save_from_pretrained(self):

with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)

pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
Expand All @@ -105,26 +111,62 @@ def test_save_from_pretrained(self):

image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")

image_slices.append(image[0, -3:, -3:, -1].flatten())

assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3


class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks
repo = "hf-internal-testing/tiny-flux-kontext-pipe"

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])

def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "pt",
}
image = PIL.Image.new("RGB", (32, 32), 0)
_ = inputs.pop("strength")

inputs["image"] = image
inputs["height"] = 8
inputs["width"] = 8
inputs["max_area"] = 8 * 8
inputs["max_area"] = inputs["height"] * inputs["width"]
inputs["_auto_resize"] = False

return inputs

def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)

with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)

pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)

pipes.append(pipe)

image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")

image_slices.append(image[0, -3:, -3:, -1].flatten())

assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
Loading
Loading