11from functools import partial
2- from typing import Callable , List , Optional , Sequence , Tuple , Union
2+ from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
33
44import torch
55import torch .nn as nn
@@ -129,12 +129,12 @@ def __init__(
129129 num_classes : int = 1000 ,
130130 in_chans : int = 3 ,
131131 stem_size : int = 16 ,
132- stem_bias : bool = False ,
132+ stem_bias : bool = True ,
133133 fix_stem : bool = False ,
134134 num_features : int = 2048 ,
135135 pad_type : str = '' ,
136136 use_msfa : bool = True ,
137- msfa_indices : List [int ] = (- 3 , - 2 , - 1 ),
137+ msfa_indices : List [int ] = (- 2 , - 1 ),
138138 msfa_output_resolution : int = 16 ,
139139 act_layer : Optional [LayerType ] = None ,
140140 norm_layer : Optional [LayerType ] = None ,
@@ -574,6 +574,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
574574 return self .forward_features (x )
575575
576576
577+ def checkpoint_filter_fn (
578+ state_dict : Dict [str , torch .Tensor ],
579+ model ,
580+ ) -> Dict [str , torch .Tensor ]:
581+ """ convert weights from gemma encoders """
582+ state_dict = state_dict .get ('model' , state_dict )
583+ state_dict = state_dict .get ('state_dict' , state_dict )
584+ if 'model.vision_tower.timm_model.conv_stem.conv.weight' in state_dict :
585+ prefix = 'model.vision_tower.timm_model.'
586+ state_dict = {k .replace (prefix , '' ): v for k , v in state_dict .items () if prefix in k }
587+ return state_dict
588+
589+
577590def _create_mnv5_encoder (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5Encoder :
578591 out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
579592 feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' )
@@ -590,21 +603,22 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo
590603 variant ,
591604 pretrained ,
592605 pretrained_strict = False ,
606+ pretrained_filter_fn = checkpoint_filter_fn ,
593607 feature_cfg = feature_cfg ,
594608 kwargs_filter = kwargs_filter ,
595609 ** kwargs ,
596610 )
597611 return model
598612
599613
600- def _create_mnv5 (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5Encoder :
614+ def _create_mnv5 (variant : str , pretrained : bool = False , ** kwargs ) -> MobileNetV5 :
601615 out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
602616 feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' )
603617 model = build_model_with_cfg (
604618 MobileNetV5 ,
605619 variant ,
606620 pretrained ,
607- pretrained_strict = False ,
621+ pretrained_filter_fn = checkpoint_filter_fn ,
608622 feature_cfg = feature_cfg ,
609623 ** kwargs ,
610624 )
@@ -809,8 +823,8 @@ def _cfg(url: str = '', **kwargs):
809823 num_classes = 0 ),
810824
811825 # WIP classification configs for testing
812- 'mobilenetv5_300m' : _cfg (
813- # hf_hub_id='timm/',
826+ 'mobilenetv5_300m.gemma3n ' : _cfg (
827+ hf_hub_id = 'timm/' ,
814828 mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
815829 input_size = (3 , 768 , 768 ),
816830 num_classes = 0 ),
0 commit comments