Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Abstract Variables/Symbols #1308

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implement abstract classes
alanlujan91 committed Jul 25, 2023
commit c523d4b50a819be3cb66e6298c856eeb508d976f
52 changes: 52 additions & 0 deletions HARK/abstract/tests/consindshk.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
states: !StateSpace
variables:
- !State
name: m
short_name: money
long_name: market resources
latex_repr: \mNrm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know how you're imagining using the Latex representation.
But my preference would be to not include it in the PR unless you have some demonstration of how it works ready.
4 different ways to name something for a quick demo seems like a lot....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now these are filler, I want to throw any non-required keys into an attributes dictionary.

- !State
name: &name stigma
short_name: &short_name risky share
long_name: &long_name risky share of portfolio
latex_repr: &latex_repr \stigma

actions: !ActionSpace
variables:
- !Action
name: c
short_name: consumption
long_name: consumption
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fields optional? Can one spill over to the others as a default?
Basically, how can we make these config files lighter weight?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are optional, only required is name.

latex_repr: \cNrm
- !Action
name: *name
short_name: *short_name
long_name: *long_name
latex_repr: *latex_repr

post_states: !PostStateSpace
variables:
- !PostState
name: a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you've repeated the post_states block twice in this file?

Not sure how @mnwhite feels, but maybe we don't need to draw a firm distinction between states and post states like this.

Or the labels could be inside the variable, not part of the document structure.

Compare:

var_type_1:
   variables:
       - !VarTypeClass1
           details

var_type_2:
   variables:
       - !VarTypeClass2
           details

with

variables:
    - !VarTypeClass1
       details
    - !VarTypeClass2   
       details

The latter is same information, but fewer lines.

short_name: assets
long_name: savings
latex_repr: \aNrm
- !PostState
name: *name
short_name: *short_name
long_name: *long_name
latex_repr: *latex_repr

parameters: !Parameters
variables:
- !Parameter
name: DiscFac
short_name: discount factor
long_name: discount factor
latex_repr: \beta
- !Parameter
name: CRRA
short_name: risk aversion
long_name: coefficient of relative risk aversion
latex_repr: \rho
265 changes: 265 additions & 0 deletions HARK/abstract/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from dataclasses import dataclass, field
from typing import Any, Mapping, Optional, Union
from warnings import warn

import numpy as np
import xarray as xr
from yaml import SafeLoader, YAMLObject

rng = np.random.default_rng()


@dataclass
class Variable(YAMLObject):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to initialize Variables without parsing them from a YAML file?
I.e., a pure python way to create variables?

I'm a little wary of tying the model objects too tightly to the serial format because it can make it tricky to interoperate with other python libraries.

"""
Abstract class for representing variables. Variables are the building blocks
of models. They can be parameters, states, actions, shocks, or auxiliaries.
"""

name: str # The name of the variable, required
attrs: dict = field(default_factory=dict, kw_only=True)
short_name: str = field(default=None, kw_only=True)
long_name: str = field(default=None, kw_only=True)
latex_repr: str = field(default=None, kw_only=True)
yaml_tag: str = field(default="!Variable", kw_only=False)
yaml_loader = SafeLoader

def __post_init__(self):
for key in ["long_name", "short_name", "latex_repr"]:
self.attrs.setdefault(key, None)

@classmethod
def from_yaml(cls, loader, node):
fields = loader.construct_mapping(node, deep=True)
return cls(**fields)

def __repr__(self):
"""
String representation of the variable.

Returns:
str: The string representation of the variable.
"""
return f"{self.__class__.__name__}({self.name})"


@dataclass
class VariableSpace(YAMLObject):
"""
Abstract class for representing a collection of variables.
"""

variables: list[Variable]
yaml_tag: str = field(default="!VariableSpace", kw_only=True)
yaml_loader = SafeLoader

def __post_init__(self):
"""
Save the variables in a dictionary for easy access.
"""
self.variables = {var.name: var for var in self.variables}


@dataclass
class Parameter(Variable):
"""
A `Parameter` is a variable that has a fixed value.
"""

value: Union[int, float] = 0
yaml_tag: str = field(default="!Parameter", kw_only=True)

def __repr__(self):
"""
String representation of the parameter.

Returns:
str: The string representation of the parameter.
"""
return f"{self.__class__.__name__}({self.name}, {self.value})"


@dataclass
class Parameters(VariableSpace):
"""
A `Parameters` is a collection of parameters.
"""

yaml_tag: str = "!Parameters"


