22import torch
33import torch .nn as nn
44
5- from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d
5+ from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d , MultiQueryAttentionV2
66
77import importlib
88import os
@@ -121,6 +121,23 @@ def test_get_act_fn_none():
121121 assert get_act_fn ('' ) is None
122122
123123
124+ @pytest .mark .parametrize ("dim" , [128 ])
125+ @pytest .mark .parametrize ("dim_out" , [128 , 256 ])
126+ @pytest .mark .parametrize ("use_m" , [True , False ])
127+ def test_mqa_v2 (dim , dim_out , use_m ):
128+ mqa = MultiQueryAttentionV2 (dim , dim_out )
129+
130+ x = torch .randn (1 , dim , 32 , 48 )
131+ if use_m :
132+ m = torch .randn (1 , dim , 16 , 24 )
133+ else :
134+ m = None
135+
136+ y = mqa (x , m = m )
137+
138+ assert (y .shape ) == (1 , dim_out , 32 , 48 )
139+
140+
124141@pytest .mark .parametrize ("bias" , [True , False ])
125142@pytest .mark .parametrize ("expand_first" , [True , False ])
126143@pytest .mark .parametrize ("head_first" , [True , False ])
@@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
141158 o2 = attn (x , mask )
142159
143160 assert torch .allclose (o1 , o2 , atol = 1e-5 ), f"{ torch .abs (o1 - o2 ).max ()} "
144-
145-
146-
0 commit comments