1515Adapted from official impl at https://github.com/jameslahm/RepViT
1616"""
1717
18- __all__ = ['RepViT ' ]
18+ __all__ = ['RepVit ' ]
1919
2020import torch .nn as nn
2121from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@@ -81,7 +81,7 @@ def fuse(self):
8181 return m
8282
8383
84- class RepVGGDW (nn .Module ):
84+ class RepVggDw (nn .Module ):
8585 def __init__ (self , ed , kernel_size ):
8686 super ().__init__ ()
8787 self .conv = ConvNorm (ed , ed , kernel_size , 1 , (kernel_size - 1 ) // 2 , groups = ed )
@@ -115,7 +115,7 @@ def fuse(self):
115115 return conv
116116
117117
118- class RepViTMlp (nn .Module ):
118+ class RepVitMlp (nn .Module ):
119119 def __init__ (self , in_dim , hidden_dim , act_layer ):
120120 super ().__init__ ()
121121 self .conv1 = ConvNorm (in_dim , hidden_dim , 1 , 1 , 0 )
@@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
130130 def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer ):
131131 super (RepViTBlock , self ).__init__ ()
132132
133- self .token_mixer = RepVGGDW (in_dim , kernel_size )
133+ self .token_mixer = RepVggDw (in_dim , kernel_size )
134134 self .se = SqueezeExcite (in_dim , 0.25 ) if use_se else nn .Identity ()
135- self .channel_mixer = RepViTMlp (in_dim , in_dim * mlp_ratio , act_layer )
135+ self .channel_mixer = RepVitMlp (in_dim , in_dim * mlp_ratio , act_layer )
136136
137137 def forward (self , x ):
138138 x = self .token_mixer (x )
@@ -142,7 +142,7 @@ def forward(self, x):
142142 return identity + x
143143
144144
145- class RepViTStem (nn .Module ):
145+ class RepVitStem (nn .Module ):
146146 def __init__ (self , in_chs , out_chs , act_layer ):
147147 super ().__init__ ()
148148 self .conv1 = ConvNorm (in_chs , out_chs // 2 , 3 , 2 , 1 )
@@ -154,13 +154,13 @@ def forward(self, x):
154154 return self .conv2 (self .act1 (self .conv1 (x )))
155155
156156
157- class RepViTDownsample (nn .Module ):
157+ class RepVitDownsample (nn .Module ):
158158 def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer ):
159159 super ().__init__ ()
160160 self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer )
161161 self .spatial_downsample = ConvNorm (in_dim , in_dim , kernel_size , 2 , (kernel_size - 1 ) // 2 , groups = in_dim )
162162 self .channel_downsample = ConvNorm (in_dim , out_dim , 1 , 1 )
163- self .ffn = RepViTMlp (out_dim , out_dim * mlp_ratio , act_layer )
163+ self .ffn = RepVitMlp (out_dim , out_dim * mlp_ratio , act_layer )
164164
165165 def forward (self , x ):
166166 x = self .pre_block (x )
@@ -171,22 +171,25 @@ def forward(self, x):
171171 return x + identity
172172
173173
174- class RepViTClassifier (nn .Module ):
175- def __init__ (self , dim , num_classes , distillation = False ):
174+ class RepVitClassifier (nn .Module ):
175+ def __init__ (self , dim , num_classes , distillation = False , drop = 0. ):
176176 super ().__init__ ()
177+ self .head_drop = nn .Dropout (drop )
177178 self .head = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
178179 self .distillation = distillation
179- self .num_classes = num_classes
180+ self .distilled_training = False
181+ self .num_classes = num_classes
180182 if distillation :
181183 self .head_dist = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
182184
183185 def forward (self , x ):
186+ x = self .head_drop (x )
184187 if self .distillation :
185188 x1 , x2 = self .head (x ), self .head_dist (x )
186- if (not self .training ) or torch .jit .is_scripting ():
187- return (x1 + x2 ) / 2
188- else :
189+ if self .training and self .distilled_training and not torch .jit .is_scripting ():
189190 return x1 , x2
191+ else :
192+ return (x1 + x2 ) / 2
190193 else :
191194 x = self .head (x )
192195 return x
@@ -207,11 +210,11 @@ def fuse(self):
207210 return head
208211
209212
210- class RepViTStage (nn .Module ):
213+ class RepVitStage (nn .Module ):
211214 def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True ):
212215 super ().__init__ ()
213216 if downsample :
214- self .downsample = RepViTDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
217+ self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
215218 else :
216219 assert in_dim == out_dim
217220 self .downsample = nn .Identity ()
@@ -230,7 +233,7 @@ def forward(self, x):
230233 return x
231234
232235
233- class RepViT (nn .Module ):
236+ class RepVit (nn .Module ):
234237 def __init__ (
235238 self ,
236239 in_chans = 3 ,
@@ -243,15 +246,16 @@ def __init__(
243246 num_classes = 1000 ,
244247 act_layer = nn .GELU ,
245248 distillation = True ,
249+ drop_rate = 0. ,
246250 ):
247- super (RepViT , self ).__init__ ()
251+ super (RepVit , self ).__init__ ()
248252 self .grad_checkpointing = False
249253 self .global_pool = global_pool
250254 self .embed_dim = embed_dim
251255 self .num_classes = num_classes
252256
253257 in_dim = embed_dim [0 ]
254- self .stem = RepViTStem (in_chans , in_dim , act_layer )
258+ self .stem = RepVitStem (in_chans , in_dim , act_layer )
255259 stride = self .stem .stride
256260 resolution = tuple ([i // p for i , p in zip (to_2tuple (img_size ), to_2tuple (stride ))])
257261
@@ -263,7 +267,7 @@ def __init__(
263267 for i in range (num_stages ):
264268 downsample = True if i != 0 else False
265269 stages .append (
266- RepViTStage (
270+ RepVitStage (
267271 in_dim ,
268272 embed_dim [i ],
269273 depth [i ],
@@ -281,7 +285,8 @@ def __init__(
281285 self .stages = nn .Sequential (* stages )
282286
283287 self .num_features = embed_dim [- 1 ]
284- self .head = RepViTClassifier (embed_dim [- 1 ], num_classes , distillation )
288+ self .head_drop = nn .Dropout (drop_rate )
289+ self .head = RepVitClassifier (embed_dim [- 1 ], num_classes , distillation )
285290
286291 @torch .jit .ignore
287292 def group_matcher (self , coarse = False ):
@@ -304,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False):
304309 if global_pool is not None :
305310 self .global_pool = global_pool
306311 self .head = (
307- RepViTClassifier (self .embed_dim [- 1 ], num_classes , distillation ) if num_classes > 0 else nn .Identity ()
312+ RepVitClassifier (self .embed_dim [- 1 ], num_classes , distillation ) if num_classes > 0 else nn .Identity ()
308313 )
309314
315+ @torch .jit .ignore
316+ def set_distilled_training (self , enable = True ):
317+ self .head .distilled_training = enable
318+
310319 def forward_features (self , x ):
311320 x = self .stem (x )
312321 if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -317,8 +326,9 @@ def forward_features(self, x):
317326
318327 def forward_head (self , x , pre_logits : bool = False ):
319328 if self .global_pool == 'avg' :
320- x = nn .functional .adaptive_avg_pool2d (x , 1 ).flatten (1 )
321- return x if pre_logits else self .head (x )
329+ x = x .mean ((2 , 3 ), keepdim = False )
330+ x = self .head_drop (x )
331+ return self .head (x )
322332
323333 def forward (self , x ):
324334 x = self .forward_features (x )
@@ -373,7 +383,9 @@ def _cfg(url='', **kwargs):
373383def _create_repvit (variant , pretrained = False , ** kwargs ):
374384 out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 ))
375385 model = build_model_with_cfg (
376- RepViT , variant , pretrained , feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ), ** kwargs
386+ RepVit , variant , pretrained ,
387+ feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
388+ ** kwargs ,
377389 )
378390 return model
379391
0 commit comments