File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change 1212
1313class Mlp (nn .Module ):
1414 """ MLP as used in Vision Transformer, MLP-Mixer and related networks
15+
16+ NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
1517 """
1618 def __init__ (
1719 self ,
@@ -51,6 +53,8 @@ def forward(self, x):
5153class GluMlp (nn .Module ):
5254 """ MLP w/ GLU style gating
5355 See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
56+
57+ NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
5458 """
5559 def __init__ (
5660 self ,
@@ -192,7 +196,7 @@ def forward(self, x):
192196
193197
194198class ConvMlp (nn .Module ):
195- """ MLP using 1x1 convs that keeps spatial dims
199+ """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
196200 """
197201 def __init__ (
198202 self ,
@@ -226,6 +230,8 @@ def forward(self, x):
226230
227231class GlobalResponseNormMlp (nn .Module ):
228232 """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
233+
234+ NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
229235 """
230236 def __init__ (
231237 self ,
You can’t perform that action at this time.
0 commit comments