Skip to content

Commit 1ec28a2

Browse files
sywangyisayakpaul
andauthored
ulysses enabling in native attention path (#12563)
* ulysses enabling in native attention path Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * address review comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add supports_context_parallel for native attention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update templated attention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent de6173c commit 1ec28a2

File tree

1 file changed

+110
-12
lines changed

1 file changed

+110
-12
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,86 @@ def _(
649649
# ===== Helper functions to use attention backends with templated CP autograd functions =====
650650

651651

652+
def _native_attention_forward_op(
653+
ctx: torch.autograd.function.FunctionCtx,
654+
query: torch.Tensor,
655+
key: torch.Tensor,
656+
value: torch.Tensor,
657+
attn_mask: Optional[torch.Tensor] = None,
658+
dropout_p: float = 0.0,
659+
is_causal: bool = False,
660+
scale: Optional[float] = None,
661+
enable_gqa: bool = False,
662+
return_lse: bool = False,
663+
_save_ctx: bool = True,
664+
_parallel_config: Optional["ParallelConfig"] = None,
665+
):
666+
# Native attention does not return_lse
667+
if return_lse:
668+
raise ValueError("Native attention does not support return_lse=True")
669+
670+
# used for backward pass
671+
if _save_ctx:
672+
ctx.save_for_backward(query, key, value)
673+
ctx.attn_mask = attn_mask
674+
ctx.dropout_p = dropout_p
675+
ctx.is_causal = is_causal
676+
ctx.scale = scale
677+
ctx.enable_gqa = enable_gqa
678+
679+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
680+
out = torch.nn.functional.scaled_dot_product_attention(
681+
query=query,
682+
key=key,
683+
value=value,
684+
attn_mask=attn_mask,
685+
dropout_p=dropout_p,
686+
is_causal=is_causal,
687+
scale=scale,
688+
enable_gqa=enable_gqa,
689+
)
690+
out = out.permute(0, 2, 1, 3)
691+
692+
return out
693+
694+
695+
def _native_attention_backward_op(
696+
ctx: torch.autograd.function.FunctionCtx,
697+
grad_out: torch.Tensor,
698+
*args,
699+
**kwargs,
700+
):
701+
query, key, value = ctx.saved_tensors
702+
703+
query.requires_grad_(True)
704+
key.requires_grad_(True)
705+
value.requires_grad_(True)
706+
707+
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
708+
out = torch.nn.functional.scaled_dot_product_attention(
709+
query=query_t,
710+
key=key_t,
711+
value=value_t,
712+
attn_mask=ctx.attn_mask,
713+
dropout_p=ctx.dropout_p,
714+
is_causal=ctx.is_causal,
715+
scale=ctx.scale,
716+
enable_gqa=ctx.enable_gqa,
717+
)
718+
out = out.permute(0, 2, 1, 3)
719+
720+
grad_out_t = grad_out.permute(0, 2, 1, 3)
721+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
722+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
723+
)
724+
725+
grad_query = grad_query_t.permute(0, 2, 1, 3)
726+
grad_key = grad_key_t.permute(0, 2, 1, 3)
727+
grad_value = grad_value_t.permute(0, 2, 1, 3)
728+
729+
return grad_query, grad_key, grad_value
730+
731+
652732
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
653733
# forward declaration:
654734
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1523,6 +1603,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
15231603
@_AttentionBackendRegistry.register(
15241604
AttentionBackendName.NATIVE,
15251605
constraints=[_check_device, _check_shape],
1606+
supports_context_parallel=True,
15261607
)
15271608
def _native_attention(
15281609
query: torch.Tensor,
@@ -1538,18 +1619,35 @@ def _native_attention(
15381619
) -> torch.Tensor:
15391620
if return_lse:
15401621
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
1541-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1542-
out = torch.nn.functional.scaled_dot_product_attention(
1543-
query=query,
1544-
key=key,
1545-
value=value,
1546-
attn_mask=attn_mask,
1547-
dropout_p=dropout_p,
1548-
is_causal=is_causal,
1549-
scale=scale,
1550-
enable_gqa=enable_gqa,
1551-
)
1552-
out = out.permute(0, 2, 1, 3)
1622+
if _parallel_config is None:
1623+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1624+
out = torch.nn.functional.scaled_dot_product_attention(
1625+
query=query,
1626+
key=key,
1627+
value=value,
1628+
attn_mask=attn_mask,
1629+
dropout_p=dropout_p,
1630+
is_causal=is_causal,
1631+
scale=scale,
1632+
enable_gqa=enable_gqa,
1633+
)
1634+
out = out.permute(0, 2, 1, 3)
1635+
else:
1636+
out = _templated_context_parallel_attention(
1637+
query,
1638+
key,
1639+
value,
1640+
attn_mask,
1641+
dropout_p,
1642+
is_causal,
1643+
scale,
1644+
enable_gqa,
1645+
return_lse,
1646+
forward_op=_native_attention_forward_op,
1647+
backward_op=_native_attention_backward_op,
1648+
_parallel_config=_parallel_config,
1649+
)
1650+
15531651
return out
15541652

15551653

0 commit comments

Comments
 (0)