diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 5af95cba7490..94dad286e4a3 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -373,6 +373,8 @@
title: QwenImageTransformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
+ - local: api/models/sana_video_transformer3d
+ title: SanaVideoTransformer3DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/skyreels_v2_transformer_3d
@@ -563,6 +565,8 @@
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
+ - local: api/pipelines/sana_video
+ title: Sana Video
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
diff --git a/docs/source/en/api/models/sana_video_transformer3d.md b/docs/source/en/api/models/sana_video_transformer3d.md
new file mode 100644
index 000000000000..0cf1451a2d39
--- /dev/null
+++ b/docs/source/en/api/models/sana_video_transformer3d.md
@@ -0,0 +1,36 @@
+
+
+# SanaVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
+
+The abstract from the paper is:
+
+*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import SanaVideoTransformer3DModel
+import torch
+
+transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## SanaVideoTransformer3DModel
+
+[[autodoc]] SanaVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+
diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md
index 357d7e406dd4..46cdc13302ec 100644
--- a/docs/source/en/api/pipelines/sana_sprint.md
+++ b/docs/source/en/api/pipelines/sana_sprint.md
@@ -24,9 +24,6 @@ The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
-> [!TIP]
-> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
diff --git a/docs/source/en/api/pipelines/sana_video.md b/docs/source/en/api/pipelines/sana_video.md
new file mode 100644
index 000000000000..85d77fb2944b
--- /dev/null
+++ b/docs/source/en/api/pipelines/sana_video.md
@@ -0,0 +1,102 @@
+
+
+# SanaVideoPipeline
+
+
+

+

