File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -429,7 +429,19 @@ def build_model_with_cfg(
429429 else :
430430 model = model_cls (cfg = model_cfg , ** kwargs )
431431 if pretrained :
432+ # .to_empty() will also move cpu params/buffers to uninitialized storage.
433+ # this is problematic for non-persistent buffers, since they don't get loaded
434+ # from pretrained weights later (not part of state_dict). hence, we have
435+ # to save them before calling .to_empty() and fill them back after.
436+ buffers = {k : v for k , v in model .named_buffers () if not v .is_meta }
432437 model .to_empty (device = "cpu" )
438+ for k , v in model .named_buffers ():
439+ if k in buffers :
440+ v .data = buffers [k ]
441+
442+ # alternative, rely on internal method ._apply()
443+ # model._apply(lambda t: torch.empty_like(t, device="cpu") if t.is_meta else t)
444+
433445 model .pretrained_cfg = pretrained_cfg
434446 model .default_cfg = model .pretrained_cfg # alias for backwards compat
435447
You can’t perform that action at this time.
0 commit comments