Skip to content

Commit

Permalink
GH-178: Allow saving PDT clone with subset of columns
Browse files Browse the repository at this point in the history
  • Loading branch information
onyb committed May 14, 2021
1 parent 5a492ac commit 5e91c26
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 6 deletions.
14 changes: 14 additions & 0 deletions core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,20 @@ def save_operation():

f.write(text.lstrip())

loader = load_point_data_by_path(pdt_path, cheaper=cheaper)

if pdt_path.endswith(".ascii"):
ext = "ascii"
elif pdt_path.endswith(".parquet"):
ext = "parquet"
else:
ext = "ascii"

exclude_cols = payload["excludePredictors"]
cols = [col for col in loader.columns if col not in exclude_cols]

loader.clone(*cols, path=output_path / f"PDT.{ext}")

return Response(json.dumps({}), mimetype="application/json")


Expand Down
9 changes: 6 additions & 3 deletions core/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List, Union

import pandas as pd
Expand Down Expand Up @@ -55,9 +56,7 @@ def units(self) -> dict:
obs = {m.group(1): m.group(2)} if m else {}

return {
"predictors": {
k: v.replace("NoUnit", "-") for k, v in predictors.items()
},
"predictors": {k: v.replace("NoUnit", "-") for k, v in predictors.items()},
"observations": obs,
"predictand": predictand,
}
Expand All @@ -71,6 +70,10 @@ def columns(self) -> List[str]:
def select(self, *args: str, series: bool = True) -> Union[pd.DataFrame, pd.Series]:
raise NotImplementedError

@abc.abstractmethod
def clone(self, *args: str, path: Path):
raise NotImplementedError

