diff --git a/linopy/common.py b/linopy/common.py index 4e28fe63..a573182e 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -260,7 +260,7 @@ def check_has_nulls(df: pd.DataFrame, name: str): raise ValueError(f"{name} contains nan's in field(s) {fields}") -def infer_schema_polars(ds: pl.DataFrame) -> dict: +def infer_schema_polars(ds: Dataset, overwrites: dict[str, pl.DataType]) -> dict: """ Infer the schema for a Polars DataFrame based on the data types of its columns. @@ -272,7 +272,9 @@ def infer_schema_polars(ds: pl.DataFrame) -> dict: """ schema = {} for col_name, array in ds.items(): - if np.issubdtype(array.dtype, np.integer): + if col_name in overwrites: + schema[col_name] = overwrites[col_name] + elif np.issubdtype(array.dtype, np.integer): schema[col_name] = pl.Int32 if os.name == "nt" else pl.Int64 elif np.issubdtype(array.dtype, np.floating): schema[col_name] = pl.Float64 @@ -301,10 +303,10 @@ def to_polars(ds: Dataset, **kwargs) -> pl.DataFrame: DataFrame constructor. """ data = broadcast(ds)[0] - return pl.DataFrame({k: v.values.reshape(-1) for k, v in data.items()}, **kwargs) + return pl.LazyFrame({k: v.values.reshape(-1) for k, v in data.items()}, **kwargs) -def check_has_nulls_polars(df: pl.DataFrame, name: str = "") -> None: +def check_has_nulls_polars(df: pl.LazyFrame, name: str = "") -> None: """ Checks if the given DataFrame contains any null values and raises a ValueError if it does. @@ -316,7 +318,7 @@ def check_has_nulls_polars(df: pl.DataFrame, name: str = "") -> None: ValueError: If the DataFrame contains null values, a ValueError is raised with a message indicating the name of the constraint and the fields containing null values. """ - has_nulls = df.select(pl.col("*").is_null().any()) + has_nulls = df.select(pl.col("*").is_null().any()).collect() null_columns = [col for col in has_nulls.columns if has_nulls[col][0]] if null_columns: raise ValueError(f"{name} contains nan's in field(s) {null_columns}") @@ -345,7 +347,7 @@ def filter_nulls_polars(df: pl.DataFrame) -> pl.DataFrame: return df.filter(cond) -def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame: +def group_terms_polars(df: pl.LazyFrame) -> pl.LazyFrame: """ Groups terms in a polars DataFrame. diff --git a/linopy/constraints.py b/linopy/constraints.py index 0d7740b3..3985c909 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -578,18 +578,19 @@ def to_polars(self): check_has_nulls_polars(long, name=f"{self.type} {self.name}") short = ds[[k for k in ds if "_term" not in ds[k].dims]] - schema = infer_schema_polars(short) - schema["sign"] = pl.Enum(["=", "<=", ">="]) + schema = infer_schema_polars( + short, overwrites={"sign": pl.Enum(["=", "<=", ">="])} + ) short = to_polars(short, schema=schema) short = filter_nulls_polars(short) check_has_nulls_polars(short, name=f"{self.type} {self.name}") - df = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) + lf = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) # delete subsequent non-null rhs (happens is all vars per label are -1) - is_non_null = df["rhs"].is_not_null() + is_non_null = pl.col("rhs").is_not_null() prev_non_is_null = is_non_null.shift(1).fill_null(False) - df = df.filter(is_non_null & ~prev_non_is_null | ~is_non_null) - return df[["labels", "coeffs", "vars", "sign", "rhs"]] + lf = lf.filter(is_non_null & ~prev_non_is_null | ~is_non_null) + return lf.select(pl.col(["labels", "coeffs", "vars", "sign", "rhs"])) sel = conwrap(Dataset.sel) diff --git a/linopy/expressions.py b/linopy/expressions.py index 124af6be..08b541ef 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1268,9 +1268,9 @@ def mask_func(data): check_has_nulls(df, name=self.type) return df - def to_polars(self) -> pl.DataFrame: + def to_polars(self) -> pl.LazyFrame: """ - Convert the expression to a polars DataFrame. + Convert the expression to a polars lazyFrame. The resulting DataFrame represents a long table format of the all non-masked expressions with non-zero coefficients. It contains the @@ -1278,13 +1278,13 @@ def to_polars(self) -> pl.DataFrame: Returns ------- - df : polars.DataFrame + lf : polars.LazyFrame """ - df = to_polars(self.data) - df = filter_nulls_polars(df) - df = group_terms_polars(df) - check_has_nulls_polars(df, name=self.type) - return df + lf = to_polars(self.data) + lf = filter_nulls_polars(lf) + lf = group_terms_polars(lf) + check_has_nulls_polars(lf, name=self.type) + return lf # Wrapped function which would convert variable to dataarray assign = exprwrap(Dataset.assign) @@ -1480,7 +1480,7 @@ def mask_func(data): check_has_nulls(df, name=self.type) return df - def to_polars(self, **kwargs): + def to_polars(self, **kwargs) -> pl.LazyFrame: """ Convert the expression to a polars DataFrame. @@ -1490,17 +1490,17 @@ def to_polars(self, **kwargs): Returns ------- - df : polars.DataFrame + lf : polars.LazyFrame """ vars = self.data.vars.assign_coords( {FACTOR_DIM: ["vars1", "vars2"]} ).to_dataset(FACTOR_DIM) ds = self.data.drop_vars("vars").assign(vars) - df = to_polars(ds, **kwargs) - df = filter_nulls_polars(df) - df = group_terms_polars(df) - check_has_nulls_polars(df, name=self.type) - return df + lf = to_polars(ds, **kwargs) + lf = filter_nulls_polars(lf) + lf = group_terms_polars(lf) + check_has_nulls_polars(lf, name=self.type) + return lf def to_matrix(self): """ diff --git a/linopy/io.py b/linopy/io.py index 779d3636..9582fb32 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -3,6 +3,7 @@ """ Module containing all import/export functionalities. """ + import logging import shutil import time @@ -12,6 +13,8 @@ import numpy as np import pandas as pd import polars as pl +import pyarrow as pa +import pyarrow.csv import xarray as xr from numpy import ones_like, zeros_like from scipy.sparse import tril, triu @@ -278,20 +281,41 @@ def to_lp_file(m, fn, integer_label): logger.info(f" Writing time: {round(time.time()-start, 2)}s") -def objective_write_linear_terms_polars(f, df): +def write_lazyframe(f, lf): + lf = lf.fill_null("") + + def to_pyarrow_schema(schema): + return pa.schema( + (k, pl.datatypes.py_type_to_arrow_type(pl.datatypes.dtype_to_py_type(v))) + for k, v in schema.items() + ) + + writer = pa.csv.CSVWriter( + f, + to_pyarrow_schema(lf.schema), + write_options=pa.csv.WriteOptions( + include_header=False, delimiter=";", quoting_style="none" + ), + ) + + def write_batch(batch): + writer.write(batch.to_arrow()) + return pl.DataFrame() + + lf.map_batches(write_batch, schema={}, streamable=True).collect() + + +def objective_write_linear_terms_polars(f, lf): cols = [ pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), pl.col("coeffs").cast(pl.String), pl.lit(" x"), pl.col("vars").cast(pl.String), ] - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + write_lazyframe(f, lf.select(pl.concat_str(cols, ignore_nulls=True))) -def objective_write_quadratic_terms_polars(f, df): +def objective_write_quadratic_terms_polars(f, lf): cols = [ pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), pl.col("coeffs").mul(2).cast(pl.String), @@ -301,10 +325,7 @@ def objective_write_quadratic_terms_polars(f, df): pl.col("vars2").cast(pl.String), ] f.write(b"+ [\n") - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + write_lazyframe(f, lf.select(pl.concat_str(cols, ignore_nulls=True))) f.write(b"] / 2\n") @@ -317,13 +338,13 @@ def objective_to_file_polars(m, f, log=False): sense = m.objective.sense f.write(f"{sense}\n\nobj:\n\n".encode("utf-8")) - df = m.objective.to_polars() + lf = m.objective.to_polars() if m.is_linear: - objective_write_linear_terms_polars(f, df) + objective_write_linear_terms_polars(f, lf) elif m.is_quadratic: - lins = df.filter(pl.col("vars1").eq(-1) | pl.col("vars2").eq(-1)) + lins = lf.filter(pl.col("vars1").eq(-1) | pl.col("vars2").eq(-1)) lins = lins.with_columns( pl.when(pl.col("vars1").eq(-1)) .then(pl.col("vars2")) @@ -332,7 +353,7 @@ def objective_to_file_polars(m, f, log=False): ) objective_write_linear_terms_polars(f, lins) - quads = df.filter(pl.col("vars1").ne(-1) & pl.col("vars2").ne(-1)) + quads = lf.filter(pl.col("vars1").ne(-1) & pl.col("vars2").ne(-1)) objective_write_quadratic_terms_polars(f, quads) @@ -353,7 +374,7 @@ def bounds_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.when(pl.col("lower") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), @@ -365,11 +386,7 @@ def bounds_to_file_polars(m, f, log=False): pl.col("upper").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def binaries_to_file_polars(m, f, log=False): @@ -389,18 +406,14 @@ def binaries_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.lit("x"), pl.col("labels").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def integers_to_file_polars(m, f, log=False, integer_label="general"): @@ -420,18 +433,14 @@ def integers_to_file_polars(m, f, log=False, integer_label="general"): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.lit("x"), pl.col("labels").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def constraints_to_file_polars(m, f, log=False, lazy=False): @@ -447,14 +456,14 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): colour=TQDM_COLOR, ) - # to make this even faster, we can use polars expression + # to make this even faster, we could create a custom polars expression plugin # https://docs.pola.rs/user-guide/expressions/plugins/#output-data-types for name in names: - df = m.constraints[name].to_polars() + lf = m.constraints[name].to_polars() - # df = df.lazy() + # lf = lf.lazy() # filter out repeated label values - df = df.with_columns( + lf = lf.with_columns( pl.when(pl.col("labels").is_first_distinct()) .then(pl.col("labels")) .otherwise(pl.lit(None)) @@ -462,28 +471,19 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): ) columns = [ - pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")).alias("c"), + pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")), pl.col("labels").cast(pl.String), - pl.when(pl.col("labels").is_not_null()).then(pl.lit(":\n")).alias(":"), + pl.when(pl.col("labels").is_not_null()).then(pl.lit(": ")), pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")), pl.col("coeffs").cast(pl.String), - pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")).alias("x"), + pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")), pl.col("vars").cast(pl.String), - "sign", + pl.col("sign"), pl.lit(" "), pl.col("rhs").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) - - # in the future, we could use lazy dataframes when they support appending - # tp existent files - # formatted = df.lazy().select(pl.concat_str(columns, ignore_nulls=True)) - # formatted.sink_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def to_lp_file_polars(m, fn, integer_label="general"): diff --git a/linopy/variables.py b/linopy/variables.py index 8c00e97a..03afd0ba 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -159,7 +159,6 @@ def __init__(self, data: Dataset, model: Any, name: str): self._model = model def __getitem__(self, selector) -> Union["Variable", "ScalarVariable"]: - keys = selector if isinstance(selector, tuple) else (selector,) if all(map(pd.api.types.is_scalar, keys)): warn( @@ -799,10 +798,10 @@ def to_polars(self) -> pl.DataFrame: ------- pl.DataFrame """ - df = to_polars(self.data) - df = filter_nulls_polars(df) - check_has_nulls_polars(df, name=f"{self.type} {self.name}") - return df + lf = to_polars(self.data) + lf = filter_nulls_polars(lf) + check_has_nulls_polars(lf, name=f"{self.type} {self.name}") + return lf def sum(self, dim=None, **kwargs): """ @@ -1020,7 +1019,6 @@ def __init__(self, obj): self.object = obj def __getitem__(self, keys) -> "ScalarVariable": - keys = keys if isinstance(keys, tuple) else (keys,) object = self.object diff --git a/pyproject.toml b/pyproject.toml index 116a62d8..7a7d00af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "xarray>=2023.9.0", "dask>=0.18.0", "polars", + "pyarrow", "tqdm", "deprecation", ] diff --git a/test/test_constraint.py b/test/test_constraint.py index 7866fc94..93b12bac 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -335,7 +335,7 @@ def test_constraint_flat(c): def test_constraint_to_polars(c): - assert isinstance(c.to_polars(), pl.DataFrame) + assert isinstance(c.to_polars(), pl.LazyFrame) def test_constraint_assignment_with_anonymous_constraints(m, x, y): diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 74c31fe7..07c519be 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -529,8 +529,8 @@ def test_linear_expression_to_polars(v): coeff = np.arange(1, 21) # use non-zero coefficients expr = coeff * v df = expr.to_polars() - assert isinstance(df, pl.DataFrame) - assert (df["coeffs"].to_numpy() == coeff).all() + assert isinstance(df, pl.LazyFrame) + assert (df.collect()["coeffs"].to_numpy() == coeff).all() def test_linear_expression_where(v): diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 2d70e621..c46907df 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -221,10 +221,10 @@ def test_quadratic_expression_flat(x, y): def test_linear_expression_to_polars(x, y): expr = x * y + x + 5 df = expr.to_polars() - assert isinstance(df, pl.DataFrame) + assert isinstance(df, pl.LazyFrame) assert "vars1" in df.columns assert "vars2" in df.columns - assert len(df) == expr.nterm * 2 + assert len(df.collect()) == expr.nterm * 2 def test_quadratic_expression_to_matrix(model, x, y): diff --git a/test/test_variable.py b/test/test_variable.py index df6fee28..b461ac2b 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -260,8 +260,8 @@ def test_variable_flat(x): def test_variable_polars(x): result = x.to_polars() - assert isinstance(result, pl.DataFrame) - assert len(result) == x.size + assert isinstance(result, pl.LazyFrame) + assert len(result.collect()) == x.size def test_variable_sanitize(x):