+
+
+[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
+
+The abstract from the paper is:
+
+*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
+
+This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-----:|:-----------------:|
+| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
+
+Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
+
+Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = AutoModel.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = SanaVideoPipeline.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+model_score = 30
+prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_prompt = f" motion score: {model_score}."
+prompt = prompt + motion_prompt
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ num_frames=81,
+ guidance_scale=6.0,
+ num_inference_steps=50
+).frames[0]
+export_to_video(output, "sana-video-output.mp4", fps=16)
+```
+
+## SanaVideoPipeline
+
+[[autodoc]] SanaVideoPipeline
+ - all
+ - __call__
+
+
+## SanaVideoPipelineOutput
+
+[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py
new file mode 100644
index 000000000000..fbb7c1d9e706
--- /dev/null
+++ b/scripts/convert_sana_video_to_diffusers.py
@@ -0,0 +1,324 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+import argparse
+import os
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from termcolor import colored
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from diffusers import (
+ AutoencoderKLWan,
+ DPMSolverMultistepScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ SanaVideoPipeline,
+ SanaVideoTransformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
+# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
+
+
+def main(args):
+ cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
+
+ if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
+ ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
+ snapshot_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ file_path = hf_hub_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ else:
+ file_path = args.orig_ckpt_path
+
+ print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
+ all_state_dict = torch.load(file_path, weights_only=True)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ # Shared norm.
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ # y norm
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
+
+ # scheduler
+ flow_shift = 8.0
+
+ # model config
+ layer_num = 20
+ # Positional embedding interpolation scale.
+ qk_norm = True
+
+ # sample size
+ if args.video_size == 480:
+ sample_size = 30 # Wan-VAE: 8xp2 downsample factor
+ patch_size = (1, 2, 2)
+ elif args.video_size == 720:
+ sample_size = 22 # Wan-VAE: 32xp1 downsample factor
+ patch_size = (1, 1, 1)
+ else:
+ raise ValueError(f"Video size {args.video_size} is not supported.")
+
+ for depth in range(layer_num):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"blocks.{depth}.scale_shift_table"
+ )
+
+ # Linear Attention is all you need 🤘
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ if qk_norm is not None:
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.k_norm.weight"
+ )
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.point_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.t_conv.weight"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+ if qk_norm is not None:
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.k_norm.weight"
+ )
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.bias"
+ )
+
+ # Final block.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
+
+ # Transformer
+ with CTX():
+ transformer_kwargs = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 20,
+ "attention_head_dim": 112,
+ "num_layers": 20,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "caption_channels": 2304,
+ "mlp_ratio": 3.0,
+ "attention_bias": False,
+ "sample_size": sample_size,
+ "patch_size": patch_size,
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 1024,
+ }
+
+ transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
+
+ transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
+
+ try:
+ state_dict.pop("y_embedder.y_embedding")
+ state_dict.pop("pos_embed")
+ state_dict.pop("logvar_linear.weight")
+ state_dict.pop("logvar_linear.bias")
+ except KeyError:
+ print("y_embedder.y_embedding or pos_embed not found in the state_dict")
+
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ transformer = transformer.to(weight_dtype)
+
+ if not args.save_full_pipeline:
+ print(
+ colored(
+ f"Only saving transformer model of {args.model_type}. "
+ f"Set --save_full_pipeline to save the whole Pipeline",
+ "green",
+ attrs=["bold"],
+ )
+ )
+ transformer.save_pretrained(
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
+ )
+ else:
+ print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
+ # VAE
+ vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+
+ # Text Encoder
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
+ tokenizer.padding_side = "right"
+ text_encoder = AutoModelForCausalLM.from_pretrained(
+ text_encoder_model_path, torch_dtype=torch.bfloat16
+ ).get_decoder()
+
+ # Choose the appropriate pipeline and scheduler based on model type
+ # Original Sana scheduler
+ if args.scheduler_type == "flow-dpm_solver":
+ scheduler = DPMSolverMultistepScheduler(
+ flow_shift=flow_shift,
+ use_flow_sigmas=True,
+ prediction_type="flow_prediction",
+ )
+ elif args.scheduler_type == "flow-euler":
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
+ elif args.scheduler_type == "uni-pc":
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction",
+ use_flow_sigmas=True,
+ num_train_timesteps=1000,
+ flow_shift=flow_shift,
+ )
+ else:
+ raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
+
+ pipe = SanaVideoPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--video_size",
+ default=480,
+ type=int,
+ choices=[480, 720],
+ required=False,
+ help="Video size of pretrained model, 480 or 720.",
+ )
+ parser.add_argument(
+ "--model_type",
+ default="SanaVideo",
+ type=str,
+ choices=[
+ "SanaVideo",
+ ],
+ )
+ parser.add_argument(
+ "--scheduler_type",
+ default="flow-dpm_solver",
+ type=str,
+ choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
+ help="Scheduler type to use.",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
+ parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+
+ args = parser.parse_args()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ weight_dtype = DTYPE_MAPPING[args.dtype]
+
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 94104667b541..572aad4bd3f1 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -246,6 +246,7 @@
"QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
+ "SanaVideoTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -544,6 +545,7 @@
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
+ "SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -951,6 +953,7 @@
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -1219,6 +1222,7 @@
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
+ SanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index e3b297464143..202e77fd197d 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -102,6 +102,7 @@
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
+ _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -204,6 +205,7 @@
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 2fe1159eec4c..15408a4b15cc 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -36,6 +36,7 @@
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
+ from .transformer_sana_video import SanaVideoTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py
new file mode 100644
index 000000000000..aaf96175c0e8
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_sana_video.py
@@ -0,0 +1,703 @@
+# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GLUMBTempConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 4,
+ norm_type: Optional[str] = None,
+ residual_connection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_channels = int(expand_ratio * in_channels)
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ self.nonlinearity = nn.SiLU()
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
+
+ self.norm = None
+ if norm_type == "rms_norm":
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
+
+ self.conv_temp = nn.Conv2d(
+ out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.residual_connection:
+ residual = hidden_states
+ batch_size, num_frames, height, width, num_channels = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
+
+ hidden_states = self.conv_inverted(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv_depth(hidden_states)
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
+ hidden_states = hidden_states * self.nonlinearity(gate)
+
+ hidden_states = self.conv_point(hidden_states)
+
+ # Temporal aggregation
+ hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
+ 0, 2, 1, 3
+ )
+ hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SanaLinearAttnProcessor3_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+ # B,N,H,C
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query_rotate = apply_rotary_emb(query, *rotary_emb)
+ key_rotate = apply_rotary_emb(key, *rotary_emb)
+
+ # B,H,C,N
+ query = query.permute(0, 2, 3, 1)
+ key = key.permute(0, 2, 3, 1)
+ query_rotate = query_rotate.permute(0, 2, 3, 1)
+ key_rotate = key_rotate.permute(0, 2, 3, 1)
+ value = value.permute(0, 2, 3, 1)
+
+ query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
+
+ z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15)
+
+ scores = torch.matmul(value, key_rotate.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query_rotate)
+
+ hidden_states = hidden_states * z
+ # B,H,C,N
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
+ hidden_states = hidden_states.to(original_dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
+class SanaModulatedNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.norm(hidden_states)
+ shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states
+
+
+class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, embedding_dim):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ guidance_proj = self.guidance_condition_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
+ conditioning = timesteps_emb + guidance_emb
+
+ return self.linear(self.silu(conditioning)), conditioning
+
+
+class SanaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SanaVideoTransformerBlock(nn.Module):
+ r"""
+ Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695).
+ """
+
+ def __init__(
+ self,
+ dim: int = 2240,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ dropout: float = 0.0,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ attention_bias: bool = True,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ attention_out_bias: bool = True,
+ mlp_ratio: float = 3.0,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_attention_heads if qk_norm is not None else None,
+ qk_norm=qk_norm,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ processor=SanaLinearAttnProcessor3_0(),
+ )
+
+ # 2. Cross Attention
+ if cross_attention_dim is not None:
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.attn2 = Attention(
+ query_dim=dim,
+ qk_norm=qk_norm,
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_cross_attention_heads,
+ dim_head=cross_attention_head_dim,
+ dropout=dropout,
+ bias=True,
+ out_bias=attention_out_bias,
+ processor=SanaAttnProcessor2_0(),
+ )
+
+ # 3. Feed-forward
+ self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ frames: int = None,
+ height: int = None,
+ width: int = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # 1. Modulation
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+
+ # 2. Self Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
+
+ attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = hidden_states + gate_msa * attn_output
+
+ # 3. Cross Attention
+ if self.attn2 is not None:
+ attn_output = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = ff_output.flatten(1, 3)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ return hidden_states
+
+
+class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
+ r"""
+ A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `20`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `112`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of Transformer blocks to use.
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
+ The number of heads to use for cross-attention.
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
+ The number of channels in each head for cross-attention.
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
+ The number of channels in the cross-attention output.
+ caption_channels (`int`, defaults to `2304`):
+ The number of channels in the caption embeddings.
+ mlp_ratio (`float`, defaults to `2.5`):
+ The expansion ratio to use in the GLUMBConv layer.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability.
+ attention_bias (`bool`, defaults to `False`):
+ Whether to use bias in the attention layer.
+ sample_size (`int`, defaults to `32`):
+ The base size of the input latent.
+ patch_size (`int`, defaults to `1`):
+ The size of the patches to use in the patch embedding layer.
+ norm_elementwise_affine (`bool`, defaults to `False`):
+ Whether to use elementwise affinity in the normalization layer.
+ norm_eps (`float`, defaults to `1e-6`):
+ The epsilon value for the normalization layer.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for the query and key.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"]
+ _skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ num_layers: int = 20,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ caption_channels: int = 2304,
+ mlp_ratio: float = 2.5,
+ dropout: float = 0.0,
+ attention_bias: bool = False,
+ sample_size: int = 30,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ interpolation_scale: Optional[int] = None,
+ guidance_embeds: bool = False,
+ guidance_embeds_scale: float = 0.1,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Additional condition embeddings
+ if guidance_embeds:
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
+ else:
+ self.time_embed = AdaLayerNormSingle(inner_dim)
+
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SanaVideoTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ num_cross_attention_heads=num_cross_attention_heads,
+ cross_attention_head_dim=cross_attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output blocks
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ guidance: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ if guidance is not None:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
+
+ # 2. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ else:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ # 3. Normalization
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index db357669b6f3..87d953845e21 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -308,6 +308,7 @@
"SanaSprintPipeline",
"SanaControlNetPipeline",
"SanaSprintImg2ImgPipeline",
+ "SanaVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -735,7 +736,13 @@
QwenImageInpaintPipeline,
QwenImagePipeline,
)
- from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
+ from .sana import (
+ SanaControlNetPipeline,
+ SanaPipeline,
+ SanaSprintImg2ImgPipeline,
+ SanaSprintPipeline,
+ SanaVideoPipeline,
+ )
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py
index 91684f35f153..d5571ab12fac 100644
--- a/src/diffusers/pipelines/sana/__init__.py
+++ b/src/diffusers/pipelines/sana/__init__.py
@@ -26,6 +26,7 @@
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
+ _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -39,6 +40,7 @@
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
+ from .pipeline_sana_video import SanaVideoPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py
index f8ac12951644..8021b7738755 100644
--- a/src/diffusers/pipelines/sana/pipeline_output.py
+++ b/src/diffusers/pipelines/sana/pipeline_output.py
@@ -3,6 +3,7 @@
import numpy as np
import PIL.Image
+import torch
from ...utils import BaseOutput
@@ -19,3 +20,18 @@ class SanaPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+@dataclass
+class SanaVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Sana-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index ac979305ca6d..2beff802c6e0 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index 62b978829271..04f45f817efb 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -1,4 +1,4 @@
-# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py
new file mode 100644
index 000000000000..5ec498faffb9
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py
@@ -0,0 +1,1017 @@
+# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SanaVideoPipelineOutput
+
+
+ASPECT_RATIO_480_BIN = {
+ "0.5": [448.0, 896.0],
+ "0.57": [480.0, 832.0],
+ "0.68": [528.0, 768.0],
+ "0.78": [560.0, 720.0],
+ "1.0": [624.0, 624.0],
+ "1.13": [672.0, 592.0],
+ "1.29": [720.0, 560.0],
+ "1.46": [768.0, 528.0],
+ "1.67": [816.0, 496.0],
+ "1.75": [832.0, 480.0],
+ "2.0": [896.0, 448.0],
+}
+
+
+ASPECT_RATIO_720_BIN = {
+ "0.5": [672.0, 1344.0],
+ "0.57": [704.0, 1280.0],
+ "0.68": [800.0, 1152.0],
+ "0.78": [832.0, 1088.0],
+ "1.0": [960.0, 960.0],
+ "1.13": [1024.0, 896.0],
+ "1.29": [1088.0, 832.0],
+ "1.46": [1152.0, 800.0],
+ "1.67": [1248.0, 736.0],
+ "1.75": [1280.0, 704.0],
+ "2.0": [1344.0, 672.0],
+}
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
+ >>> pipe = SanaVideoPipeline.from_pretrained(model_id)
+ >>> pipe.transformer.to(torch.bfloat16)
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+ >>> model_score = 30
+
+ >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
+ >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+ >>> motion_prompt = f" motion score: {model_score}."
+ >>> prompt = prompt + motion_prompt
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... frames=81,
+ ... guidance_scale=6,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator(device="cuda").manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_video(output, "sana-video-output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits
+ from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all
+ pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]):
+ The tokenizer used to tokenize the prompt.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Text encoder model to encode the input prompts.
+ vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer ([`SanaVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`DPMSolverMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: Union[AutoencoderDC, AutoencoderKLWan],
+ transformer: SanaVideoTransformer3DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.vae_scale_factor = self.vae_scale_factor_spatial
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of videos that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ height: int = 480,
+ width: int = 832,
+ frames: int = 81,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaVideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ height (`int`, *optional*, defaults to 480):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to 832):
+ The width in pixels of the generated video.
+ frames (`int`, *optional*, defaults to 81):
+ The number of frames in the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between mp4 or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos,
+ they are resized back to the requested resolution. Useful for generating non-square videos.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated videos
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 30:
+ aspect_ratio_bin = ASPECT_RATIO_480_BIN
+ elif self.transformer.config.sample_size == 22:
+ aspect_ratio_bin = ASPECT_RATIO_720_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ try:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+
+ if use_resolution_binning:
+ video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height)
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SanaVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 3c426d503996..22d2d8c0a5ea 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1308,6 +1308,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SanaVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SD3ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 20575ff2294d..e8209403de75 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2177,6 +2177,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class SanaVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py
index 59b59b47d2c7..abeb30bca102 100644
--- a/src/diffusers/video_processor.py
+++ b/src/diffusers/video_processor.py
@@ -13,11 +13,12 @@
# limitations under the License.
import warnings
-from typing import List, Optional, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
+import torch.nn.functional as F
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
@@ -111,3 +112,65 @@ def postprocess_video(
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
+
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
+ r"""
+ Returns the binned height and width based on the aspect ratio.
+
+ Args:
+ height (`int`): The height of the image.
+ width (`int`): The width of the image.
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
+
+ Returns:
+ `Tuple[int, int]`: The closest binned height and width.
+ """
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
+ r"""
+ Resizes and crops a tensor of videos to the specified dimensions.
+
+ Args:
+ samples (`torch.Tensor`):
+ A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the
+ number of frames, H is the height, and W is the width.
+ new_width (`int`): The desired width of the output videos.
+ new_height (`int`): The desired height of the output videos.
+
+ Returns:
+ `torch.Tensor`: A tensor containing the resized and cropped videos.
+ """
+ orig_height, orig_width = samples.shape[3], samples.shape[4]
+
+ # Check if resizing is needed
+ if orig_height != new_height or orig_width != new_width:
+ ratio = max(new_height / orig_height, new_width / orig_width)
+ resized_width = int(orig_width * ratio)
+ resized_height = int(orig_height * ratio)
+
+ # Reshape to (N*T, C, H, W) for interpolation
+ n, c, t, h, w = samples.shape
+ samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w)
+
+ # Resize
+ samples = F.interpolate(
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ # Center Crop
+ start_x = (resized_width - new_width) // 2
+ end_x = start_x + new_width
+ start_y = (resized_height - new_height) // 2
+ end_y = start_y + new_height
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
+
+ # Reshape back to (N, C, T, H, W)
+ samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4)
+
+ return samples
diff --git a/tests/models/transformers/test_models_transformer_sana_video.py b/tests/models/transformers/test_models_transformer_sana_video.py
new file mode 100644
index 000000000000..ff564ed8918d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_sana_video.py
@@ -0,0 +1,97 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 16
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 2, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 2, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 2,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 12,
+ "cross_attention_dim": 24,
+ "caption_channels": 16,
+ "mlp_ratio": 2.5,
+ "dropout": 0.0,
+ "attention_bias": False,
+ "sample_size": 8,
+ "patch_size": (1, 2, 2),
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SanaVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py
new file mode 100644
index 000000000000..9f360a942a64
--- /dev/null
+++ b/tests/pipelines/sana/test_sana_video.py
@@ -0,0 +1,225 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = DPMSolverMultistepScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ transformer = SanaVideoTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=2,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=12,
+ cross_attention_dim=24,
+ caption_channels=8,
+ mlp_ratio=2.5,
+ dropout=0.0,
+ attention_bias=False,
+ sample_size=8,
+ patch_size=(1, 2, 2),
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": [],
+ "use_resolution_binning": False,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ def test_save_load_float16(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_save_load_float16(expected_max_diff=0.2)
+
+
+@slow
+@require_torch_accelerator
+class SanaVideoPipelineIntegrationTests(unittest.TestCase):
+ prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_sana_video_480p(self):
+ pass