@property
def error_type(self) -> ErrorType:
"""
Expand Down
9 changes: 9 additions & 0 deletions core/loaders/ascii.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from functools import partial
from itertools import takewhile
from pathlib import Path
from typing import List, Optional, Union

import attr
Expand Down Expand Up @@ -86,6 +87,14 @@ def select(self, *args: str, series: bool = True) -> Union[pd.DataFrame, pd.Seri
result = result[col]
return result

def clone(self, *args: str, path: Path):
encoder = ASCIIEncoder(path=path)
encoder.add_header(self.metadata.get("header", ""))

for chunk in self:
filtered_chunk = chunk[list(args)]
encoder.add_columns_chunk(filtered_chunk.to_dict())

def __iter__(self) -> "ASCIIDecoder":
self._current_csv_offset = 0
return self
Expand Down
13 changes: 13 additions & 0 deletions core/loaders/parquet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union

import attr
Expand Down Expand Up @@ -107,6 +108,8 @@ def columns(self) -> List[str]:
def dataframe(self) -> pd.DataFrame:
if self._dataframe is None:
self._dataframe = pd.read_parquet(self.path, engine="pyarrow")
if "__index_level_0__" in self._dataframe:
self._dataframe = self._dataframe.drop(["__index_level_0__"], axis=1)

return self._dataframe

Expand All @@ -127,6 +130,16 @@ def select(self, *args: str, series: bool = True) -> Union[pd.DataFrame, pd.Seri
result = result[col]
return result

def clone(self, *args: str, path: Path):
encoder = ParquetPointDataTableWriter(path=path)
encoder.add_header(self.metadata.get("header", ""))

for chunk in self:
filtered_chunk = chunk[list(args)]
encoder.add_columns_chunk(filtered_chunk.to_dict())

encoder.add_footer(self.metadata.get("footer", ""))

def __iter__(self) -> "ParquetPointDataTableReader":
self._current_row_group = 0
return self
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def test_alfa(client, alfa_cassette, alfa_loader, fmt, tmp_path):
path = tmp_path / f"pdt.{fmt.lower()}"
request = alfa_cassette(output_path=str(path), fmt=fmt)
response = client.post("/computation-logs", json=request)
response = client.post("/computations/start", json=request)
assert response.status_code == 200

got_loader = load_point_data_by_path(path=str(path))
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/loaders/test_ascii.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy
from pandas.testing import assert_frame_equal

from core.loaders.ascii import ASCIIDecoder
from tests.conf import TEST_DATA_DIR
Expand Down Expand Up @@ -31,3 +32,16 @@ def test_alfa_units():
},
"observations": {"tp": "mm"},
}


def test_good_ascii_file_clone(tmp_path):
path = TEST_DATA_DIR / "good_ascii_file.ascii"
data = ASCIIDecoder(path=path)

exclude_cols = ["TP", "CAPE"]
cloned_path = tmp_path / "good_ascii_file.ascii"
cols = [col for col in data.columns if col not in exclude_cols]
data.clone(*cols, path=cloned_path)
cloned_data = ASCIIDecoder(path=cloned_path)

assert_frame_equal(cloned_data.dataframe, data.dataframe.drop(exclude_cols, axis=1))
25 changes: 25 additions & 0 deletions tests/unit/loaders/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,28 @@ def test_good_parquet_file():
assert df.memory_usage(deep=True).sum() > df_pq.memory_usage(deep=True).sum()

assert_frame_equal(df_pq, df, check_dtype=False, check_categorical=False)


def test_good_parquet_file_clone(tmp_path):
path = TEST_DATA_DIR / "good_parquet.ascii"
df = ASCIIDecoder(path=path).dataframe

with NamedTemporaryFile() as f:
w = ParquetPointDataTableWriter(f.name)
w.add_columns_chunk(df.copy())
w.close()

r = ParquetPointDataTableReader(f.name)
exclude_cols = ["tp_acc", "cape_wa"]
cloned_path = tmp_path / "good_parquet.parquet"
cols = [col for col in r.columns if col not in exclude_cols]
r.clone(*cols, path=cloned_path)

cloned_data = ParquetPointDataTableReader(cloned_path)

assert_frame_equal(
cloned_data.dataframe,
df.drop(exclude_cols, axis=1),
check_dtype=False,
check_categorical=False,
)
4 changes: 4 additions & 0 deletions ui/redux/postprocessingReducer.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ export default (state = defaultState, action) => {
}
}

case 'POSTPROCESSING.SET_EXCLUDED_PREDICTORS': {
return { ...state, excludedPredictors: action.data }
}

case 'POSTPROCESSING.SET_Y_LIM': {
return { ...state, yLim: action.value }
}
Expand Down
5 changes: 5 additions & 0 deletions ui/workflows/C/2/levels/actions.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ export const setFields = fields => ({
data: fields,
})

export const setExcludedPredictors = items => ({
type: 'POSTPROCESSING.SET_EXCLUDED_PREDICTORS',
data: items,
})

export const onFieldsSortEnd = (
fields,
thrGridIn,
Expand Down
13 changes: 12 additions & 1 deletion ui/workflows/C/2/levels/component.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ const SortableItem = SortableElement(({ value, showDelete, onDelete }) => (
))

const SortableList = SortableContainer(
({ items, breakpoints, labels, fieldRanges, setFields, setBreakpoints }) => (
({
items,
breakpoints,
labels,
fieldRanges,
setFields,
setBreakpoints,
setExcludedPredictors,
}) => (
<Segment.Group raised size="mini" style={{ width: '20%' }}>
{items.map((value, index) => (
<SortableItem
Expand All @@ -41,6 +49,8 @@ const SortableList = SortableContainer(
.map(row => row.slice(0, -2))
const newLabels = labels.slice(0, -2)

const excludePredictors = labels.slice(-2)[0].replace('_thrL', '')
setExcludedPredictors(excludePredictors)
setBreakpoints(newLabels, matrix, fieldRanges)
}}
/>
Expand All @@ -57,6 +67,7 @@ const Levels = props => (
items={props.fields}
setFields={props.setFields}
setBreakpoints={props.setBreakpoints}
setExcludedPredictors={props.setExcludedPredictors}
fieldRanges={props.fieldRanges}
breakpoints={props.thrGridOut}
labels={props.labels}
Expand Down
4 changes: 3 additions & 1 deletion ui/workflows/C/2/levels/container.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { connect } from 'react-redux'

import Levels from './component'

import { onFieldsSortEnd, setFields } from './actions'
import { onFieldsSortEnd, setFields, setExcludedPredictors } from './actions'
import { setBreakpoints } from '../breakpoints/actions'

const mapStateToProps = state => ({
Expand All @@ -23,6 +23,8 @@ const mapDispatchToProps = dispatch => ({

setBreakpoints: (labels, matrix, fieldRanges) =>
dispatch(setBreakpoints(labels, matrix, fieldRanges)),

setExcludedPredictors: items => dispatch(setExcludedPredictors(items)),
})

export default connect(
Expand Down
1 change: 1 addition & 0 deletions ui/workflows/C/2/saveOperation/component.js
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class SaveOperation extends Component {
cheaper: this.props.cheaper,
mode: this.props.mode,
fieldRanges: this.props.fieldRanges,
excludePredictors: this.props.excludedPredictors,
breakpointsCSV:
this.props.mode === 'breakpoints' || this.props.mode === 'all'
? this.getBreakpointsCSV()
Expand Down
1 change: 1 addition & 0 deletions ui/workflows/C/2/saveOperation/container.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const mapStateToProps = state => ({
yLim: state.postprocessing.yLim,
cheaper: state.preloader.cheaper,
breakpoints: state.postprocessing.thrGridOut,
excludedPredictors: state.postprocessing.excludedPredictors,
path: state.preloader.path,
labels: state.postprocessing.thrGridIn[0].slice(1).map(cell => cell.value),
error: state.binning.error,
Expand Down

0 comments on commit 5e91c26

Please sign in to comment.