From af9814b9073df6e92b4f388eba3e92b83aa73aeb Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 4 Nov 2025 10:21:56 +0000 Subject: [PATCH 1/2] feat: enable attention dispatch for huanyuan video --- .../transformers/transformer_hunyuan_video.py | 59 ++++++++++++------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index bc857ccab463..8af1e0770e41 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -24,6 +24,7 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention, AttentionProcessor from ..cache_utils import CacheMixin from ..embeddings import ( @@ -42,6 +43,9 @@ class HunyuanVideoAttnProcessor2_0: + _attention_backend = None + _parallel_config = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -64,9 +68,9 @@ def __call__( key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) # 2. QK normalization if attn.norm_q is not None: @@ -81,21 +85,29 @@ def __call__( if attn.add_q_proj is None and encoder_hidden_states is not None: query = torch.cat( [ - apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), - query[:, :, -encoder_hidden_states.shape[1] :], + apply_rotary_emb( + query[:, :-encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + query[:, -encoder_hidden_states.shape[1] :], ], - dim=2, + dim=1, ) key = torch.cat( [ - apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), - key[:, :, -encoder_hidden_states.shape[1] :], + apply_rotary_emb( + key[:, : -encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + key[:, -encoder_hidden_states.shape[1] :], ], - dim=2, + dim=1, ) else: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) # 4. Encoder condition QKV projection and normalization if attn.add_q_proj is not None and encoder_hidden_states is not None: @@ -103,24 +115,31 @@ def __call__( encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) if attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) - query = torch.cat([query, encoder_query], dim=2) - key = torch.cat([key, encoder_key], dim=2) - value = torch.cat([value, encoder_value], dim=2) + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) # 5. Attention - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) # 6. Output projection From cf296d0196598287b1cf290deec1e242153647d0 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 4 Nov 2025 10:21:56 +0000 Subject: [PATCH 2/2] feat: enable attention dispatch for huanyuan video --- docs/source/en/_toctree.yml | 4 ++-- .../models/transformers/transformer_hunyuan_video.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 251eb25899ce..5af95cba7490 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -636,6 +636,8 @@ title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL + - local: api/pipelines/kandinsky5_video + title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte - local: api/pipelines/ltx_video @@ -654,8 +656,6 @@ title: Text2Video-Zero - local: api/pipelines/wan title: Wan - - local: api/pipelines/kandinsky5_video - title: Kandinsky 5.0 Video title: Video title: Pipelines - sections: diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 8af1e0770e41..c564d4e40db0 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -86,7 +86,7 @@ def __call__( query = torch.cat( [ apply_rotary_emb( - query[:, :-encoder_hidden_states.shape[1]], + query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1, ),