From 13e516ce50678d58d941c37ed15061ae87ffeb3c Mon Sep 17 00:00:00 2001 From: junsong Date: Tue, 9 Sep 2025 16:41:53 +0000 Subject: [PATCH 01/36] 1. add `SanaVideoTransformer3DModel` in transformer_sana_video.py 2. add `SanaVideoPipeline` in pipeline_sana_video.py 3. add all code we need for import `SanaVideoPipeline` --- scripts/convert_sana_video_to_diffusers.py | 347 ++++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_sana_video.py | 848 ++++++++++++++ src/diffusers/pipelines/__init__.py | 3 +- src/diffusers/pipelines/sana/__init__.py | 2 + .../pipelines/sana/pipeline_output.py | 16 + .../pipelines/sana/pipeline_sana_video.py | 1009 +++++++++++++++++ 9 files changed, 2231 insertions(+), 1 deletion(-) create mode 100644 scripts/convert_sana_video_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_sana_video.py create mode 100644 src/diffusers/pipelines/sana/pipeline_sana_video.py diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py new file mode 100644 index 000000000000..d6d349c96c5c --- /dev/null +++ b/scripts/convert_sana_video_to_diffusers.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import os +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from termcolor import colored +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import ( + AutoencoderDC, + AutoencoderKLWan, + DPMSolverMultistepScheduler, + FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, + SanaVideoPipeline, + SanaVideoTransformer3DModel, +) +from diffusers.models.model_loading_utils import load_model_dict_into_meta +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +ckpt_ids = [ + "Efficient-Large-Model/SanaVideo_willquant/checkpoints/model.pth" +] +# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py + + +def main(args): + cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub") + + if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids: + ckpt_id = args.orig_ckpt_path or ckpt_ids[0] + snapshot_download( + repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}", + cache_dir=cache_dir_path, + repo_type="model", + ) + file_path = hf_hub_download( + repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}", + filename=f"{'/'.join(ckpt_id.split('/')[2:])}", + cache_dir=cache_dir_path, + repo_type="model", + ) + else: + file_path = args.orig_ckpt_path + + print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"])) + all_state_dict = torch.load(file_path, weights_only=True) + state_dict = all_state_dict.pop("state_dict") + converted_state_dict = {} + + # Patch embeddings. + converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight") + converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias") + + # Caption projection. + converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") + + # Handle different time embedding structure based on model type + + if args.model_type in ["SanaVideo"]: + # For Sana Sprint, the time embedding structure is different + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + # Guidance embedder for Sana Sprint + converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop( + "cfg_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias") + converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop( + "cfg_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias") + else: + # Original Sana time embedding structure + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( + "t_embedder.mlp.0.bias" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( + "t_embedder.mlp.2.bias" + ) + + # Shared norm. + converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias") + + # y norm + converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") + + # scheduler + flow_shift = 6.0 + + # model config + layer_num = 20 + # Positional embedding interpolation scale. + qk_norm = True + + for depth in range(layer_num): + # Transformer blocks. + converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( + f"blocks.{depth}.scale_shift_table" + ) + + # Linear Attention is all you need 🤘 + # Self attention. + q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + if qk_norm is not None: + # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) + # Projection. + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.attn.proj.bias" + ) + + # Feed-forward. + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop( + f"blocks.{depth}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.point_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop( + f"blocks.{depth}.mlp.t_conv.weight" + ) + + # Cross-attention. + q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") + q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") + k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + if qk_norm is not None: + # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.k_norm.weight" + ) + + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( + f"blocks.{depth}.cross_attn.proj.bias" + ) + + # Final block. + converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") + + # Transformer + with CTX(): + transformer_kwargs = { + "in_channels": 16, + "out_channels": 16, + "num_attention_heads": 20, + "attention_head_dim": 112, + "num_layers": 20, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "caption_channels": 2304, + "mlp_ratio": 3.0, + "attention_bias": False, + "sample_size": args.image_size // 16, + "patch_size": (1, 2, 2), + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 1024, + } + + transformer = SanaVideoTransformer3DModel(**transformer_kwargs) + + transformer.load_state_dict(converted_state_dict, strict=True, assign=True) + + try: + state_dict.pop("y_embedder.y_embedding") + state_dict.pop("pos_embed") + state_dict.pop("logvar_linear.weight") + state_dict.pop("logvar_linear.bias") + except KeyError: + print("y_embedder.y_embedding or pos_embed not found in the state_dict") + + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + transformer = transformer.to(weight_dtype) + + if not args.save_full_pipeline: + print( + colored( + f"Only saving transformer model of {args.model_type}. " + f"Set --save_full_pipeline to save the whole Pipeline", + "green", + attrs=["bold"], + ) + ) + transformer.save_pretrained( + os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB" + ) + else: + print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) + # VAE + # vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32) + vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) + + # Text Encoder + text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" + tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) + tokenizer.padding_side = "right" + text_encoder = AutoModelForCausalLM.from_pretrained( + text_encoder_model_path, torch_dtype=torch.bfloat16 + ).get_decoder() + + # Choose the appropriate pipeline and scheduler based on model type + # Original Sana scheduler + if args.scheduler_type == "flow-dpm_solver": + scheduler = DPMSolverMultistepScheduler( + flow_shift=flow_shift, + use_flow_sigmas=True, + prediction_type="flow_prediction", + ) + elif args.scheduler_type == "flow-euler": + scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) + elif args.scheduler_type == "uni-pc": + scheduler = UniPCMultistepScheduler( + prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift + ) + else: + raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") + + pipe = SanaVideoPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB") + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--image_size", + default=1024, + type=int, + choices=[512, 1024, 2048, 4096], + required=False, + help="Image size of pretrained model, 512, 1024, 2048 or 4096.", + ) + parser.add_argument( + "--model_type", + default="SanaMS_1600M_P1_D20", + type=str, + choices=[ + "SanaMS_1600M_P1_D20", + "SanaMS_600M_P1_D28", + "SanaMS1.5_1600M_P1_D20", + "SanaMS1.5_4800M_P1_D60", + "SanaSprint_1600M_P1_D20", + "SanaSprint_600M_P1_D28", + ], + ) + parser.add_argument( + "--scheduler_type", + default="flow-dpm_solver", + type=str, + choices=["flow-dpm_solver", "flow-euler", "uni-pc"], + help="Scheduler type to use.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") + parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") + + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + weight_dtype = DTYPE_MAPPING[args.dtype] + + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 94104667b541..09860f771ba7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -246,6 +246,7 @@ "QwenImageTransformer2DModel", "SanaControlNetModel", "SanaTransformer2DModel", + "SanaVideoTransformer3DModel", "SD3ControlNetModel", "SD3MultiControlNetModel", "SD3Transformer2DModel", @@ -542,6 +543,7 @@ "SanaControlNetPipeline", "SanaPAGPipeline", "SanaPipeline", + "SanaVideoPipeline", "SanaSprintImg2ImgPipeline", "SanaSprintPipeline", "SemanticStableDiffusionPipeline", @@ -951,6 +953,7 @@ QwenImageTransformer2DModel, SanaControlNetModel, SanaTransformer2DModel, + SanaVideoTransformer3DModel, SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, @@ -1217,6 +1220,7 @@ SanaControlNetPipeline, SanaPAGPipeline, SanaPipeline, + SanaVideoPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline, SemanticStableDiffusionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e3b297464143..808e6f253003 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -79,6 +79,7 @@ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"] + _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] @@ -204,6 +205,7 @@ PRXTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, + SanaVideoTransformer3DModel, SD3Transformer2DModel, SkyReelsV2Transformer3DModel, StableAudioDiTModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 2fe1159eec4c..0c9809bc3a62 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -13,6 +13,7 @@ from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .sana_transformer import SanaTransformer2DModel + from .transformer_sana_video import SanaVideoTransformer3DModel from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py new file mode 100644 index 000000000000..556af8b1e1ca --- /dev/null +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -0,0 +1,848 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention_processor import ( + Attention, + AttentionProcessor, +) +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GLUMBTempConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 4, + norm_type: Optional[str] = None, + residual_connection: bool = True, + ) -> None: + super().__init__() + + hidden_channels = int(expand_ratio * in_channels) + self.norm_type = norm_type + self.residual_connection = residual_connection + + self.nonlinearity = nn.SiLU() + self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) + self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) + self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) + + self.norm = None + if norm_type == "rms_norm": + self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) + + self.conv_temp = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual_connection: + residual = hidden_states + batch_size, num_frames, height, width, num_channels = hidden_states.shape + hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2) + + hidden_states = self.conv_inverted(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv_depth(hidden_states) + hidden_states, gate = torch.chunk(hidden_states, 2, dim=1) + hidden_states = hidden_states * self.nonlinearity(gate) + + hidden_states = self.conv_point(hidden_states) + + # Temporal aggregation + hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(0, 2, 1, 3) + hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal) + hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels) + + if self.norm_type == "rms_norm": + # move channel to the last dimension so we apply RMSnorm across channel dimension + hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.residual_connection: + hidden_states = hidden_states + residual + + return hidden_states + + +class SanaLinearAttnProcessor3_0: + r""" + Processor for implementing scaled dot-product linear attention. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = hidden_states.dtype + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # B,N,C + # B,H,C,N + # query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) + # B,H,N,C + # key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) + # B,N,H,C + # value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + # B,N,H,C + + query = F.relu(query) + key = F.relu(key) + + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query_rotate = apply_rotary_emb(query, *rotary_emb) + key_rotate = apply_rotary_emb(key, *rotary_emb) + + # B,H,C,N + query = query.permute(0, 2, 3, 1) + key = key.permute(0, 2, 3, 1) + query_rotate = query_rotate.permute(0, 2, 3, 1) + key_rotate = key_rotate.permute(0, 2, 3, 1) + value = value.permute(0, 2, 3, 1) + + query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float() + + z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15) + + scores = torch.matmul(value, key_rotate.transpose(-1, -2)) + hidden_states = torch.matmul(scores, query_rotate) + + hidden_states = hidden_states * z + # B,H,C,N + hidden_states = hidden_states.flatten(1, 2).transpose(1, 2) + hidden_states = hidden_states.to(original_dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class SanaLinearAttnProcessor3_1: + r""" + Processor for implementing scaled dot-product linear attention. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = hidden_states.dtype + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # B,N,C + # B,H,C,N + # query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) + # B,H,N,C + # key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) + # B,N,H,C + # value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + # B,N,H,C + + query = F.relu(query) + key = F.relu(key) + + # if rotary_emb is not None: + + # def apply_rotary_emb( + # hidden_states: torch.Tensor, + # freqs_cos: torch.Tensor, + # freqs_sin: torch.Tensor, + # ): + # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # cos = freqs_cos[..., 0::2] + # sin = freqs_sin[..., 1::2] + # out = torch.empty_like(hidden_states) + # out[..., 0::2] = x1 * cos - x2 * sin + # out[..., 1::2] = x1 * sin + x2 * cos + # return out.type_as(hidden_states) + + # query_rotate = apply_rotary_emb(query, *rotary_emb) + # key_rotate = apply_rotary_emb(key, *rotary_emb) + + # B,H,C,N + # query_rotate = query_rotate.permute(0, 2, 3, 1) + # key_rotate = key_rotate.permute(0, 2, 3, 1) + # value = value.permute(0, 2, 3, 1) + + query = query.permute(0, 2, 3, 1) + key = key.permute(0, 2, 3, 1) + query_rotate = query + key_rotate = key + value = value.permute(0, 2, 3, 1) + + query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float() + + z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15) + + scores = torch.matmul(value, key_rotate.transpose(-1, -2)) + hidden_states = torch.matmul(scores, query_rotate) + + hidden_states = hidden_states * z + # B,H,C,N + hidden_states = hidden_states.flatten(1, 2).transpose(1, 2) + hidden_states = hidden_states.to(original_dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# copy from https://github.com/huggingface/diffusers/blob/11d22e0e809d1219a067ded8a18f7b0129fc58c7/src/diffusers/models/transformers/transformer_wan.py#L410 +class WanRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +class SanaModulatedNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + +class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + guidance_proj = self.guidance_condition_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype)) + conditioning = timesteps_emb + guidance_emb + + return self.linear(self.silu(conditioning)), conditioning + + +class SanaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SanaVideoTransformerBlock(nn.Module): + r""" + Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). + """ + + def __init__( + self, + dim: int = 2240, + num_attention_heads: int = 20, + attention_head_dim: int = 112, + dropout: float = 0.0, + num_cross_attention_heads: Optional[int] = 20, + cross_attention_head_dim: Optional[int] = 112, + cross_attention_dim: Optional[int] = 2240, + attention_bias: bool = True, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + attention_out_bias: bool = True, + mlp_ratio: float = 3.0, + qk_norm: Optional[str] = "rms_norm_across_heads", + rope_max_seq_len: int = 1024, + ) -> None: + super().__init__() + + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_attention_heads if qk_norm is not None else None, + qk_norm=qk_norm, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + processor=SanaLinearAttnProcessor3_0(), + ) + + # 2. Cross Attention + if cross_attention_dim is not None: + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = Attention( + query_dim=dim, + qk_norm=qk_norm, + kv_heads=num_cross_attention_heads if qk_norm is not None else None, + cross_attention_dim=cross_attention_dim, + heads=num_cross_attention_heads, + dim_head=cross_attention_head_dim, + dropout=dropout, + bias=True, + out_bias=attention_out_bias, + processor=SanaAttnProcessor2_0(), + ) + + # 3. Feed-forward + self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) + + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + frames: int = None, + height: int = None, + width: int = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # 1. Modulation + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + # 2. Self Attention + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) + + attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = hidden_states + gate_msa * attn_output + + # 3. Cross Attention + if self.attn2 is not None: + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width)) + ff_output = self.ff(norm_hidden_states) + ff_output = ff_output.flatten(1, 3) + hidden_states = hidden_states + gate_mlp * ff_output + + return hidden_states + + +class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `20`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `112`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of Transformer blocks to use. + num_cross_attention_heads (`int`, *optional*, defaults to `20`): + The number of heads to use for cross-attention. + cross_attention_head_dim (`int`, *optional*, defaults to `112`): + The number of channels in each head for cross-attention. + cross_attention_dim (`int`, *optional*, defaults to `2240`): + The number of channels in the cross-attention output. + caption_channels (`int`, defaults to `2304`): + The number of channels in the caption embeddings. + mlp_ratio (`float`, defaults to `2.5`): + The expansion ratio to use in the GLUMBConv layer. + dropout (`float`, defaults to `0.0`): + The dropout probability. + attention_bias (`bool`, defaults to `False`): + Whether to use bias in the attention layer. + sample_size (`int`, defaults to `32`): + The base size of the input latent. + patch_size (`int`, defaults to `1`): + The size of the patches to use in the patch embedding layer. + norm_elementwise_affine (`bool`, defaults to `False`): + Whether to use elementwise affinity in the normalization layer. + norm_eps (`float`, defaults to `1e-6`): + The epsilon value for the normalization layer. + qk_norm (`str`, *optional*, defaults to `None`): + The normalization to use for the query and key. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"] + _skip_layerwise_casting_patterns = ["patch_embedding", "norm"] + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: Optional[int] = 16, + num_attention_heads: int = 20, + attention_head_dim: int = 112, + num_layers: int = 20, + num_cross_attention_heads: Optional[int] = 20, + cross_attention_head_dim: Optional[int] = 112, + cross_attention_dim: Optional[int] = 2240, + caption_channels: int = 2304, + mlp_ratio: float = 2.5, + dropout: float = 0.0, + attention_bias: bool = False, + sample_size: int = 32, + patch_size: Tuple[int, int, int] = (1, 2, 2), + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + guidance_embeds: bool = False, + guidance_embeds_scale: float = 0.1, + qk_norm: Optional[str] = "rms_norm_across_heads", + rope_max_seq_len: int = 1024, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Additional condition embeddings + if guidance_embeds: + self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim) + else: + self.time_embed = AdaLayerNormSingle(inner_dim) + + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + SanaVideoTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + num_cross_attention_heads=num_cross_attention_heads, + cross_attention_head_dim=cross_attention_head_dim, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output blocks + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + guidance: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + if guidance is not None: + timestep, embedded_timestep = self.time_embed( + timestep, guidance=guidance, hidden_dtype=hidden_states.dtype + ) + else: + timestep, embedded_timestep = self.time_embed( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + encoder_hidden_states = self.caption_norm(encoder_hidden_states) + + # 2. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for index_block, block in enumerate(self.transformer_blocks): + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + post_patch_num_frames, + post_patch_height, + post_patch_width, + rotary_emb, + ) + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): + hidden_states = hidden_states + controlnet_block_samples[index_block - 1] + + else: + for index_block, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + post_patch_num_frames, + post_patch_height, + post_patch_width, + rotary_emb, + ) + if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): + hidden_states = hidden_states + controlnet_block_samples[index_block - 1] + + # 3. Normalization + hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table) + + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index db357669b6f3..0b7dfa579881 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -308,6 +308,7 @@ "SanaSprintPipeline", "SanaControlNetPipeline", "SanaSprintImg2ImgPipeline", + "SanaVideoPipeline", ] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] @@ -735,7 +736,7 @@ QwenImageInpaintPipeline, QwenImagePipeline, ) - from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline + from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline, SanaVideoPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 91684f35f153..d5571ab12fac 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"] + _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -39,6 +40,7 @@ from .pipeline_sana_controlnet import SanaControlNetPipeline from .pipeline_sana_sprint import SanaSprintPipeline from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline + from .pipeline_sana_video import SanaVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py index f8ac12951644..8021b7738755 100644 --- a/src/diffusers/pipelines/sana/pipeline_output.py +++ b/src/diffusers/pipelines/sana/pipeline_output.py @@ -3,6 +3,7 @@ import numpy as np import PIL.Image +import torch from ...utils import BaseOutput @@ -19,3 +20,18 @@ class SanaPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class SanaVideoPipelineOutput(BaseOutput): + r""" + Output class for Sana-Video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py new file mode 100644 index 000000000000..60ec749163dc --- /dev/null +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -0,0 +1,1009 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaVideoTransformer3DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...video_processor import VideoProcessor +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +ASPECT_RATIO_VIDEO_480_MS = { + "0.5": [448.0, 896.0], + "0.57": [480.0, 832.0], + "0.68": [528.0, 768.0], + "0.78": [560.0, 720.0], + "1.0": [624.0, 624.0], + "1.13": [672.0, 592.0], + "1.29": [720.0, 560.0], + "1.46": [768.0, 528.0], + "1.67": [816.0, 496.0], + "1.75": [832.0, 480.0], + "2.0": [896.0, 448.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPipeline + + >>> pipe = SanaPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> video = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"', frames=81)[0] + >>> # Save video frames or process as needed + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2410.10629). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`DPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaVideoTransformer3DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + # 使用与Wan相同的VAE,设置时间和空间缩放因子 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + # 兼容原有的vae_scale_factor属性 + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + height: int = 480, + width: int = 720, + frames: int = 81, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: List[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> Union[SanaVideoPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 720): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`List[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) From 5eb53546ae931a4986fcbecf05da08a55c470480 Mon Sep 17 00:00:00 2001 From: junsong Date: Tue, 9 Sep 2025 16:51:43 +0000 Subject: [PATCH 02/36] add a sample about how to use sana-video; --- tmp.py | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 tmp.py diff --git a/tmp.py b/tmp.py new file mode 100644 index 000000000000..fc8225f29f7a --- /dev/null +++ b/tmp.py @@ -0,0 +1,123 @@ +import torch +from diffusers import SanaPipeline, SanaVideoPipeline, UniPCMultistepScheduler +from diffusers import AutoencoderKLWan +from diffusers.utils import export_to_video + + +def sana_video(): + # pipe = SanaPipeline.from_pretrained( + # "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers", + # torch_dtype=torch.bfloat16, + # ) + + model_id = "sana_video" + # model_id = "sana_video_unipc" + pipe = SanaVideoPipeline.from_pretrained( + model_id, + vae=None, + torch_dtype=torch.bfloat16, + ) + vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) + pipe.vae=vae + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) + + pipe.text_encoder.to(torch.bfloat16) + + pipe.to("cuda") + + prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + frames=81, + guidance_scale=6, + num_inference_steps=30, + generator=torch.Generator(device="cuda").manual_seed(42), + ).frames[0] + + export_to_video(video, "sana.mp4", fps=16) + + +def profile_sana_video(): + from tqdm import tqdm + import time + model_id = "sana_video" + pipe = SanaVideoPipeline.from_pretrained( + model_id, + vae=None, + ) + vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) + pipe.vae=vae + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) + pipe.text_encoder.to(torch.bfloat16) + pipe.to("cuda") + + prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + for i in tqdm(range(1), desc="Warmup"): + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + frames=81, + guidance_scale=6, + num_inference_steps=30, + generator=torch.Generator(device="cuda").manual_seed(42), + ).frames[0] + + n = 10 + time_start = time.time() + for i in tqdm(range(n), desc=f"Inference {n} times"): + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + frames=81, + guidance_scale=6, + num_inference_steps=30, + generator=torch.Generator(device="cuda").manual_seed(42), + ).frames[0] + + time_end = time.time() + print(f"Time taken: {(time_end - time_start)/n} seconds/video, {n / (time_end - time_start) * 81} fps") + + +def wan(): + import torch + from diffusers.utils import export_to_video + from diffusers import AutoencoderKLWan, WanPipeline + from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + + # model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + pipe.to("cuda") + + prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=720, + width=1280, + num_frames=81, + guidance_scale=5.0, + ).frames[0] + export_to_video(output, "output.mp4", fps=16) + + +if __name__ == "__main__": + sana_video() + # profile_sana_video() + # wan() \ No newline at end of file From c6d78763ece478fb921674d8fdb2552e6d7e02b8 Mon Sep 17 00:00:00 2001 From: junsong Date: Thu, 11 Sep 2025 01:39:27 +0000 Subject: [PATCH 03/36] code update; --- scripts/convert_sana_video_to_diffusers.py | 61 +++++-------------- .../pipelines/sana/pipeline_sana_video.py | 38 ++++++++---- tmp.py | 37 ++++------- 3 files changed, 55 insertions(+), 81 deletions(-) diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index d6d349c96c5c..41d33f16b53a 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -27,7 +27,8 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ - "Efficient-Large-Model/SanaVideo_willquant/checkpoints/model.pth" + # "Efficient-Large-Model/SanaVideo_willquant/checkpoints/model.pth" + "Efficient-Large-Model/SanaVideo_willquant_v2/checkpoints/model.pth" ] # https://github.com/NVlabs/Sana/blob/main/scripts/inference.py @@ -66,42 +67,18 @@ def main(args): converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - # Handle different time embedding structure based on model type - - if args.model_type in ["SanaVideo"]: - # For Sana Sprint, the time embedding structure is different - converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") - - # Guidance embedder for Sana Sprint - converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop( - "cfg_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias") - converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop( - "cfg_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias") - else: - # Original Sana time embedding structure - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( - "t_embedder.mlp.0.bias" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( - "t_embedder.mlp.2.bias" - ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( + "t_embedder.mlp.0.bias" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( + "t_embedder.mlp.2.bias" + ) # Shared norm. converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") @@ -255,7 +232,6 @@ def main(args): else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - # vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) # Text Encoder @@ -317,15 +293,10 @@ def main(args): ) parser.add_argument( "--model_type", - default="SanaMS_1600M_P1_D20", + default="SanaVideo", type=str, choices=[ - "SanaMS_1600M_P1_D20", - "SanaMS_600M_P1_D28", - "SanaMS1.5_1600M_P1_D20", - "SanaMS1.5_4800M_P1_D60", - "SanaSprint_1600M_P1_D20", - "SanaSprint_600M_P1_D28", + "SanaVideo", ], ) parser.add_argument( diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 60ec749163dc..707d4481e411 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -25,7 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, SanaVideoTransformer3DModel +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import DPMSolverMultistepScheduler from ...video_processor import VideoProcessor from ...utils import ( @@ -82,17 +82,31 @@ Examples: ```py >>> import torch - >>> from diffusers import SanaPipeline - - >>> pipe = SanaPipeline.from_pretrained( - ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + >>> from diffusers import SanaVideoPipeline + >>> from diffusers.utils import export_to_video + >>> model_id = "sana_video" + >>> pipe = SanaVideoPipeline.from_pretrained( + ... model_id, ... ) - >>> pipe.to("cuda") - >>> pipe.text_encoder.to(torch.bfloat16) - >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) - - >>> video = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"', frames=81)[0] - >>> # Save video frames or process as needed + ... pipe.transformer.to(torch.bfloat16) + ... pipe.text_encoder.to(torch.bfloat16) + ... pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-video-output.mp4", fps=16) ``` """ @@ -188,7 +202,7 @@ def __init__( self, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], text_encoder: Gemma2PreTrainedModel, - vae: AutoencoderDC, + vae: Union[AutoencoderDC, AutoencoderKLWan], transformer: SanaVideoTransformer3DModel, scheduler: DPMSolverMultistepScheduler, ): diff --git a/tmp.py b/tmp.py index fc8225f29f7a..38b3e4a50d2c 100644 --- a/tmp.py +++ b/tmp.py @@ -5,28 +5,18 @@ def sana_video(): - # pipe = SanaPipeline.from_pretrained( - # "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers", - # torch_dtype=torch.bfloat16, - # ) - model_id = "sana_video" + # model_id = "sana_video" + model_id = "sana_video_v2" # model_id = "sana_video_unipc" - pipe = SanaVideoPipeline.from_pretrained( - model_id, - vae=None, - torch_dtype=torch.bfloat16, - ) - vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) - pipe.vae=vae - pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) - + pipe = SanaVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) + pipe.vae.to(torch.bfloat32) pipe.text_encoder.to(torch.bfloat16) - pipe.to("cuda") - prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." - negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + prompt = "Extreme close-up of a thoughtful, gray-haired professor in his 60s, sitting motionless in a Paris café, dressed in a wool coat and beret, pondering the universe. His subtle closed-mouth smile reveals an answer. Golden light, cinematic depth of field, Paris streets blurred in the background. Cinematic 35mm film." + negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." video = pipe( prompt=prompt, @@ -35,11 +25,11 @@ def sana_video(): width=832, frames=81, guidance_scale=6, - num_inference_steps=30, + num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0] - export_to_video(video, "sana.mp4", fps=16) + export_to_video(video, "sana_v2.mp4", fps=16) def profile_sana_video(): @@ -48,11 +38,10 @@ def profile_sana_video(): model_id = "sana_video" pipe = SanaVideoPipeline.from_pretrained( model_id, - vae=None, + torch_dtype=torch.bfloat16, ) - vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) - pipe.vae=vae - pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) + # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) + pipe.vae.to(torch.float32) pipe.text_encoder.to(torch.bfloat16) pipe.to("cuda") @@ -81,7 +70,7 @@ def profile_sana_video(): width=832, frames=81, guidance_scale=6, - num_inference_steps=30, + num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0] From d67ab2a6e2fde04184886517d6ba793c18c3ff9f Mon Sep 17 00:00:00 2001 From: junsong Date: Thu, 11 Sep 2025 01:46:40 +0000 Subject: [PATCH 04/36] update hf model path; --- tmp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tmp.py b/tmp.py index 38b3e4a50d2c..eba9224b57f2 100644 --- a/tmp.py +++ b/tmp.py @@ -7,7 +7,7 @@ def sana_video(): # model_id = "sana_video" - model_id = "sana_video_v2" + model_id = "hf://Efficient-Large-Model/sana_video_v2" # model_id = "sana_video_unipc" pipe = SanaVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) From a5f19e09269481a35148863d9b4d24c2bd03e7c1 Mon Sep 17 00:00:00 2001 From: junsong Date: Fri, 3 Oct 2025 16:33:05 +0000 Subject: [PATCH 05/36] update code; --- tmp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tmp.py b/tmp.py index eba9224b57f2..2110073dfc48 100644 --- a/tmp.py +++ b/tmp.py @@ -6,12 +6,10 @@ def sana_video(): - # model_id = "sana_video" - model_id = "hf://Efficient-Large-Model/sana_video_v2" - # model_id = "sana_video_unipc" + model_id = "Efficient-Large-Model/sana_video_v2_diffusers" pipe = SanaVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) - pipe.vae.to(torch.bfloat32) + pipe.vae.to(torch.float32) pipe.text_encoder.to(torch.bfloat16) pipe.to("cuda") @@ -60,7 +58,7 @@ def profile_sana_video(): generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0] - n = 10 + n = 1 time_start = time.time() for i in tqdm(range(n), desc=f"Inference {n} times"): video = pipe( @@ -109,4 +107,4 @@ def wan(): if __name__ == "__main__": sana_video() # profile_sana_video() - # wan() \ No newline at end of file + # wan() From c15ae23c2646ba77d7747118f7faf5fbb338db7c Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Mon, 3 Nov 2025 19:02:21 -0800 Subject: [PATCH 06/36] sana-video can run now; --- scripts/convert_sana_video_to_diffusers.py | 33 ++---- src/diffusers/__init__.py | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_sana_video.py | 2 +- src/diffusers/pipelines/__init__.py | 8 +- .../pipelines/sana/pipeline_sana_video.py | 31 +++-- tmp.py | 110 ------------------ 7 files changed, 36 insertions(+), 152 deletions(-) delete mode 100644 tmp.py diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index 41d33f16b53a..1df6483f0f07 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -12,25 +12,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( - AutoencoderDC, AutoencoderKLWan, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, - UniPCMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel, + UniPCMultistepScheduler, ) -from diffusers.models.model_loading_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available CTX = init_empty_weights if is_accelerate_available else nullcontext -ckpt_ids = [ - # "Efficient-Large-Model/SanaVideo_willquant/checkpoints/model.pth" - "Efficient-Large-Model/SanaVideo_willquant_v2/checkpoints/model.pth" -] -# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py +ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"] +# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py def main(args): @@ -67,18 +62,10 @@ def main(args): converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( - "t_embedder.mlp.0.bias" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( - "t_embedder.mlp.2.bias" - ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") # Shared norm. converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") @@ -88,7 +75,7 @@ def main(args): converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") # scheduler - flow_shift = 6.0 + flow_shift = 8.0 # model config layer_num = 20 @@ -232,7 +219,9 @@ def main(args): else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) + vae = AutoencoderKLWan.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 + ) # Text Encoder text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 09860f771ba7..f1000ff689c7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1220,9 +1220,9 @@ SanaControlNetPipeline, SanaPAGPipeline, SanaPipeline, - SanaVideoPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline, + SanaVideoPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 0c9809bc3a62..15408a4b15cc 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -13,7 +13,6 @@ from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .sana_transformer import SanaTransformer2DModel - from .transformer_sana_video import SanaVideoTransformer3DModel from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel @@ -37,6 +36,7 @@ from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel + from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 556af8b1e1ca..7ca43dabf90e 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -263,7 +263,7 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) return hidden_states - + # copy from https://github.com/huggingface/diffusers/blob/11d22e0e809d1219a067ded8a18f7b0129fc58c7/src/diffusers/models/transformers/transformer_wan.py#L410 class WanRotaryPosEmbed(nn.Module): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0b7dfa579881..87d953845e21 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -736,7 +736,13 @@ QwenImageInpaintPipeline, QwenImagePipeline, ) - from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline, SanaVideoPipeline + from .sana import ( + SanaControlNetPipeline, + SanaPipeline, + SanaSprintImg2ImgPipeline, + SanaSprintPipeline, + SanaVideoPipeline, + ) from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 707d4481e411..fdaf338a9e9c 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -16,7 +16,6 @@ import inspect import re import urllib.parse as ul -import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -27,7 +26,6 @@ from ...loaders import SanaLoraLoaderMixin from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import DPMSolverMultistepScheduler -from ...video_processor import VideoProcessor from ...utils import ( BACKENDS_MAPPING, USE_PEFT_BACKEND, @@ -39,7 +37,8 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -84,16 +83,18 @@ >>> import torch >>> from diffusers import SanaVideoPipeline >>> from diffusers.utils import export_to_video - >>> model_id = "sana_video" - >>> pipe = SanaVideoPipeline.from_pretrained( - ... model_id, - ... ) - ... pipe.transformer.to(torch.bfloat16) - ... pipe.text_encoder.to(torch.bfloat16) - ... pipe.to("cuda") - - >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers" + >>> pipe = SanaVideoPipeline.from_pretrained(model_id) + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> model_score = 30 + + >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm色调。A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {model_score}." + >>> prompt = prompt + motion_prompt >>> output = pipe( ... prompt=prompt, @@ -212,13 +213,11 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - # 使用与Wan相同的VAE,设置时间和空间缩放因子 self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 - - # 兼容原有的vae_scale_factor属性 + self.vae_scale_factor = self.vae_scale_factor_spatial - + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) diff --git a/tmp.py b/tmp.py deleted file mode 100644 index 2110073dfc48..000000000000 --- a/tmp.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -from diffusers import SanaPipeline, SanaVideoPipeline, UniPCMultistepScheduler -from diffusers import AutoencoderKLWan -from diffusers.utils import export_to_video - - -def sana_video(): - - model_id = "Efficient-Large-Model/sana_video_v2_diffusers" - pipe = SanaVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) - # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) - pipe.vae.to(torch.float32) - pipe.text_encoder.to(torch.bfloat16) - pipe.to("cuda") - - prompt = "Extreme close-up of a thoughtful, gray-haired professor in his 60s, sitting motionless in a Paris café, dressed in a wool coat and beret, pondering the universe. His subtle closed-mouth smile reveals an answer. Golden light, cinematic depth of field, Paris streets blurred in the background. Cinematic 35mm film." - negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." - - video = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - height=480, - width=832, - frames=81, - guidance_scale=6, - num_inference_steps=50, - generator=torch.Generator(device="cuda").manual_seed(42), - ).frames[0] - - export_to_video(video, "sana_v2.mp4", fps=16) - - -def profile_sana_video(): - from tqdm import tqdm - import time - model_id = "sana_video" - pipe = SanaVideoPipeline.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, - ) - # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) - pipe.vae.to(torch.float32) - pipe.text_encoder.to(torch.bfloat16) - pipe.to("cuda") - - prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." - negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" - - for i in tqdm(range(1), desc="Warmup"): - video = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - height=480, - width=832, - frames=81, - guidance_scale=6, - num_inference_steps=30, - generator=torch.Generator(device="cuda").manual_seed(42), - ).frames[0] - - n = 1 - time_start = time.time() - for i in tqdm(range(n), desc=f"Inference {n} times"): - video = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - height=480, - width=832, - frames=81, - guidance_scale=6, - num_inference_steps=50, - generator=torch.Generator(device="cuda").manual_seed(42), - ).frames[0] - - time_end = time.time() - print(f"Time taken: {(time_end - time_start)/n} seconds/video, {n / (time_end - time_start) * 81} fps") - - -def wan(): - import torch - from diffusers.utils import export_to_video - from diffusers import AutoencoderKLWan, WanPipeline - from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler - - # model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" - model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" - vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) - flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P - pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) - pipe.to("cuda") - - prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." - negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" - - output = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - height=720, - width=1280, - num_frames=81, - guidance_scale=5.0, - ).frames[0] - export_to_video(output, "output.mp4", fps=16) - - -if __name__ == "__main__": - sana_video() - # profile_sana_video() - # wan() From ee79af30193c2046924ee528da47fcb8f476b3a3 Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Mon, 3 Nov 2025 21:12:38 -0800 Subject: [PATCH 07/36] 1. add aspect ratio in sana-video-pipeline; 2. add reshape function in sana-video-processor; 3. fix convert pth to safetensor bugs; --- scripts/convert_sana_video_to_diffusers.py | 22 ++- .../pipelines/sana/pipeline_sana_video.py | 142 ++++++++++-------- src/diffusers/video_processor.py | 65 +++++++- 3 files changed, 163 insertions(+), 66 deletions(-) diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index 1df6483f0f07..67fffceb563b 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -82,6 +82,16 @@ def main(args): # Positional embedding interpolation scale. qk_norm = True + # sample size + if args.video_size == 480: + sample_size = 30 # Wan-VAE: 8xp2 downsample factor + patch_size = (1, 2, 2) + elif args.video_size == 720: + sample_size = 22 # Wan-VAE: 32xp1 downsample factor + patch_size = (1, 1, 1) + else: + raise ValueError(f"Video size {args.video_size} is not supported.") + for depth in range(layer_num): # Transformer blocks. converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( @@ -177,8 +187,8 @@ def main(args): "caption_channels": 2304, "mlp_ratio": 3.0, "attention_bias": False, - "sample_size": args.image_size // 16, - "patch_size": (1, 2, 2), + "sample_size": sample_size, + "patch_size": patch_size, "norm_elementwise_affine": False, "norm_eps": 1e-6, "qk_norm": "rms_norm_across_heads", @@ -273,12 +283,12 @@ def main(args): "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." ) parser.add_argument( - "--image_size", - default=1024, + "--video_size", + default=480, type=int, - choices=[512, 1024, 2048, 4096], + choices=[480, 720], required=False, - help="Image size of pretrained model, 512, 1024, 2048 or 4096.", + help="Video size of pretrained model, 480 or 720.", ) parser.add_argument( "--model_type", diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index fdaf338a9e9c..6d671cf8aff0 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -16,6 +16,7 @@ import inspect import re import urllib.parse as ul +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -37,17 +38,41 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from ..pixart_alpha.pipeline_pixart_alpha import ( - ASPECT_RATIO_512_BIN, - ASPECT_RATIO_1024_BIN, -) -from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN from .pipeline_output import SanaVideoPipelineOutput +ASPECT_RATIO_480_BIN = { + "0.5": [448.0, 896.0], + "0.57": [480.0, 832.0], + "0.68": [528.0, 768.0], + "0.78": [560.0, 720.0], + "1.0": [624.0, 624.0], + "1.13": [672.0, 592.0], + "1.29": [720.0, 560.0], + "1.46": [768.0, 528.0], + "1.67": [816.0, 496.0], + "1.75": [832.0, 480.0], + "2.0": [896.0, 448.0], +} + + +ASPECT_RATIO_720_BIN = { + "0.5": [672.0, 1344.0], + "0.57": [704.0, 1280.0], + "0.68": [800.0, 1152.0], + "0.78": [832.0, 1088.0], + "1.0": [960.0, 960.0], + "1.13": [1024.0, 896.0], + "1.29": [1088.0, 832.0], + "1.46": [1152.0, 800.0], + "1.67": [1248.0, 736.0], + "1.75": [1280.0, 704.0], + "2.0": [1344.0, 672.0], +} + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -63,19 +88,6 @@ if is_ftfy_available(): import ftfy -ASPECT_RATIO_VIDEO_480_MS = { - "0.5": [448.0, 896.0], - "0.57": [480.0, 832.0], - "0.68": [528.0, 768.0], - "0.78": [560.0, 720.0], - "1.0": [624.0, 624.0], - "1.13": [672.0, 592.0], - "1.29": [720.0, 560.0], - "1.46": [768.0, 528.0], - "1.67": [816.0, 496.0], - "1.75": [832.0, 480.0], - "2.0": [896.0, 448.0], -} EXAMPLE_DOC_STRING = """ Examples: @@ -174,7 +186,7 @@ def retrieve_timesteps( class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): r""" - Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2410.10629). + Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -184,7 +196,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): The tokenizer used to tokenize the prompt. text_encoder ([`Gemma2PreTrainedModel`]): Text encoder model to encode the input prompts. - vae ([`AutoencoderDC`]): + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer ([`SanaVideoTransformer3DModel`]): Conditional Transformer to denoise the input latents. @@ -218,7 +230,6 @@ def __init__( self.vae_scale_factor = self.vae_scale_factor_spatial - self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def enable_vae_slicing(self): @@ -239,7 +250,7 @@ def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. + processing larger videos. """ self.vae.enable_tiling() @@ -313,7 +324,7 @@ def encode_prompt( prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, @@ -331,13 +342,13 @@ def encode_prompt( prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): - The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For PixArt-Alpha, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not - num_images_per_prompt (`int`, *optional*, defaults to 1): - number of images that should be generated per prompt + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): @@ -399,10 +410,10 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -422,11 +433,11 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None @@ -721,13 +732,13 @@ def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", - num_inference_steps: int = 20, + num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, height: int = 480, - width: int = 720, + width: int = 832, frames: int = 81, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -766,8 +777,8 @@ def __call__( The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_inference_steps (`int`, *optional*, defaults to 20): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument @@ -787,7 +798,7 @@ def __call__( The number of videos to generate per prompt. height (`int`, *optional*, defaults to 480): The height in pixels of the generated video. - width (`int`, *optional*, defaults to 720): + width (`int`, *optional*, defaults to 832): The width in pixels of the generated video. frames (`int`, *optional*, defaults to 81): The number of frames in the generated video. @@ -811,10 +822,9 @@ def __call__( negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated video. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + The output format of the generated video. Choose between mp4 or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. attention_kwargs: A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -825,8 +835,8 @@ def __call__( prompt. use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using - `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to - the requested resolution. Useful for generating non-square images. + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, they are resized back to + the requested resolution. Useful for generating non-square videos. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -845,8 +855,8 @@ def __call__( Examples: Returns: - [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated videos """ @@ -855,18 +865,14 @@ def __call__( # 1. Check inputs. Raise error if not correct if use_resolution_binning: - if self.transformer.config.sample_size == 128: - aspect_ratio_bin = ASPECT_RATIO_4096_BIN - elif self.transformer.config.sample_size == 64: - aspect_ratio_bin = ASPECT_RATIO_2048_BIN - elif self.transformer.config.sample_size == 32: - aspect_ratio_bin = ASPECT_RATIO_1024_BIN - elif self.transformer.config.sample_size == 16: - aspect_ratio_bin = ASPECT_RATIO_512_BIN + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width - height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, @@ -905,7 +911,7 @@ def __call__( prompt, self.do_classifier_free_guidance, negative_prompt=negative_prompt, - num_images_per_prompt=num_videos_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -997,8 +1003,16 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if not output_type == "latent": + if output_type == "latent": + video = latents + else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -1008,10 +1022,20 @@ def __call__( latents.device, latents.dtype ) latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + if not output_type == "latent": video = self.video_processor.postprocess_video(video, output_type=output_type) - else: - video = latents # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 59b59b47d2c7..54e3cc4a70a4 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -13,11 +13,12 @@ # limitations under the License. import warnings -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch +import torch.nn.functional as F from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist @@ -111,3 +112,65 @@ def postprocess_video( raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs + + @staticmethod + def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: + r""" + Returns the binned height and width based on the aspect ratio. + + Args: + height (`int`): The height of the image. + width (`int`): The width of the image. + ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width). + + Returns: + `Tuple[int, int]`: The closest binned height and width. + """ + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + r""" + Resizes and crops a tensor of videos to the specified dimensions. + + Args: + samples (`torch.Tensor`): + A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the + number of frames, H is the height, and W is the width. + new_width (`int`): The desired width of the output videos. + new_height (`int`): The desired height of the output videos. + + Returns: + `torch.Tensor`: A tensor containing the resized and cropped videos. + """ + orig_height, orig_width = samples.shape[3], samples.shape[4] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Reshape to (N*T, C, H, W) for interpolation + n, c, t, h, w = samples.shape + samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w) + + # Resize + samples = F.interpolate( + samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[:, :, start_y:end_y, start_x:end_x] + + # Reshape back to (N, C, T, H, W) + samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4) + + return samples \ No newline at end of file From 49557c107bf3094e640f04ce3d41ba4eec69130d Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Mon, 3 Nov 2025 21:22:28 -0800 Subject: [PATCH 08/36] default to use `use_resolution_binning`; --- src/diffusers/pipelines/sana/pipeline_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 6d671cf8aff0..f7f0f1e80b75 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -750,7 +750,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, clean_caption: bool = False, - use_resolution_binning: bool = False, + use_resolution_binning: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], From 857ca301ff45d5719894f28458f8faaf60dfd802 Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Mon, 3 Nov 2025 21:22:46 -0800 Subject: [PATCH 09/36] make style; --- src/diffusers/pipelines/sana/pipeline_sana_video.py | 1 - src/diffusers/video_processor.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index f7f0f1e80b75..aaf9255e7f67 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -23,7 +23,6 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PixArtImageProcessor from ...loaders import SanaLoraLoaderMixin from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import DPMSolverMultistepScheduler diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 54e3cc4a70a4..881f5dfe88b5 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -173,4 +173,4 @@ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: in # Reshape back to (N, C, T, H, W) samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4) - return samples \ No newline at end of file + return samples From 3ed7000741aaeba6c3c5f63ffd4da27c55b9273b Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Mon, 3 Nov 2025 21:31:41 -0800 Subject: [PATCH 10/36] remove unused code; --- .../transformers/transformer_sana_video.py | 96 ------------------- 1 file changed, 96 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 7ca43dabf90e..b8683924d0fe 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -118,13 +118,6 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # B,N,C - # B,H,C,N - # query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) - # B,H,N,C - # key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) - # B,N,H,C - # value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) @@ -176,95 +169,6 @@ def apply_rotary_emb( return hidden_states -class SanaLinearAttnProcessor3_1: - r""" - Processor for implementing scaled dot-product linear attention. - """ - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - original_dtype = hidden_states.dtype - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # B,N,C - # B,H,C,N - # query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) - # B,H,N,C - # key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) - # B,N,H,C - # value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - # B,N,H,C - - query = F.relu(query) - key = F.relu(key) - - # if rotary_emb is not None: - - # def apply_rotary_emb( - # hidden_states: torch.Tensor, - # freqs_cos: torch.Tensor, - # freqs_sin: torch.Tensor, - # ): - # x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) - # cos = freqs_cos[..., 0::2] - # sin = freqs_sin[..., 1::2] - # out = torch.empty_like(hidden_states) - # out[..., 0::2] = x1 * cos - x2 * sin - # out[..., 1::2] = x1 * sin + x2 * cos - # return out.type_as(hidden_states) - - # query_rotate = apply_rotary_emb(query, *rotary_emb) - # key_rotate = apply_rotary_emb(key, *rotary_emb) - - # B,H,C,N - # query_rotate = query_rotate.permute(0, 2, 3, 1) - # key_rotate = key_rotate.permute(0, 2, 3, 1) - # value = value.permute(0, 2, 3, 1) - - query = query.permute(0, 2, 3, 1) - key = key.permute(0, 2, 3, 1) - query_rotate = query - key_rotate = key - value = value.permute(0, 2, 3, 1) - - query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float() - - z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15) - - scores = torch.matmul(value, key_rotate.transpose(-1, -2)) - hidden_states = torch.matmul(scores, query_rotate) - - hidden_states = hidden_states * z - # B,H,C,N - hidden_states = hidden_states.flatten(1, 2).transpose(1, 2) - hidden_states = hidden_states.to(original_dtype) - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - # copy from https://github.com/huggingface/diffusers/blob/11d22e0e809d1219a067ded8a18f7b0129fc58c7/src/diffusers/models/transformers/transformer_wan.py#L410 class WanRotaryPosEmbed(nn.Module): def __init__( From 439bf58d2702e2747ac7b1834de4b29b74e5a5a9 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 11:59:58 +0800 Subject: [PATCH 11/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index b8683924d0fe..9f46f8f585a4 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -169,7 +169,7 @@ def apply_rotary_emb( return hidden_states -# copy from https://github.com/huggingface/diffusers/blob/11d22e0e809d1219a067ded8a18f7b0129fc58c7/src/diffusers/models/transformers/transformer_wan.py#L410 +# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed class WanRotaryPosEmbed(nn.Module): def __init__( self, From de4cf3180de49618bf3daca162728ba2e9e11982 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:01:05 +0800 Subject: [PATCH 12/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 9f46f8f585a4..ba74898ccd0c 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -339,7 +339,7 @@ def __call__( class SanaVideoTransformerBlock(nn.Module): r""" - Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). + Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695). """ def __init__( From fe73287ec72ca989affcce29746804290bcd1bdd Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:01:27 +0800 Subject: [PATCH 13/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index ba74898ccd0c..e98aa1401ad5 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -446,7 +446,7 @@ def forward( class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" - A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. + A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models. Args: in_channels (`int`, defaults to `16`): From 118677a45cf70028d5f2a84160dda31815075489 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:23:58 +0800 Subject: [PATCH 14/36] Update src/diffusers/pipelines/sana/pipeline_sana_video.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/sana/pipeline_sana_video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index aaf9255e7f67..89c7144aec9d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -1033,7 +1033,6 @@ def __call__( if use_resolution_binning: video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) - if not output_type == "latent": video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models From 3546c443bf3ec066a80cf96d92e4020731554c42 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:24:10 +0800 Subject: [PATCH 15/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index e98aa1401ad5..6146f26402e7 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -444,7 +444,7 @@ def forward( return hidden_states -class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin): r""" A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models. From f845bba0c8ce0480b6e87e13f26505b18b6042d7 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:24:47 +0800 Subject: [PATCH 16/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../transformers/transformer_sana_video.py | 60 ------------------- 1 file changed, 60 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 6146f26402e7..4e3dc9acbd65 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -560,66 +560,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def forward( self, hidden_states: torch.Tensor, From 77714ba7a4e169ad4c5c90aa0fade7d961776a49 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:26:47 +0800 Subject: [PATCH 17/36] Update src/diffusers/models/transformers/transformer_sana_video.py --- src/diffusers/models/transformers/transformer_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 4e3dc9acbd65..c0001eaede8a 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -504,7 +504,7 @@ def __init__( mlp_ratio: float = 2.5, dropout: float = 0.0, attention_bias: bool = False, - sample_size: int = 32, + sample_size: int = 30, patch_size: Tuple[int, int, int] = (1, 2, 2), norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, From b536cfdf86dba21104ae9c81546c2d68586dc245 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:27:12 +0800 Subject: [PATCH 18/36] Update src/diffusers/pipelines/sana/pipeline_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/sana/pipeline_sana_video.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 89c7144aec9d..1dc60eb250fd 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -702,7 +702,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) return latents @property From b0f4866797af17d391e2b1b16eb204dcdc26bd2c Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:28:20 +0800 Subject: [PATCH 19/36] Update src/diffusers/models/transformers/transformer_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_sana_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index c0001eaede8a..dd58b4ec8567 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -234,6 +234,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm class SanaModulatedNorm(nn.Module): def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6): super().__init__() From 7205eee0a81c3aa566400b588f737e37b2ad2cae Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 12:29:05 +0800 Subject: [PATCH 20/36] Update src/diffusers/pipelines/sana/pipeline_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../pipelines/sana/pipeline_sana_video.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 1dc60eb250fd..b7325df7896b 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -231,35 +231,6 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger videos. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], From fd5cff2e1c349c43d44e63705c3846b8d7425e2f Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Tue, 4 Nov 2025 21:12:30 -0800 Subject: [PATCH 21/36] support `dispatch_attention_fn` --- .../transformers/transformer_sana_video.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index dd58b4ec8567..bd1efa45129e 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -22,10 +22,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import ( - Attention, - AttentionProcessor, -) +from ..attention import AttentionMixin +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -276,6 +275,8 @@ class SanaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -314,19 +315,23 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) # linear proj hidden_states = attn.to_out[0](hidden_states) From f2a9d0bfbb01ea376a5b4101612599b2545300d9 Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Tue, 4 Nov 2025 21:38:19 -0800 Subject: [PATCH 22/36] 1. add sana-video markdown; 2. fix typos; --- docs/source/en/api/pipelines/sana_sprint.md | 3 - docs/source/en/api/pipelines/sana_video.md | 102 ++++++++++++++++++ .../transformers/transformer_sana_video.py | 2 +- src/diffusers/pipelines/sana/pipeline_sana.py | 2 +- .../pipelines/sana/pipeline_sana_sprint.py | 2 +- .../pipelines/sana/pipeline_sana_video.py | 4 +- 6 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 docs/source/en/api/pipelines/sana_video.md diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md index 357d7e406dd4..46cdc13302ec 100644 --- a/docs/source/en/api/pipelines/sana_sprint.md +++ b/docs/source/en/api/pipelines/sana_sprint.md @@ -24,9 +24,6 @@ The abstract from the paper is: *This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.* -> [!TIP] -> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. - This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/). Available models: diff --git a/docs/source/en/api/pipelines/sana_video.md b/docs/source/en/api/pipelines/sana_video.md new file mode 100644 index 000000000000..85d77fb2944b --- /dev/null +++ b/docs/source/en/api/pipelines/sana_video.md @@ -0,0 +1,102 @@ + + +# SanaVideoPipeline + +
+ LoRA + MPS +
+ +[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie. + +The abstract from the paper is: + +*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).* + +This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video). + +Available models: + +| Model | Recommended dtype | +|:-----:|:-----------------:| +| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` | + +Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information. + +Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + +## Quantization + +Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. + +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes. + +```py +import torch +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline +from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel + +quant_config = BitsAndBytesConfig(load_in_8bit=True) +text_encoder_8bit = AutoModel.from_pretrained( + "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", + subfolder="text_encoder", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +transformer_8bit = SanaVideoTransformer3DModel.from_pretrained( + "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +pipeline = SanaVideoPipeline.from_pretrained( + "Efficient-Large-Model/SANA-Video_2B_480p_diffusers", + text_encoder=text_encoder_8bit, + transformer=transformer_8bit, + torch_dtype=torch.float16, + device_map="balanced", +) + +model_score = 30 +prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." +negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." +motion_prompt = f" motion score: {model_score}." +prompt = prompt + motion_prompt + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + num_frames=81, + guidance_scale=6.0, + num_inference_steps=50 +).frames[0] +export_to_video(output, "sana-video-output.mp4", fps=16) +``` + +## SanaVideoPipeline + +[[autodoc]] SanaVideoPipeline + - all + - __call__ + + +## SanaVideoPipelineOutput + +[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index bd1efa45129e..c3843bfac177 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index ac979305ca6d..2beff802c6e0 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -1,4 +1,4 @@ -# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 62b978829271..04f45f817efb 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -1,4 +1,4 @@ -# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index b7325df7896b..3aeea25a3f26 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -1,4 +1,4 @@ -# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -102,7 +102,7 @@ >>> pipe.to("cuda") >>> model_score = 30 - >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm色调。A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." + >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." >>> motion_prompt = f" motion score: {model_score}." >>> prompt = prompt + motion_prompt From d98f93c1b4dadacb7a10690ad888570fced8dce5 Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Tue, 4 Nov 2025 21:48:40 -0800 Subject: [PATCH 23/36] add two test case for sana-video (need check) --- .../test_models_transformer_sana_video.py | 98 +++++++++ tests/pipelines/sana/test_sana_video.py | 190 ++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_sana_video.py create mode 100644 tests/pipelines/sana/test_sana_video.py diff --git a/tests/models/transformers/test_models_transformer_sana_video.py b/tests/models/transformers/test_models_transformer_sana_video.py new file mode 100644 index 000000000000..68a66d0910a9 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_sana_video.py @@ -0,0 +1,98 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import SanaVideoTransformer3DModel + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = SanaVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 16 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (16, 2, 16, 16) + + @property + def output_shape(self): + return (16, 2, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "out_channels": 16, + "num_attention_heads": 2, + "attention_head_dim": 12, + "num_layers": 2, + "num_cross_attention_heads": 2, + "cross_attention_head_dim": 12, + "cross_attention_dim": 24, + "caption_channels": 16, + "mlp_ratio": 2.5, + "dropout": 0.0, + "attention_bias": False, + "sample_size": 8, + "patch_size": (1, 2, 2), + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SanaVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = SanaVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() + diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py new file mode 100644 index 000000000000..78c898a40c37 --- /dev/null +++ b/tests/pipelines/sana/test_sana_video.py @@ -0,0 +1,190 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoModel, AutoTokenizer + +from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class SanaVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SanaVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = DPMSolverMultistepScheduler() + + text_encoder = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gemma2") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gemma2") + + torch.manual_seed(0) + transformer = SanaVideoTransformer3DModel( + in_channels=16, + out_channels=16, + num_attention_heads=2, + attention_head_dim=12, + num_layers=2, + num_cross_attention_heads=2, + cross_attention_head_dim=12, + cross_attention_dim=24, + caption_channels=32, + mlp_ratio=2.5, + dropout=0.0, + attention_bias=False, + sample_size=8, + patch_size=(1, 2, 2), + norm_elementwise_affine=False, + norm_eps=1e-6, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A cat playing with a ball", + "negative_prompt": "blurry, low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + "complex_human_instruction": [], + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_local(self, expected_max_difference=5e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + +@slow +@require_torch_accelerator +class SanaVideoPipelineIntegrationTests(unittest.TestCase): + prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest." + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + @unittest.skip("TODO: test needs to be implemented") + def test_sana_video_480p(self): + pass + From 4569d0b7e723ee969ba17746f52676e2011ead1b Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Tue, 4 Nov 2025 22:06:48 -0800 Subject: [PATCH 24/36] fix text-encoder in test-sana-video; --- tests/pipelines/sana/test_sana_video.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 78c898a40c37..9fd0b13343ce 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import AutoModel, AutoTokenizer +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel @@ -68,8 +68,22 @@ def get_dummy_components(self): torch.manual_seed(0) scheduler = DPMSolverMultistepScheduler() - text_encoder = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gemma2") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gemma2") + torch.manual_seed(0) + text_encoder_config = Gemma2Config( + head_dim=16, + hidden_size=8, + initializer_range=0.02, + intermediate_size=64, + max_position_embeddings=8192, + model_type="gemma2", + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + vocab_size=8, + attn_implementation="eager", + ) + text_encoder = Gemma2Model(text_encoder_config) + tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") torch.manual_seed(0) transformer = SanaVideoTransformer3DModel( From 137939179b0a9f75dc38c352611f65aa62c5bc50 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 16:04:08 +0800 Subject: [PATCH 25/36] Update tests/pipelines/sana/test_sana_video.py --- tests/pipelines/sana/test_sana_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 9fd0b13343ce..be1f6261b24d 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -133,6 +133,7 @@ def get_dummy_inputs(self, device, seed=0): "max_sequence_length": 16, "output_type": "pt", "complex_human_instruction": [], + "use_resolution_binning": False, } return inputs From b359240a1f46d271e992daba95f5f18001a0316d Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 16:04:41 +0800 Subject: [PATCH 26/36] Update tests/pipelines/sana/test_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/sana/test_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index be1f6261b24d..dd1451d42549 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -95,7 +95,7 @@ def get_dummy_components(self): num_cross_attention_heads=2, cross_attention_head_dim=12, cross_attention_dim=24, - caption_channels=32, + caption_channels=8, mlp_ratio=2.5, dropout=0.0, attention_bias=False, From 7256023129a0a532e0ca9bab83d04b6a53512cfb Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 16:06:25 +0800 Subject: [PATCH 27/36] Update tests/pipelines/sana/test_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/sana/test_sana_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index dd1451d42549..1153919383c3 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -122,8 +122,8 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { - "prompt": "A cat playing with a ball", - "negative_prompt": "blurry, low quality", + "prompt": "", + "negative_prompt": "", "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, From 25d1a4c5ff4f04b7d31ae90d9beb5b7ca88b48c7 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 17:07:20 +0800 Subject: [PATCH 28/36] Update tests/pipelines/sana/test_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/sana/test_sana_video.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 1153919383c3..67baa3a1e94d 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -127,8 +127,8 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "height": 16, - "width": 16, + "height": 32, + "width": 32, "frames": 9, "max_sequence_length": 16, "output_type": "pt", From a9c16ebf4c0d0fd60dc8c301f375f2f1369700f5 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 17:07:28 +0800 Subject: [PATCH 29/36] Update tests/pipelines/sana/test_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/sana/test_sana_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 67baa3a1e94d..1aedae8564f3 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -148,7 +148,7 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): From 8a27d582404a2fdf8fa3497ec5b234e44f381000 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 17:07:56 +0800 Subject: [PATCH 30/36] Update tests/pipelines/sana/test_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- tests/pipelines/sana/test_sana_video.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 1aedae8564f3..34d8efae2292 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -183,6 +183,26 @@ def test_save_load_local(self, expected_max_difference=5e-4): max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() self.assertLess(max_diff, expected_max_difference) + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass + + def test_float16_inference(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_float16_inference(expected_max_diff=0.08) + + def test_save_load_float16(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_save_load_float16(expected_max_diff=0.2) @slow @require_torch_accelerator From 4c25427a19307ef0fa19f15021f07c4167cfc1cf Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 17:08:23 +0800 Subject: [PATCH 31/36] Update src/diffusers/pipelines/sana/pipeline_sana_video.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/pipelines/sana/pipeline_sana_video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 3aeea25a3f26..ebab9273ae4e 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -186,7 +186,6 @@ def retrieve_timesteps( class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): r""" Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). From 31c9fa5efd9b0a1e07e8bc597bf18e8d8378762a Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Wed, 5 Nov 2025 17:08:37 +0800 Subject: [PATCH 32/36] Update src/diffusers/video_processor.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- src/diffusers/video_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 881f5dfe88b5..abeb30bca102 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -138,7 +138,7 @@ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: in Args: samples (`torch.Tensor`): - A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the + A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the number of frames, H is the height, and W is the width. new_width (`int`): The desired width of the output videos. new_height (`int`): The desired height of the output videos. From 0ed7eeee7172b7a061ef3b8cf4a74f856fae923a Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Wed, 5 Nov 2025 01:12:38 -0800 Subject: [PATCH 33/36] make style make quality make fix-copies --- scripts/convert_sana_video_to_diffusers.py | 17 ++++++++++++----- .../transformers/transformer_sana_video.py | 9 +++++++-- .../pipelines/sana/pipeline_sana_video.py | 1 - src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ .../test_models_transformer_sana_video.py | 1 - tests/pipelines/sana/test_sana_video.py | 4 ++-- 7 files changed, 51 insertions(+), 11 deletions(-) diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index 67fffceb563b..fbb7c1d9e706 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -62,9 +62,13 @@ def main(args): converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") # Shared norm. @@ -84,10 +88,10 @@ def main(args): # sample size if args.video_size == 480: - sample_size = 30 # Wan-VAE: 8xp2 downsample factor + sample_size = 30 # Wan-VAE: 8xp2 downsample factor patch_size = (1, 2, 2) elif args.video_size == 720: - sample_size = 22 # Wan-VAE: 32xp1 downsample factor + sample_size = 22 # Wan-VAE: 32xp1 downsample factor patch_size = (1, 1, 1) else: raise ValueError(f"Video size {args.video_size} is not supported.") @@ -253,7 +257,10 @@ def main(args): scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) elif args.scheduler_type == "uni-pc": scheduler = UniPCMultistepScheduler( - prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift + prediction_type="flow_prediction", + use_flow_sigmas=True, + num_train_timesteps=1000, + flow_shift=flow_shift, ) else: raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index c3843bfac177..aaf96175c0e8 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -58,7 +58,9 @@ def __init__( if norm_type == "rms_norm": self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) - self.conv_temp = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False) + self.conv_temp = nn.Conv2d( + out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.residual_connection: @@ -76,7 +78,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_point(hidden_states) # Temporal aggregation - hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(0, 2, 1, 3) + hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute( + 0, 2, 1, 3 + ) hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal) hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels) @@ -275,6 +279,7 @@ class SanaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + _attention_backend = None _parallel_config = None diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index ebab9273ae4e..59e17ecc8a69 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -642,7 +642,6 @@ def _clean_caption(self, caption): return caption.strip() - def prepare_latents( self, batch_size: int, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3c426d503996..22d2d8c0a5ea 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1308,6 +1308,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SanaVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SD3ControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 20575ff2294d..e8209403de75 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2177,6 +2177,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SanaVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_sana_video.py b/tests/models/transformers/test_models_transformer_sana_video.py index 68a66d0910a9..ff564ed8918d 100644 --- a/tests/models/transformers/test_models_transformer_sana_video.py +++ b/tests/models/transformers/test_models_transformer_sana_video.py @@ -95,4 +95,3 @@ class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas def prepare_init_args_and_inputs_for_common(self): return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() - diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py index 34d8efae2292..9f360a942a64 100644 --- a/tests/pipelines/sana/test_sana_video.py +++ b/tests/pipelines/sana/test_sana_video.py @@ -67,7 +67,7 @@ def get_dummy_components(self): torch.manual_seed(0) scheduler = DPMSolverMultistepScheduler() - + torch.manual_seed(0) text_encoder_config = Gemma2Config( head_dim=16, @@ -204,6 +204,7 @@ def test_save_load_float16(self): # Requires higher tolerance as model seems very sensitive to dtype super().test_save_load_float16(expected_max_diff=0.2) + @slow @require_torch_accelerator class SanaVideoPipelineIntegrationTests(unittest.TestCase): @@ -222,4 +223,3 @@ def tearDown(self): @unittest.skip("TODO: test needs to be implemented") def test_sana_video_480p(self): pass - From e31f91b3936e1e830846279fed11a14362bb603a Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Wed, 5 Nov 2025 01:14:20 -0800 Subject: [PATCH 34/36] toctree yaml update; --- docs/source/en/_toctree.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5af95cba7490..74a1e0ef2faa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -373,6 +373,8 @@ title: QwenImageTransformer2DModel - local: api/models/sana_transformer2d title: SanaTransformer2DModel + - local: api/models/sana_video_transformer3d + title: SanaVideoTransformer3DModel - local: api/models/sd3_transformer2d title: SD3Transformer2DModel - local: api/models/skyreels_v2_transformer_3d @@ -561,6 +563,8 @@ title: QwenImage - local: api/pipelines/sana title: Sana + - local: api/pipelines/sana_video + title: Sana Video - local: api/pipelines/sana_sprint title: Sana Sprint - local: api/pipelines/self_attention_guidance From cb31fc256a8f05a4c06427a0821deb37bfcbf926 Mon Sep 17 00:00:00 2001 From: Lawrence-cj Date: Wed, 5 Nov 2025 18:39:55 -0800 Subject: [PATCH 35/36] add sana-video-transformer3d markdown; --- .../en/api/models/sana_video_transformer3d.md | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 docs/source/en/api/models/sana_video_transformer3d.md diff --git a/docs/source/en/api/models/sana_video_transformer3d.md b/docs/source/en/api/models/sana_video_transformer3d.md new file mode 100644 index 000000000000..0cf1451a2d39 --- /dev/null +++ b/docs/source/en/api/models/sana_video_transformer3d.md @@ -0,0 +1,36 @@ + + +# SanaVideoTransformer3DModel + +A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie. + +The abstract from the paper is: + +*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.* + +The model can be loaded with the following code snippet. + +```python +from diffusers import SanaVideoTransformer3DModel +import torch + +transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## SanaVideoTransformer3DModel + +[[autodoc]] SanaVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput + From f3c87f48b6c74b4331667a4a5e9fc54611388773 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 6 Nov 2025 04:35:12 +0000 Subject: [PATCH 36/36] Apply style fixes --- docs/source/en/_toctree.yml | 4 +-- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- .../pipelines/sana/pipeline_sana_video.py | 27 ++++++++++--------- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 74a1e0ef2faa..94dad286e4a3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -563,10 +563,10 @@ title: QwenImage - local: api/pipelines/sana title: Sana - - local: api/pipelines/sana_video - title: Sana Video - local: api/pipelines/sana_sprint title: Sana Sprint + - local: api/pipelines/sana_video + title: Sana Video - local: api/pipelines/self_attention_guidance title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f1000ff689c7..572aad4bd3f1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -543,9 +543,9 @@ "SanaControlNetPipeline", "SanaPAGPipeline", "SanaPipeline", - "SanaVideoPipeline", "SanaSprintImg2ImgPipeline", "SanaSprintPipeline", + "SanaVideoPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 808e6f253003..202e77fd197d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -79,7 +79,6 @@ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"] - _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] @@ -103,6 +102,7 @@ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] + _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py index 59e17ecc8a69..5ec498faffb9 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py @@ -94,6 +94,7 @@ >>> import torch >>> from diffusers import SanaVideoPipeline >>> from diffusers.utils import export_to_video + >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers" >>> pipe = SanaVideoPipeline.from_pretrained(model_id) >>> pipe.transformer.to(torch.bfloat16) @@ -108,14 +109,14 @@ >>> prompt = prompt + motion_prompt >>> output = pipe( - ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=480, - ... width=832, - ... frames=81, - ... guidance_scale=6, - ... num_inference_steps=50, - ... generator=torch.Generator(device="cuda").manual_seed(42), + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), ... ).frames[0] >>> export_to_video(output, "sana-video-output.mp4", fps=16) @@ -185,9 +186,9 @@ def retrieve_timesteps( class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): r""" - Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). + Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits + from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). Args: tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): @@ -806,8 +807,8 @@ def __call__( prompt. use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using - `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, they are resized back to - the requested resolution. Useful for generating non-square videos. + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,