@dataclass
class Auxiliary(Variable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Chris that 'auxiliary' is more like a macro.
I wouldn't use 'auxiliary' here, though it makes sense to have this sort of object.

"""
Class for representing auxiliaries. Auxiliaries are abstract variables that
have an array structure but are not states, actions, or shocks. They may
include information like domain, measure (discrete or continuous), etc.
"""

array: Union[list, np.ndarray, xr.DataArray] = None
domain: Union[list, tuple] = field(default=None, kw_only=True)
is_discrete: bool = field(default=False, kw_only=True)
yaml_tag: str = field(default="!Auxiliary", kw_only=True)


@dataclass
class AuxiliarySpace(VariableSpace):
"""
A `AuxiliarySpace` is a collection of auxiliary variables.
"""

yaml_tag: str = "!AuxiliarySpace"


@dataclass(kw_only=True)
class State(Auxiliary):
"""
Class for representing a state variable.
"""

yaml_tag: str = "!State"

def assign_values(self, values):
return make_state_array(values, self.name, self.attrs)

def discretize(self, min, max, N, method):
# linear for now
self.assign_values(np.linspace(min, max, N))


@dataclass(kw_only=True)
class StateSpace(AuxiliarySpace):
states: Mapping[str, State] = field(init=False)
yaml_tag: str = "!StateSpace"

def __post_init__(self):
super().__post_init__()
self.states = xr.merge([self.variables])


@dataclass(kw_only=True)
class PostState(State):
yaml_tag: str = "!PostState"


@dataclass(kw_only=True)
class PostStateSpace(StateSpace):
post_states: Mapping[str, State] = field(init=False)
yaml_tag: str = "!PostStateSpace"

def __post_init__(self):
self.post_states = xr.merge(self.variables)


@dataclass(kw_only=True)
class Action(Auxiliary):
"""
Class for representing actions. Actions are variables that are chosen by the agent.
Can also be called a choice, control, decision, or a policy.

Args:
Variable (_type_): _description_
"""

is_optimal: bool = True
yaml_tag: str = "!Action"

def discretize(self, *args, **kwargs):
warn("Actions cannot be discretized.")


@dataclass(kw_only=True)
class ActionSpace(AuxiliarySpace):
actions: Mapping[str, State] = field(init=False)
yaml_tag: str = "!ActionSpace"

def __post_init__(self):
self.actions = xr.merge(self.variables)


@dataclass(kw_only=True)
class Shock(Variable):
"""
Class for representing shocks. Shocks are variables that are not
chosen by the agent.
Can also be called a random variable, or a state variable.

Args:
Variable (_type_): _description_
"""

yaml_tag: str = "!Shock"


@dataclass(kw_only=True)
class ShockSpace(VariableSpace):
shocks: list[Shock]
yaml_tag: str = "!ShockSpace"

def __post_init__(self):
self.shocks = xr.merge(self.shocks)


def make_state_array(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method duplicated?

values: np.ndarray,
name: Optional[str] = None,
attrs: Optional[dict] = None,
) -> xr.Dataset:
"""
Function to create a state with given values, name and attrs.

Parameters:
values (np.ndarray): The values for the state.
name (str, optional): The name of the state. Defaults to 'state'.
attrs (dict, optional): The attrs for the state. Defaults to None.

Returns:
State: An xarray DataArray representing the state.
"""
# Use a default name only when no name is provided
name = name or f"state{rng.integers(0, 100)}"
attrs = attrs or {}

return xr.Dataset(
{
name: xr.DataArray(
values,
name=name,
dims=(name,),
attrs=attrs,
)
}
)


def make_states_array(
values: Union[np.ndarray, list],
names: Optional[list[str]] = None,
attrs: Optional[list[dict]] = None,
) -> xr.Dataset:
"""
Function to create states with given values, names and attrs.

Parameters:
values (Union[np.ndarray, States]): The values for the states.
names (list[str], optional): The names of the states. Defaults to None.
attrs (list[dict], optional): The attrs for the states. Defaults to None.

Returns:
States: An xarray Dataset representing the states.
"""
if isinstance(values, list):
values_len = len(values)
elif isinstance(values, np.ndarray):
values_len = values.shape[0]

# Use default names and attrs only when they are not provided
names = names or [f"state{rng.integers(0, 100)}" for _ in range(values_len)]
attrs = attrs or [{}] * values_len

states = [
make_state_array(value, name, attr)
for value, name, attr in zip(values, names, attrs)
]

return xr.merge(states)
Loading