Skip to content

Commit

Permalink
Split varspec param (#351)
Browse files Browse the repository at this point in the history
* Split varspec parameter into config

* Update documentation and mypy/ruff
  • Loading branch information
qubixes authored Dec 19, 2024
1 parent ab14341 commit bce1e04
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/source/improve_synth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use the configuration file is a more appropriate interface (see also our :doc:`c
MetaFrame.fit_dataframe(
df,
var_specs="your_config_file.toml"
config="your_config_file.toml"
)
This refers to a configuration file called ``your_config_file.toml``:
Expand Down Expand Up @@ -177,7 +177,7 @@ The most common use-case for this is to set the distribution type and/or paramet
.. code-block:: python
# In this example you put the specifications in the toml file.
MetaFrame.fit_dataframe(df, var_specs="your_config_file.toml")
MetaFrame.fit_dataframe(df, config="your_config_file.toml")
.. code-block:: toml
Expand Down
2 changes: 1 addition & 1 deletion metasyn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def create_metadata() -> None:
data_frame = pl.read_csv(args.input, try_parse_dates=True, infer_schema_length=10000,
null_values=["", "na", "NA", "N/A", "Na"],
ignore_errors=True)
meta_frame = MetaFrame.fit_dataframe(data_frame, meta_config)
meta_frame = MetaFrame.fit_dataframe(data_frame, config=meta_config)
meta_frame.save(args.output)


Expand Down
12 changes: 11 additions & 1 deletion metasyn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self.config_version = config_version

@staticmethod
def _parse_var_spec(var_spec):
def _parse_var_spec(var_spec) -> VarSpec:
if isinstance(var_spec, VarSpec):
return var_spec
return VarSpec.from_dict(var_spec)
Expand All @@ -72,6 +72,16 @@ def dist_providers(self, dist_providers):
else:
self._dist_providers = dist_providers

def update_varspecs(self, new_var_specs: Union[list[dict], list[VarSpec]]):
new_var_specs = [self._parse_var_spec(v) for v in new_var_specs]
for cur_new_var_spec in new_var_specs:
# Check if currently in varspecs and pop if it exists.
for i_var, old_var_spec in enumerate(self.var_specs):
if old_var_spec.name == cur_new_var_spec.name:
self.var_specs.pop(i_var)
break
self.var_specs.append(cur_new_var_spec)

@classmethod
def from_toml(cls, config_fp: Union[str, Path]) -> MetaConfig:
"""Create a MetaConfig class from a .toml file.
Expand Down
32 changes: 23 additions & 9 deletions metasyn/metaframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ def n_columns(self) -> int:
def fit_dataframe( # noqa: PLR0912
cls,
df: Optional[pl.DataFrame],
var_specs: Optional[Union[list[VarSpec], pathlib.Path, str, MetaConfig]] = None,
var_specs: Optional[Union[list[VarSpec]]] = None,
dist_providers: Optional[list[str]] = None,
privacy: Optional[Union[BasePrivacy, dict]] = None,
n_rows: Optional[int] = None,
progress_bar: bool = True):
progress_bar: bool = True,
config: Optional[Union[pathlib.Path, str, MetaConfig]] = None):
"""Create a metasyn object from a polars (or pandas) dataframe.
The Polars dataframe should be formatted already with the correct
Expand Down Expand Up @@ -100,21 +101,34 @@ def fit_dataframe( # noqa: PLR0912
of rows in the input dataframe.
progress_bar:
Whether to create a progress bar.
config:
A path or MetaConfig object that contains information about the variable specifications
, defaults, etc. Variable specs in the config parameter will be overwritten by the
var_specs parameter.
Returns
-------
MetaFrame:
Initialized metasyn metaframe.
"""
if isinstance(var_specs, (str, pathlib.Path, MetaConfig)) and config is None:
warn("Supplying the configuration through var_specs is deprecated and will be removed"
f" in metasyn version 2.0. Use config={var_specs} instead.",
DeprecationWarning, stacklevel=2)
config = var_specs
var_specs = None
# Parse the var_specs into a MetaConfig instance.
if isinstance(var_specs, (pathlib.Path, str)):
meta_config = MetaConfig.from_toml(var_specs)
elif isinstance(var_specs, MetaConfig):
meta_config = var_specs
elif var_specs is None:
if config is None:
meta_config = MetaConfig([], dist_providers, defaults = {"privacy": privacy})
elif isinstance(config, (pathlib.Path, str)):
meta_config = MetaConfig.from_toml(config)
else:
meta_config = MetaConfig(var_specs, dist_providers, defaults = {"privacy": privacy})
meta_config = config

# var_specs overrules variable specifications in the configuration (file).
if var_specs is not None:
meta_config.update_varspecs(var_specs)

if dist_providers is not None:
meta_config.dist_providers = dist_providers # type: ignore
if privacy is not None:
Expand Down Expand Up @@ -175,7 +189,7 @@ def from_config(cls, meta_config: MetaConfig) -> MetaFrame:
-------
A created MetaFrame.
"""
return cls.fit_dataframe(None, meta_config)
return cls.fit_dataframe(None, config=meta_config)

def to_dict(self) -> Dict[str, Any]:
"""Create dictionary with the properties for recreation."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def test_create_meta(tmp_dir, config):
]
if config:
cmd.extend(["--config", Path(tmp_dir) / 'config.ini'])
result = subprocess.run(cmd, check=False, capture_output=True)
assert result.returncode == 0
result = subprocess.run(cmd, check=True, capture_output=True)
assert result.returncode == 0, result.stdout
assert out_file.is_file()
meta_frame = MetaFrame.load_json(out_file)
assert len(meta_frame.meta_vars) == 12
Expand Down
24 changes: 22 additions & 2 deletions tests/test_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_datafree_create(tmpdir):
temp_toml = tmpdir / "test.toml"
create_input_toml(temp_toml)
assert cmp(temp_toml, Path("examples", "config_files", "example_all.toml"))
mf = MetaFrame.fit_dataframe(None, var_specs=Path(temp_toml))
mf = MetaFrame.fit_dataframe(None, config=Path(temp_toml))

assert isinstance(mf, MetaFrame)
assert mf.n_columns == len(BuiltinDistributionProvider.distributions)
Expand All @@ -35,11 +35,31 @@ def test_datafree_create(tmpdir):
)
def test_toml_save_load(tmpdir, toml_input, data):
"""Test whether TOML GMF files can be saved/loaded."""
mf = MetaFrame.fit_dataframe(data, toml_input)
mf = MetaFrame.fit_dataframe(data, config=toml_input)
mf.save(tmpdir/"test.toml")
new_mf = MetaFrame.load(tmpdir/"test.toml")
assert mf.n_columns == new_mf.n_columns

def test_varspec_update():
"""Check whether overwriting the varspecs with the var_specs parameter works."""
toml_input = Path("examples", "config_files", "example_all.toml")
var_specs = [{
"name": "DiscreteTruncatedNormalDistribution",
"var_type": "discrete",
"distribution": {
"implements": "core.normal",
"unique": False,
"parameters": {
"mean": 0,
"sd": 1,
}
}
}]
mf_normal = MetaFrame.fit_dataframe(None, config=toml_input)
mf_varspec = MetaFrame.fit_dataframe(None, var_specs=var_specs, config=toml_input)
assert mf_normal["DiscreteTruncatedNormalDistribution"].distribution.implements == "core.truncated_normal"
assert mf_varspec["DiscreteTruncatedNormalDistribution"].distribution.implements == "core.normal"

@mark.parametrize(
"gmf_file", [
Path("examples", "gmf_files", "example_gmf_simple.json"),
Expand Down

0 comments on commit bce1e04

Please sign in to comment.