@@ -157,11 +157,7 @@ def __init__(
157157 self ._param_exprs = {} # cache for einsum expr
158158 self ._tiny = torch .finfo (torch .bfloat16 ).tiny
159159 self .rng = random .Random (1337 )
160- if deterministic :
161- # Use a Generator to try to be more deterministic across resume (save/load)
162- self .torch_rng = torch .Generator ().manual_seed (1337 )
163- else :
164- self .torch_rng = None
160+ self .deterministic = deterministic
165161
166162 # make compile optional (for bwd compat)
167163 if has_dynamo :
@@ -178,7 +174,6 @@ def __init__(
178174 def __getstate__ (self ):
179175 _dict = super ().__getstate__ ()
180176 _dict ["rng" ] = self .rng
181- _dict ["torch_rng" ] = self .torch_rng
182177 return _dict
183178
184179 def state_dict (self ) -> Dict [str , Any ]:
@@ -187,28 +182,21 @@ def state_dict(self) -> Dict[str, Any]:
187182
188183 # Add the generator state
189184 optimizer_state ['rng_state' ] = self .rng .getstate ()
190- if self .torch_rng is not None :
191- optimizer_state ['torch_rng_state' ] = self .torch_rng .get_state ()
192-
193185 return optimizer_state
194186
195187 def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
196188 # Extract and remove the RNG state from the state dict
197189 rng_states = {}
198190 if 'rng_state' in state_dict :
199191 rng_states ['rng_state' ] = state_dict .pop ('rng_state' )
200- if 'torch_rng_state' in state_dict :
201- rng_states ['torch_rng_state' ] = state_dict .pop ('torch_rng_state' )
202-
192+
203193 # Load the optimizer state
204194 super ().load_state_dict (state_dict )
205195 state_dict .update (rng_states ) # add back
206196
207197 # Restore the RNG state if it exists
208198 if 'rng_state' in rng_states :
209199 self .rng .setstate (rng_states ['rng_state' ])
210- if 'torch_rng_state' in rng_states :
211- self .torch_rng .set_state (rng_states ['torch_rng_state' ])
212200
213201 def __setstate__ (self , state ):
214202 super ().__setstate__ (state )
@@ -317,15 +305,11 @@ def step(self, closure=None):
317305 if do_update :
318306 exprA , exprGs , _ = exprs
319307 Q = state ["Q" ]
320- if self .torch_rng is None :
321- V = torch .randn_like ( debiased_momentum , dtype = precond_dtype )
308+ if self .deterministic :
309+ torch_rng = torch .Generator ( device = V . device ). manual_seed ( self . rng . randint ( 0 , 2 ** 31 ) )
322310 else :
323- # Restoring generator state to device is messy. For now,
324- # we keep RNG on CPU, but this slows the optimizer down quite a bit.
325- # FIXME Need a better approach
326- V = torch .randn (
327- debiased_momentum .shape , generator = self .torch_rng , dtype = precond_dtype , device = 'cpu' )
328- V = V .to (debiased_momentum .device )
311+ torch_rng = None
312+ V = torch .randn (debiased_momentum .shape , generator = torch_rng , dtype = precond_dtype , device = debiased_momentum .device )
329313 G = debiased_momentum if momentum_into_precond_update else grad
330314
331315 A , conjB = self ._calc_A_and_conjB (exprA , G , Q , V )
0 commit comments