Skip to content

Commit

Permalink
Allow the synthesized output to be seeded
Browse files Browse the repository at this point in the history
  • Loading branch information
qubixes committed Dec 19, 2024
1 parent 0eadc64 commit 874150e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
mf_out = MetaFrame.load_json(gmf_path)

# create a fake dataset
df_syn = mf_out.synthesize(10)
df_syn = mf_out.synthesize(10, seed=1234)
4 changes: 2 additions & 2 deletions examples/gmf_files/example_gmf_simple.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"provenance": {
"created by": {
"name": "metasyn",
"version": "1.0.2.dev34+gd68929e"
"version": "1.1.0"
},
"creation time": "2024-10-01T09:57:15.595769"
"creation time": "2024-12-18T14:54:05.300334"
},
"vars": [
{
Expand Down
4 changes: 2 additions & 2 deletions metasyn/metaframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def load_toml(cls, fp: Union[pathlib.Path, str],
meta_vars = [MetaVar.from_dict(d) for d in self_dict["vars"]]
return cls(meta_vars, n_rows)

def synthesize(self, n: Optional[int] = None) -> pl.DataFrame:
def synthesize(self, n: Optional[int] = None, seed: Optional[int] = None) -> pl.DataFrame:
"""Create a synthetic Polars dataframe.
Parameters
Expand All @@ -467,7 +467,7 @@ def synthesize(self, n: Optional[int] = None) -> pl.DataFrame:
raise ValueError("Cannot synthesize DataFrame, since number of rows is unknown."
"Please specify the number of rows to synthesize.")
n = self.n_rows
synth_dict = {var.name: var.draw_series(n) for var in self.meta_vars}
synth_dict = {var.name: var.draw_series(n, seed=seed) for var in self.meta_vars}
return pl.DataFrame(synth_dict)

def __repr__(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion metasyn/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def draw(self) -> Any:
return None
return self.distribution.draw()

def draw_series(self, n: int) -> pl.Series:
def draw_series(self, n: int, seed: Optional[int]) -> pl.Series:
"""Draw a new synthetic series from the metadata.
Parameters
Expand All @@ -245,6 +245,8 @@ def draw_series(self, n: int) -> pl.Series:
polars.Series:
Polars series with the synthetic data.
"""
if seed is not None:
np.random.seed(seed)
self.distribution.draw_reset()
value_list = [self.draw() for _ in range(n)]
pl_type = self.dtype.split("(")[0]
Expand Down

0 comments on commit 874150e

Please sign in to comment.