diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 94dad286e4a3..a9fe9b142451 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -628,6 +628,8 @@
- sections:
- local: api/pipelines/allegro
title: Allegro
+ - local: api/pipelines/chronoedit
+ title: ChronoEdit
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consisid
diff --git a/docs/source/en/api/models/chronoedit_transformer_3d.md b/docs/source/en/api/models/chronoedit_transformer_3d.md
new file mode 100644
index 000000000000..94982821795d
--- /dev/null
+++ b/docs/source/en/api/models/chronoedit_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# ChronoEditTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
+
+> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import ChronoEditTransformer3DModel
+
+transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## ChronoEditTransformer3DModel
+
+[[autodoc]] ChronoEditTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/chronoedit.md b/docs/source/en/api/pipelines/chronoedit.md
new file mode 100644
index 000000000000..48e70ab9e55e
--- /dev/null
+++ b/docs/source/en/api/pipelines/chronoedit.md
@@ -0,0 +1,156 @@
+
+
+
+
+# ChronoEdit
+
+[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
+
+> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
+
+*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
+
+The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
+
+
+### Image Editing
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+from PIL import Image
+
+model_id = "nvidia/ChronoEdit-14B-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
+)
+max_area = 720 * 1280
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+print("width", width, "height", height)
+image = image.resize((width, height))
+prompt = (
+ "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+)
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=5,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ enable_temporal_reasoning=False,
+ num_temporal_reasoning_steps=0,
+).frames[0]
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+Optionally, enable **temporal reasoning** for improved physical consistency:
+```py
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=29,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ enable_temporal_reasoning=True,
+ num_temporal_reasoning_steps=50,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+### Inference with 8-Step Distillation Lora
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+from PIL import Image
+
+model_id = "nvidia/ChronoEdit-14B-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
+lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
+pipe.load_lora_weights(lora_path)
+pipe.fuse_lora(lora_scale=1.0)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
+)
+max_area = 720 * 1280
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+print("width", width, "height", height)
+image = image.resize((width, height))
+prompt = (
+ "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+)
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=5,
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ enable_temporal_reasoning=False,
+ num_temporal_reasoning_steps=0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+## ChronoEditPipeline
+
+[[autodoc]] ChronoEditPipeline
+ - all
+ - __call__
+
+## ChronoEditPipelineOutput
+
+[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
\ No newline at end of file
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 572aad4bd3f1..a6b90831f743 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -202,6 +202,7 @@
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
+ "ChronoEditTransformer3DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
@@ -436,6 +437,7 @@
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
+ "ChronoEditPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
@@ -909,6 +911,7 @@
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
+ ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
@@ -1113,6 +1116,7 @@
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
+ ChronoEditPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 202e77fd197d..e97ab8bd1d2a 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -86,6 +86,7 @@
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
+ _import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -179,6 +180,7 @@
BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
+ ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 15408a4b15cc..66daf56e23b2 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -20,6 +20,7 @@
from .transformer_bria import BriaTransformer2DModel
from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
+ from .transformer_chronoedit import ChronoEditTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py
new file mode 100644
index 000000000000..9c0b883d61e0
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_chronoedit.py
@@ -0,0 +1,734 @@
+# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
+class WanAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor2_0
+class WanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use WanAttnProcessor instead. "
+ )
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return WanAttnProcessor(*args, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttention
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class ChronoEditRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ temporal_skip_len: int = 8,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+ self.temporal_skip_len = temporal_skip_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ if num_frames == 2:
+ freqs_cos_f = freqs_cos[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ else:
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ if num_frames == 2:
+ freqs_sin_f = freqs_sin[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ else:
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
+@maybe_allow_in_graph
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class ChronoEditTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the ChronoEdit model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ rope_temporal_skip_len: int = 8,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = ChronoEditRotaryPosEmbed(
+ attention_head_dim, patch_size, rope_max_seq_len, temporal_skip_len=rope_temporal_skip_len
+ )
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
+ if timestep.ndim == 2:
+ ts_seq_len = timestep.shape[1]
+ timestep = timestep.flatten() # batch_size * seq_len
+ else:
+ ts_seq_len = None
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
+ )
+ if ts_seq_len is not None:
+ # batch_size, seq_len, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+ else:
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # 5. Output norm, projection & unpatchify
+ if temb.ndim == 3:
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 87d953845e21..495753041f10 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -404,6 +404,7 @@
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
]
+ _import_structure["chronoedit"] = ["ChronoEditPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -566,6 +567,7 @@
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
+ from .chronoedit import ChronoEditPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
diff --git a/src/diffusers/pipelines/chronoedit/__init__.py b/src/diffusers/pipelines/chronoedit/__init__.py
new file mode 100644
index 000000000000..cffe4660977f
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_chronoedit"] = ["ChronoEditPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_chronoedit import ChronoEditPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py
new file mode 100644
index 000000000000..79f6580fbed6
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py
@@ -0,0 +1,752 @@
+# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, ChronoEditTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import ChronoEditPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import CLIPVisionModel
+
+ >>> # Available models: nvidia/ChronoEdit-14B-Diffusers
+ >>> model_id = "nvidia/ChronoEdit-14B-Diffusers"
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
+ ... )
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> transformer = ChronoEditTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = ChronoEditPipeline.from_pretrained(
+ ... model_id, vae=vae, image_encoder=image_encoder, transformer=transformer, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image("https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png")
+ >>> max_area = 720 * 1280
+ >>> aspect_ratio = image.height / image.width
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ >>> image = image.resize((width, height))
+ >>> prompt = (
+ ... "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ ... "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+ ... )
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... num_frames=5,
+ ... guidance_scale=5.0,
+ ... enable_temporal_reasoning=False,
+ ... num_temporal_reasoning_steps=0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class ChronoEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ transformer: ChronoEditTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ enable_temporal_reasoning: bool = False,
+ num_temporal_reasoning_steps: int = 0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`ChronoEditPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+ enable_temporal_reasoning (`bool`, *optional*, defaults to `False`):
+ Whether to enable temporal reasoning.
+ num_temporal_reasoning_steps (`int`, *optional*, defaults to `0`):
+ The number of steps to enable temporal reasoning.
+
+ Examples:
+
+ Returns:
+ [`~ChronoEditPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`ChronoEditPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ num_frames = 5 if not enable_temporal_reasoning else num_frames
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if image_embeds is None:
+ image_embeds = self.encode_image(image, device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if enable_temporal_reasoning and i == num_temporal_reasoning_steps:
+ latents = latents[:, :, [0, -1]]
+ condition = condition[:, :, [0, -1]]
+
+ for j in range(len(self.scheduler.model_outputs)):
+ if self.scheduler.model_outputs[j] is not None:
+ if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]:
+ self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]]
+ if self.scheduler.last_sample is not None:
+ self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]]
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ if enable_temporal_reasoning and latents.shape[2] > 2:
+ video_edit = self.vae.decode(latents[:, :, [0, -1]], return_dict=False)[0]
+ video_reason = self.vae.decode(latents[:, :, :-1], return_dict=False)[0]
+ video = torch.cat([video_reason, video_edit[:, :, 1:]], dim=2)
+ else:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return ChronoEditPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/chronoedit/pipeline_output.py b/src/diffusers/pipelines/chronoedit/pipeline_output.py
new file mode 100644
index 000000000000..b1df5b9de35d
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class ChronoEditPipelineOutput(BaseOutput):
+ r"""
+ Output class for ChronoEdit pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/tests/pipelines/chronoedit/__init__.py b/tests/pipelines/chronoedit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/chronoedit/test_chronoedit.py b/tests/pipelines/chronoedit/test_chronoedit.py
new file mode 100644
index 000000000000..a88c2e73e9b0
--- /dev/null
+++ b/tests/pipelines/chronoedit/test_chronoedit.py
@@ -0,0 +1,172 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ ChronoEditPipeline,
+ ChronoEditTransformer3DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class ChronoEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = ChronoEditPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = ChronoEditTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "enable_temporal_reasoning": True,
+ "num_temporal_reasoning_steps": 2,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass