-
Notifications
You must be signed in to change notification settings - Fork 72
Represent statespace metadata with dataclasses #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Represent statespace metadata with dataclasses #607
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great first pass, much cleaner than what we have now.
|
|
||
|
|
||
| @dataclass | ||
| class ParameterProperty: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class ParameterProperty: | |
| class Parameter(StatespaceProperty): |
Let's make the names really dumb and simple
|
|
||
|
|
||
| @dataclass | ||
| class ParameterProperties: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class ParameterProperties: | |
| class ParameterInfo(Info[StatespaceParameter]): |
Same here. I'm also not against just calling it StatespaceParameters (plural), but that's a bit less obvious.
|
|
||
| from pymc_extras.statespace.models.structural.core import Component | ||
| from pymc_extras.statespace.utils.constants import TIME_DIM | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment: There's a lot of code duplication here because of all the shared functionality between the different types of parameters. Use a superclass:
@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
T = TypeVar("T", bound=Property)
@dataclass(frozen=True)
class Info(Generic[T]):
items: list[T]
key_field: str = "name"
def _key(self, item: T) -> str:
return getattr(item, self.key_field)
def get(self, key: str) -> T | None:
return next((i for i in self.items if self._key(i) == key), None)
def __getitem__(self, key: str) -> T:
result = self.get(key)
if result is None:
raise KeyError(f"No {self.key_field} '{key}'")
return result
def __contains__(self, key: str) -> bool:
return any(self._key(i) == key for i in self.items)
def __str__(self) -> str:
return f"{self.key_field}s: {[self._key(i) for i in self.items]}"| needs_exogenous_data: bool = field(default=False, init=False) | ||
|
|
||
| def __post_init__(self): | ||
| for d in self.data: | ||
| if d.is_exogenous: | ||
| self.needs_exogenous_data = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| needs_exogenous_data: bool = field(default=False, init=False) | |
| def __post_init__(self): | |
| for d in self.data: | |
| if d.is_exogenous: | |
| self.needs_exogenous_data = True | |
| @property | |
| def needs_exogenous_data(self) -> bool: | |
| return any(d.is_exogenous for d in self.items) |
Just compute this on demand with a property instead of going through post init hoops. In this case you can also set frozen=True.
| @dataclass | ||
| class DataProperty: | ||
| name: str | ||
| shape: tuple[int, ...] | ||
| dims: tuple[str, ...] | ||
| is_exogenous: bool | ||
|
|
||
| def __str__(self): | ||
| base = f"name: {self.name}\nshape: {self.shape}\ndims: {self.dims}\nis_exogenous: {self.is_exogenous}" | ||
| return base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an example, after you have the base classes, this just becomes:
@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool|
|
||
| self.coords = CoordProperties(coords=[regression_state_prop, endogenous_state_prop]) | ||
|
|
||
| def populate_component_properties(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in the Component superclass, and not touched at all by the subclasses.
You will also have to add each _set_foo() method to the superclass, along with some kind of default behavior.
|
We can also keep all of the existing properties like |
|
Reflecting on it, I am convinced this is the way to go. It's 1000x more ergonomic. I made some changes to your initial code to make the API more "dictionary like", and to reduce code duplication. I moved everything to |
|
@jessegrabowski, this is looking really cool! What can I do to help push this forward? |
|
Delete the new We should keep your notebook with the plan to add it as a new example for the docs. Or it can be merged into the custom statespace notebook. So that should also be updated to import from the new |
|
Perfect! I'll work on that today!! It is really looking cool! |
This is a draft proposal for #598
The idea is to handle each component separately using
_set_{component}methods and all information are stored using data classes for easy mapping.I believe this will simplify our tests of these components and will reduce redundancies where we have the same information spread across multiple sub-components like
data_namesanddata_info.@jessegrabowski let me know what you think I put a little notebook together to showcase the changes.