Skip to content

Commit

Permalink
extractor: add explicit numerical columns tconvert_columns_to_category
Browse files Browse the repository at this point in the history
  • Loading branch information
awillecke committed Jun 3, 2024
1 parent ed783fd commit 1159513
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def apply_tags(data, tags, base_tags=None, additional_tags=[], minimal=True):


@staticmethod
def convert_columns_to_category(data, additional_columns:list = [], excluded_columns:set = {}):
def convert_columns_to_category(data, additional_columns:list = [], excluded_columns:set = {}, numerical_columns:set = {}):
excluded_columns = set(excluded_columns).union(DEFAULT_CATEGORICALS_COLUMN_EXCLUSION_SET)

col_list = []
Expand All @@ -265,11 +265,16 @@ def convert_columns_to_category(data, additional_columns:list = [], excluded_col
if s < threshold:
col_list.append(col)

logd(f"{excluded_columns}=")
logd(f"{col_list}=")
# convert selected columns to Categorical
for col in col_list:
data[col] = data[col].astype('category')
data[col] = data[col].cat.as_ordered()

for col in numerical_columns:
data[col] = data[col].astype('float')

return data


Expand Down
2 changes: 1 addition & 1 deletion plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def read_data(self):
data_list = list(map(dask.delayed(functools.partial(read_from_file, sample=self.sample, sample_seed=self.sample_seed, filter_query=self.filter_query))
, data_set.get_file_list()))
concat_result = dask.delayed(pd.concat)(data_list)
convert_columns_result = dask.delayed(RawExtractor.convert_columns_to_category)(concat_result, excluded_columns=self.numerical_columns)
convert_columns_result = dask.delayed(RawExtractor.convert_columns_to_category)(concat_result, numerical_columns=self.numerical_columns)
logd(f'PlottingReaderFeather::read_data: {data_list=}')
logd(f'PlottingReaderFeather::read_data: {convert_columns_result=}')
# d = dask.compute(convert_columns_result)
Expand Down

0 comments on commit 1159513

Please sign in to comment.