@@ -157,11 +157,7 @@ def __init__(
157157 self ._param_exprs = {} # cache for einsum expr
158158 self ._tiny = torch .finfo (torch .bfloat16 ).tiny
159159 self .rng = random .Random (1337 )
160- if deterministic :
161- # Use a Generator to try to be more deterministic across resume (save/load)
162- self .torch_rng = torch .Generator ().manual_seed (1337 )
163- else :
164- self .torch_rng = None
160+ self .deterministic = deterministic
165161
166162 # make compile optional (for bwd compat)
167163 if has_dynamo :
@@ -178,7 +174,6 @@ def __init__(
178174 def __getstate__ (self ):
179175 _dict = super ().__getstate__ ()
180176 _dict ["rng" ] = self .rng
181- _dict ["torch_rng" ] = self .torch_rng
182177 return _dict
183178
184179 def state_dict (self ) -> Dict [str , Any ]:
@@ -187,28 +182,21 @@ def state_dict(self) -> Dict[str, Any]:
187182
188183 # Add the generator state
189184 optimizer_state ['rng_state' ] = self .rng .getstate ()
190- if self .torch_rng is not None :
191- optimizer_state ['torch_rng_state' ] = self .torch_rng .get_state ()
192-
193185 return optimizer_state
194186
195187 def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
196188 # Extract and remove the RNG state from the state dict
197189 rng_states = {}
198190 if 'rng_state' in state_dict :
199191 rng_states ['rng_state' ] = state_dict .pop ('rng_state' )
200- if 'torch_rng_state' in state_dict :
201- rng_states ['torch_rng_state' ] = state_dict .pop ('torch_rng_state' )
202-
192+
203193 # Load the optimizer state
204194 super ().load_state_dict (state_dict )
205195 state_dict .update (rng_states ) # add back
206196
207197 # Restore the RNG state if it exists
208198 if 'rng_state' in rng_states :
209199 self .rng .setstate (rng_states ['rng_state' ])
210- if 'torch_rng_state' in rng_states :
211- self .torch_rng .set_state (rng_states ['torch_rng_state' ])
212200
213201 def __setstate__ (self , state ):
214202 super ().__setstate__ (state )
@@ -317,15 +305,17 @@ def step(self, closure=None):
317305 if do_update :
318306 exprA , exprGs , _ = exprs
319307 Q = state ["Q" ]
320- if self .torch_rng is None :
321- V = torch .randn_like (debiased_momentum , dtype = precond_dtype )
308+ if self .deterministic :
309+ torch_rng = torch .Generator (device = debiased_momentum .device )
310+ torch_rng .manual_seed (self .rng .randint (0 , 2 ** 31 ))
322311 else :
323- # Restoring generator state to device is messy. For now,
324- # we keep RNG on CPU, but this slows the optimizer down quite a bit.
325- # FIXME Need a better approach
326- V = torch .randn (
327- debiased_momentum .shape , generator = self .torch_rng , dtype = precond_dtype , device = 'cpu' )
328- V = V .to (debiased_momentum .device )
312+ torch_rng = None
313+ V = torch .randn (
314+ debiased_momentum .shape ,
315+ generator = torch_rng ,
316+ dtype = precond_dtype ,
317+ device = debiased_momentum .device ,
318+ )
329319 G = debiased_momentum if momentum_into_precond_update else grad
330320
331321 A , conjB = self ._calc_A_and_conjB (exprA , G , Q , V )
0 commit comments