Skip to content

Commit

Permalink
rename_train_result
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored and TingquanGao committed Nov 14, 2024
1 parent 1cbc425 commit d339e2c
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions paddlets/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ def check_train_valid_continuity(train_data: TSDataset,
pd.to_timedelta(train_index.freq))
elif isinstance(train_index, pd.RangeIndex):
if isinstance(valid_index, pd.RangeIndex):
continuious = (
valid_index[0] - train_index[-1] == train_index.step)
continuious = (valid_index[0] - train_index[-1] == train_index.step)
else:
raise_log("Unsupport data index format")

Expand Down Expand Up @@ -324,8 +323,7 @@ def get_tsdataset_max_len(dataset: TSDataset) -> int:
return len(all_index)


def repr_results_to_tsdataset(reprs: np.array,
dataset: TSDataset) -> TSDataset:
def repr_results_to_tsdataset(reprs: np.array, dataset: TSDataset) -> TSDataset:
"""
Convert representation model output to a TSDataset
Expand Down Expand Up @@ -455,9 +453,8 @@ def build_ts_infer_input(tsdataset: TSDataset,
#build sample base on DataAdapter
data_adapter = DataAdapter()
if json_data['model_type'] == 'forecasting':
raise_if_not(
tsdataset.get_target() is not None,
"The target of tsdataset can not be None for forecasting!")
raise_if_not(tsdataset.get_target() is not None,
"The target of tsdataset can not be None for forecasting!")
size_keys = ['in_chunk_len', 'out_chunk_len', 'skip_chunk_len']
for key in size_keys:
raise_if_not(
Expand All @@ -480,8 +477,7 @@ def build_ts_infer_input(tsdataset: TSDataset,
raise_if_not(
key in json_data['size'],
f"The {key} in json_data['size'] can not be None for anomaly!")
dataset = data_adapter.to_sample_dataset(tsdataset,
**json_data['size'])
dataset = data_adapter.to_sample_dataset(tsdataset, **json_data['size'])
else:
raise_log(ValueError(f"Invalid model_type: {json_data['model_type']}"))

Expand Down Expand Up @@ -522,7 +518,7 @@ def convert_and_remove_types(data):

def update_train_results(save_path, score, model_name="", done_flag=True):

train_results_path = os.path.join(save_path, "train_results.json")
train_results_path = os.path.join(save_path, "train_result.json")
save_model_tag = ["pdparams", "pdopt", "pdstates", "pdema"]
save_inference_tag = [
"inference_config", "pdmodel", "pdiparams", "pdiparams.info"
Expand Down

0 comments on commit d339e2c

Please sign in to comment.