Skip to content

Commit

Permalink
add hpi config
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Jan 7, 2025
1 parent 91885bc commit 31b0874
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
18 changes: 18 additions & 0 deletions paddlets/models/anomaly/dl/anomaly_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,24 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]
hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'observed_cov_numeric': paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'observed_cov_numeric': tensorrt_shapes
}
}
}
}
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down
19 changes: 19 additions & 0 deletions paddlets/models/classify/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,25 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]

hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'features': paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'features': tensorrt_shapes
}
}
}
}
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down
29 changes: 29 additions & 0 deletions paddlets/models/forecasting/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def save(self,
if network_model:
self._network.eval()
input_spec = build_network_input_spec(model_meta)
print("******\n" * 10, model_meta)
try:
if not os.path.os.path.exists(abs_root_path):
os.makedirs(abs_root_path)
Expand All @@ -186,6 +187,34 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]

hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'past_target': paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'past_target': tensorrt_shapes
}
}
}
}
if "known_cov_numeric" in model_meta["input_data"]:
known_cov_numeric_shape = [[1, 64, 4], [1, 96, 10],
[8, 192, 30]]
hpi_config["backend_configs"]["paddle_infer"][
"trt_dynamic_shapes"][
"known_cov_numeric"] = known_cov_numeric_shape
hpi_config["backend_configs"]["tensorrt"][
"dynamic_shapes"][
"known_cov_numeric"] = known_cov_numeric_shape
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down

0 comments on commit 31b0874

Please sign in to comment.