Skip to content

Commit

Permalink
add hpi config (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Jan 14, 2025
1 parent aeb5b59 commit 3e2cb33
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
26 changes: 26 additions & 0 deletions paddlets/models/anomaly/dl/anomaly_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,32 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
dynamic_shape = list(model_meta["input_data"][
"observed_cov_numeric"])[-2:]
if dynamic_shape != [-1, -1]:
paddle_shapes = [[1] + dynamic_shape,
[1] + dynamic_shape,
[8] + dynamic_shape]
tensorrt_shapes = paddle_shapes
else:
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]
hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'observed_cov_numeric': paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'observed_cov_numeric': tensorrt_shapes
}
}
}
}
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down
40 changes: 40 additions & 0 deletions paddlets/models/classify/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,46 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
dynamic_shape = list(model_meta["input_data"]["features"])[
-2:]
pad_mask_shape = list(model_meta["input_data"]["pad_mask"])[
-1:]
if dynamic_shape != [-1, -1]:
paddle_shapes = [[1] + dynamic_shape,
[1] + dynamic_shape,
[8] + dynamic_shape]
tensorrt_shapes = paddle_shapes
else:
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]
if pad_mask_shape != [-1]:
pad_mask_paddle_shapes = [[1] + pad_mask_shape,
[1] + pad_mask_shape,
[8] + pad_mask_shape]
pad_mask_tensorrt_shapes = pad_mask_paddle_shapes
else:
pad_mask_shapes = [[1, 64], [1, 96]]
pad_mask_paddle_shapes = pad_mask_shapes + [[8, 192]]
pad_mask_tensorrt_shapes = pad_mask_shapes + [[8, 96]]

hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'features': paddle_shapes,
'pad_mask': pad_mask_paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'features': tensorrt_shapes,
'pad_mask': pad_mask_tensorrt_shapes
}
}
}
}
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down
45 changes: 45 additions & 0 deletions paddlets/models/forecasting/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,51 @@ def save(self,
model_meta.update(data_info)
if model_name is not None:
model_meta['Global'] = {'model_name': model_name}
dynamic_shape = list(model_meta["input_data"][
"past_target"])[-2:]
if dynamic_shape != [-1, -1]:
paddle_shapes = [[1] + dynamic_shape,
[1] + dynamic_shape,
[8] + dynamic_shape]
tensorrt_shapes = paddle_shapes
else:
shapes = [[1, 64, 1], [1, 96, 5]]
paddle_shapes = shapes + [[8, 192, 20]]
tensorrt_shapes = shapes + [[8, 96, 20]]

hpi_config = {
'backend_configs': {
'paddle_infer': {
'trt_dynamic_shapes': {
'past_target': paddle_shapes
}
},
'tensorrt': {
'dynamic_shapes': {
'past_target': tensorrt_shapes
}
}
}
}
if "known_cov_numeric" in model_meta["input_data"]:
known_cov_numeric_dynamic_shape = list(model_meta[
"input_data"]["known_cov_numeric"])[-2:]
if known_cov_numeric_dynamic_shape != [-1, -1]:
known_cov_numeric_shape = [
[1] + known_cov_numeric_dynamic_shape,
[1] + known_cov_numeric_dynamic_shape,
[8] + known_cov_numeric_dynamic_shape
]
else:
known_cov_numeric_shape = [[1, 64, 4], [1, 96, 10],
[8, 192, 30]]
hpi_config["backend_configs"]["paddle_infer"][
"trt_dynamic_shapes"][
"known_cov_numeric"] = known_cov_numeric_shape
hpi_config["backend_configs"]["tensorrt"][
"dynamic_shapes"][
"known_cov_numeric"] = known_cov_numeric_shape
model_meta['Hpi'] = hpi_config
model_meta = convert_and_remove_types(model_meta)
yaml.dump(model_meta, f)
except Exception as e:
Expand Down

0 comments on commit 3e2cb33

Please sign in to comment.