@@ -649,6 +649,86 @@ def _(
649649# ===== Helper functions to use attention backends with templated CP autograd functions =====
650650
651651
652+ def _native_attention_forward_op (
653+ ctx : torch .autograd .function .FunctionCtx ,
654+ query : torch .Tensor ,
655+ key : torch .Tensor ,
656+ value : torch .Tensor ,
657+ attn_mask : Optional [torch .Tensor ] = None ,
658+ dropout_p : float = 0.0 ,
659+ is_causal : bool = False ,
660+ scale : Optional [float ] = None ,
661+ enable_gqa : bool = False ,
662+ return_lse : bool = False ,
663+ _save_ctx : bool = True ,
664+ _parallel_config : Optional ["ParallelConfig" ] = None ,
665+ ):
666+ # Native attention does not return_lse
667+ if return_lse :
668+ raise ValueError ("Native attention does not support return_lse=True" )
669+
670+ # used for backward pass
671+ if _save_ctx :
672+ ctx .save_for_backward (query , key , value )
673+ ctx .attn_mask = attn_mask
674+ ctx .dropout_p = dropout_p
675+ ctx .is_causal = is_causal
676+ ctx .scale = scale
677+ ctx .enable_gqa = enable_gqa
678+
679+ query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
680+ out = torch .nn .functional .scaled_dot_product_attention (
681+ query = query ,
682+ key = key ,
683+ value = value ,
684+ attn_mask = attn_mask ,
685+ dropout_p = dropout_p ,
686+ is_causal = is_causal ,
687+ scale = scale ,
688+ enable_gqa = enable_gqa ,
689+ )
690+ out = out .permute (0 , 2 , 1 , 3 )
691+
692+ return out
693+
694+
695+ def _native_attention_backward_op (
696+ ctx : torch .autograd .function .FunctionCtx ,
697+ grad_out : torch .Tensor ,
698+ * args ,
699+ ** kwargs ,
700+ ):
701+ query , key , value = ctx .saved_tensors
702+
703+ query .requires_grad_ (True )
704+ key .requires_grad_ (True )
705+ value .requires_grad_ (True )
706+
707+ query_t , key_t , value_t = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
708+ out = torch .nn .functional .scaled_dot_product_attention (
709+ query = query_t ,
710+ key = key_t ,
711+ value = value_t ,
712+ attn_mask = ctx .attn_mask ,
713+ dropout_p = ctx .dropout_p ,
714+ is_causal = ctx .is_causal ,
715+ scale = ctx .scale ,
716+ enable_gqa = ctx .enable_gqa ,
717+ )
718+ out = out .permute (0 , 2 , 1 , 3 )
719+
720+ grad_out_t = grad_out .permute (0 , 2 , 1 , 3 )
721+ grad_query_t , grad_key_t , grad_value_t = torch .autograd .grad (
722+ outputs = out , inputs = [query_t , key_t , value_t ], grad_outputs = grad_out_t , retain_graph = False
723+ )
724+
725+ grad_query = grad_query_t .permute (0 , 2 , 1 , 3 )
726+ grad_key = grad_key_t .permute (0 , 2 , 1 , 3 )
727+ grad_value = grad_value_t .permute (0 , 2 , 1 , 3 )
728+
729+ return grad_query , grad_key , grad_value
730+
731+
652732# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
653733# forward declaration:
654734# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1523,6 +1603,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
15231603@_AttentionBackendRegistry .register (
15241604 AttentionBackendName .NATIVE ,
15251605 constraints = [_check_device , _check_shape ],
1606+ supports_context_parallel = True ,
15261607)
15271608def _native_attention (
15281609 query : torch .Tensor ,
@@ -1538,18 +1619,35 @@ def _native_attention(
15381619) -> torch .Tensor :
15391620 if return_lse :
15401621 raise ValueError ("Native attention backend does not support setting `return_lse=True`." )
1541- query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
1542- out = torch .nn .functional .scaled_dot_product_attention (
1543- query = query ,
1544- key = key ,
1545- value = value ,
1546- attn_mask = attn_mask ,
1547- dropout_p = dropout_p ,
1548- is_causal = is_causal ,
1549- scale = scale ,
1550- enable_gqa = enable_gqa ,
1551- )
1552- out = out .permute (0 , 2 , 1 , 3 )
1622+ if _parallel_config is None :
1623+ query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
1624+ out = torch .nn .functional .scaled_dot_product_attention (
1625+ query = query ,
1626+ key = key ,
1627+ value = value ,
1628+ attn_mask = attn_mask ,
1629+ dropout_p = dropout_p ,
1630+ is_causal = is_causal ,
1631+ scale = scale ,
1632+ enable_gqa = enable_gqa ,
1633+ )
1634+ out = out .permute (0 , 2 , 1 , 3 )
1635+ else :
1636+ out = _templated_context_parallel_attention (
1637+ query ,
1638+ key ,
1639+ value ,
1640+ attn_mask ,
1641+ dropout_p ,
1642+ is_causal ,
1643+ scale ,
1644+ enable_gqa ,
1645+ return_lse ,
1646+ forward_op = _native_attention_forward_op ,
1647+ backward_op = _native_attention_backward_op ,
1648+ _parallel_config = _parallel_config ,
1649+ )
1650+
15531651 return out
15541652
15551653
0 commit comments