@@ -212,7 +212,7 @@ class TimeSeasonality(Component):
212212 sigma_level_trend = pm.HalfNormal(
213213 "sigma_level_trend", sigma=1e-6, dims=ss_mod.param_dims["sigma_level_trend"]
214214 )
215- coefs_annual = pm.Normal("coefs_annual ", sigma=1e-2, dims=ss_mod.param_dims["coefs_annual "])
215+ params_annual = pm.Normal("params_annual ", sigma=1e-2, dims=ss_mod.param_dims["params_annual "])
216216
217217 ss_mod.build_statespace_graph(data)
218218 idata = pm.sample(
@@ -298,10 +298,10 @@ def populate_component_properties(self):
298298 for endog_name in self .observed_state_names
299299 for state_name in self .provided_state_names
300300 ]
301- self .param_names = [f"coefs_ { self .name } " ]
301+ self .param_names = [f"params_ { self .name } " ]
302302
303303 self .param_info = {
304- f"coefs_ { self .name } " : {
304+ f"params_ { self .name } " : {
305305 "shape" : (k_states ,) if k_endog == 1 else (k_endog , k_states ),
306306 "constraints" : None ,
307307 "dims" : (f"state_{ self .name } " ,)
@@ -311,7 +311,7 @@ def populate_component_properties(self):
311311 }
312312
313313 self .param_dims = {
314- f"coefs_ { self .name } " : (f"state_{ self .name } " ,)
314+ f"params_ { self .name } " : (f"state_{ self .name } " ,)
315315 if k_endog == 1
316316 else (f"endog_{ self .name } " , f"state_{ self .name } " )
317317 }
@@ -327,12 +327,14 @@ def populate_component_properties(self):
327327
328328 if self .innovations :
329329 self .param_names += [f"sigma_{ self .name } " ]
330+ self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
330331 self .param_info [f"sigma_{ self .name } " ] = {
331- "shape" : (),
332+ "shape" : () if k_endog == 1 else ( k_endog ,) ,
332333 "constraints" : "Positive" ,
333- "dims" : None ,
334+ "dims" : None if k_endog == 1 else ( f"endog_ { self . name } " ,) ,
334335 }
335- self .shock_names = [f"{ self .name } [{ name } ]" for name in self .observed_state_names ]
336+ if k_endog > 1 :
337+ self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
336338
337339 def make_symbolic_graph (self ) -> None :
338340 k_states = self .k_states // self .k_endog
@@ -377,7 +379,7 @@ def make_symbolic_graph(self) -> None:
377379 self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
378380
379381 initial_states = self .make_and_register_variable (
380- f"coefs_ { self .name } " ,
382+ f"params_ { self .name } " ,
381383 shape = (k_unique_states ,) if k_endog == 1 else (k_endog , k_unique_states ),
382384 )
383385 if k_endog == 1 :
@@ -506,7 +508,7 @@ def make_symbolic_graph(self) -> None:
506508 self .ssm ["design" , :, :] = pt .linalg .block_diag (* [Z for _ in range (k_endog )])
507509
508510 init_state = self .make_and_register_variable (
509- f"{ self .name } " , shape = (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs )
511+ f"params_ { self .name } " , shape = (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs )
510512 )
511513
512514 init_state_idx = np .concatenate (
@@ -535,19 +537,30 @@ def make_symbolic_graph(self) -> None:
535537 def populate_component_properties (self ):
536538 k_endog = self .k_endog
537539 n_coefs = self .n_coefs
538- k_states = self .k_states // k_endog
539540
540541 self .state_names = [
541- f"{ f } _{ self . name } _{ i } [{ obs_state_name } ]"
542+ f"{ f } _{ i } _{ self . name } [{ obs_state_name } ]"
542543 for obs_state_name in self .observed_state_names
543544 for i in range (self .n )
544545 for f in ["Cos" , "Sin" ]
545546 ]
546- self .param_names = [f"{ self .name } " ]
547+ # determine which state names correspond to parameters
548+ # all endog variables use same state structure, so we just need
549+ # the first n_coefs state names (which may be less than total if saturated)
550+ param_state_names = [f"{ f } _{ i } _{ self .name } " for i in range (self .n ) for f in ["Cos" , "Sin" ]][
551+ :n_coefs
552+ ]
553+
554+ self .param_names = [f"params_{ self .name } " ]
555+
556+ self .param_dims = {
557+ f"params_{ self .name } " : (f"state_{ self .name } " ,)
558+ if k_endog == 1
559+ else (f"endog_{ self .name } " , f"state_{ self .name } " )
560+ }
547561
548- self .param_dims = {self .name : (f"state_{ self .name } " ,)}
549562 self .param_info = {
550- f"{ self .name } " : {
563+ f"params_ { self .name } " : {
551564 "shape" : (n_coefs ,) if k_endog == 1 else (k_endog , n_coefs ),
552565 "constraints" : None ,
553566 "dims" : (f"state_{ self .name } " ,)
@@ -556,23 +569,22 @@ def populate_component_properties(self):
556569 }
557570 }
558571
559- # Regardless of whether the fourier basis are saturated, there will always be one symbolic state per basis.
560- # That's why the self.states is just a simple loop over everything. But when saturated, one of those states
561- # doesn't have an associated **parameter**, so the coords need to be adjusted to reflect this.
562- init_state_idx = np .concatenate (
563- [
564- np .arange (k_states * i , (i + 1 ) * k_states , dtype = int )[:n_coefs ]
565- for i in range (k_endog )
566- ],
567- axis = 0 ,
572+ self .coords = (
573+ {f"state_{ self .name } " : param_state_names }
574+ if k_endog == 1
575+ else {
576+ f"endog_{ self .name } " : self .observed_state_names ,
577+ f"state_{ self .name } " : param_state_names ,
578+ }
568579 )
569- self .coords = {f"state_{ self .name } " : [self .state_names [i ] for i in init_state_idx ]}
570580
571581 if self .innovations :
572- self .shock_names = self .state_names .copy ()
573582 self .param_names += [f"sigma_{ self .name } " ]
583+ self .shock_names = self .state_names .copy ()
574584 self .param_info [f"sigma_{ self .name } " ] = {
575- "shape" : () if k_endog == 1 else (k_endog , n_coefs ),
585+ "shape" : () if k_endog == 1 else (k_endog ,),
576586 "constraints" : "Positive" ,
577587 "dims" : None if k_endog == 1 else (f"endog_{ self .name } " ,),
578588 }
589+ if k_endog > 1 :
590+ self .param_dims [f"sigma_{ self .name } " ] = (f"endog_{ self .name } " ,)
0 commit comments