diff --git a/examples/basic_example.py b/examples/basic_example.py index aa781bd5..f1b9bc6b 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -25,4 +25,4 @@ mf_out = MetaFrame.load_json(gmf_path) # create a fake dataset -df_syn = mf_out.synthesize(10) +df_syn = mf_out.synthesize(10, seed=1234) diff --git a/examples/gmf_files/example_gmf_simple.json b/examples/gmf_files/example_gmf_simple.json index b7cf5001..66dfa9ce 100644 --- a/examples/gmf_files/example_gmf_simple.json +++ b/examples/gmf_files/example_gmf_simple.json @@ -4,9 +4,9 @@ "provenance": { "created by": { "name": "metasyn", - "version": "1.0.2.dev34+gd68929e" + "version": "1.1.0" }, - "creation time": "2024-10-01T09:57:15.595769" + "creation time": "2024-12-18T14:54:05.300334" }, "vars": [ { diff --git a/metasyn/metaframe.py b/metasyn/metaframe.py index 14712009..3cfe9890 100644 --- a/metasyn/metaframe.py +++ b/metasyn/metaframe.py @@ -449,7 +449,7 @@ def load_toml(cls, fp: Union[pathlib.Path, str], meta_vars = [MetaVar.from_dict(d) for d in self_dict["vars"]] return cls(meta_vars, n_rows) - def synthesize(self, n: Optional[int] = None) -> pl.DataFrame: + def synthesize(self, n: Optional[int] = None, seed: Optional[int] = None) -> pl.DataFrame: """Create a synthetic Polars dataframe. Parameters @@ -467,7 +467,7 @@ def synthesize(self, n: Optional[int] = None) -> pl.DataFrame: raise ValueError("Cannot synthesize DataFrame, since number of rows is unknown." "Please specify the number of rows to synthesize.") n = self.n_rows - synth_dict = {var.name: var.draw_series(n) for var in self.meta_vars} + synth_dict = {var.name: var.draw_series(n, seed=seed) for var in self.meta_vars} return pl.DataFrame(synth_dict) def __repr__(self) -> str: diff --git a/metasyn/var.py b/metasyn/var.py index 645231c6..fd71e721 100644 --- a/metasyn/var.py +++ b/metasyn/var.py @@ -232,7 +232,7 @@ def draw(self) -> Any: return None return self.distribution.draw() - def draw_series(self, n: int) -> pl.Series: + def draw_series(self, n: int, seed: Optional[int]) -> pl.Series: """Draw a new synthetic series from the metadata. Parameters @@ -245,6 +245,8 @@ def draw_series(self, n: int) -> pl.Series: polars.Series: Polars series with the synthetic data. """ + if seed is not None: + np.random.seed(seed) self.distribution.draw_reset() value_list = [self.draw() for _ in range(n)] pl_type = self.dtype.split("(")[0]