-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add consise dataframe * use right file * minimal routine * expand test --------- Co-authored-by: J.R. Angevaare <[email protected]>
- Loading branch information
1 parent
85613e3
commit a8086fc
Showing
4 changed files
with
139 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from tqdm.notebook import tqdm | ||
import pandas as pd | ||
import numpy as np | ||
import optim_esm_tools as oet | ||
import typing as ty | ||
|
||
|
||
class ConciseDataFrame: | ||
delimiter = ', ' | ||
merge_postfix = '(s)' | ||
|
||
def __init__( | ||
self, | ||
df: pd.DataFrame, | ||
group: ty.Iterable = None, | ||
tqdm: bool = False, | ||
match_overlap: bool = True, | ||
min_frac_overlap: float = 0.33, | ||
): | ||
# important to sort by tips == True first! As in match_rows there is a line that assumes | ||
# that all tipping rows are already merged! | ||
self.df = df.copy().sort_values( | ||
by=['tips', 'institution_id', 'source_id', 'experiment_id'], ascending=False | ||
) | ||
self.group = group or set(self.df.colums) - set( | ||
[ | ||
'institution_id', | ||
'source_id', | ||
'experiment_id', | ||
] | ||
) | ||
self.match_overlap = match_overlap | ||
self.tqdm = tqdm | ||
self.min_frac_overlap = min_frac_overlap | ||
|
||
def concise(self) -> pd.DataFrame: | ||
rows = [row.to_dict() for _, row in self.df.iterrows()] | ||
matched_rows = self.match_rows(rows) | ||
combined_rows = [self.combine_rows(r, self.delimiter) for r in matched_rows] | ||
df_ret = pd.DataFrame(combined_rows) | ||
return self.rename_columns_with_plural(df_ret) | ||
|
||
def rename_columns_with_plural(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Add postfix to columns from the dataframe""" | ||
rename_dict = {k: f'{k}{self.merge_postfix}' for k in self.group} | ||
return df.rename(columns=rename_dict) | ||
|
||
@staticmethod | ||
def combine_rows(rows: ty.Mapping, delimiter: str) -> ty.Dict[str, str]: | ||
ret = {} | ||
for k in rows[0].keys(): | ||
val = sorted(list(set(r[k] for r in rows))) | ||
if len(val) == 1: | ||
ret[k] = val[0] | ||
else: | ||
ret[k] = delimiter.join([str(v) for v in val]) | ||
return ret | ||
|
||
def match_rows(self, rows: ty.Mapping) -> ty.List[ty.Mapping]: | ||
groups = [] | ||
for row in oet.utils.tqdm(rows, desc='rows', disable=not self.tqdm): | ||
if any(row in g for g in groups): | ||
continue | ||
|
||
groups.append([row]) | ||
for other_row in rows: | ||
if row == other_row: | ||
continue | ||
for k, v in row.items(): | ||
if k in self.group: | ||
continue | ||
if other_row.get(k) != v: | ||
break | ||
else: | ||
if (not self.match_overlap) or ( | ||
any( | ||
self.overlaps_enough(r['path'], other_row['path']) | ||
for r in groups[-1] | ||
if r['tips'] | ||
) | ||
): | ||
groups[-1].append(other_row) | ||
return groups | ||
|
||
@staticmethod | ||
def overlaps_percent(ds1, ds2, use_field='global_mask'): | ||
arr1 = ds1[use_field].values | ||
arr2 = ds2[use_field].values | ||
return np.sum(arr1 & arr2) / min(np.sum(arr1), np.sum(arr2)) | ||
|
||
def overlaps_enough(self, path1, path2): | ||
return ( | ||
self.overlaps_percent(oet.load_glob(path1), oet.load_glob(path2)) | ||
>= self.min_frac_overlap | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import tempfile | ||
import optim_esm_tools as oet | ||
import pandas as pd | ||
import os | ||
from unittest import TestCase | ||
import numpy as np | ||
|
||
|
||
class TestConsiseDataFrame(TestCase): | ||
def test_merge_two(self, nx=4, ny=3, is_match=(True, True)): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
kw = dict(len_x=nx, len_y=ny, len_time=2, add_nans=False) | ||
names = list('abcdefg'[: len(is_match)]) | ||
paths = [os.path.join(temp_dir, f'{x}.nc') for x in names] | ||
for path in paths: | ||
ds = oet._test_utils.minimal_xr_ds(**kw) | ||
print(ds['var'].shape, ds['var'].dims) | ||
ds['global_mask'] = ( | ||
oet.config.config['analyze']['lon_lat_dim'].split(','), | ||
np.ones((nx, ny), bool), | ||
) | ||
ds.to_netcdf(path) | ||
_same = ['same'] * len(names) | ||
data_frame = pd.DataFrame( | ||
dict( | ||
path=paths, | ||
names=names, | ||
tips=[True] * len(is_match), | ||
institution_id=_same, | ||
source_id=_same, | ||
experiment_id=_same, | ||
is_match=is_match, | ||
) | ||
) | ||
concise_df = oet.analyze.concise_dataframe.ConciseDataFrame( | ||
data_frame, group=('path', 'names') | ||
).concise() | ||
assert len(concise_df) == len(np.unique(is_match)) | ||
|
||
def test_merge_three(self): | ||
return self.test_merge_two(is_match=(True, True, False)) |