Skip to content

Commit

Permalink
[ONNX] Code to dump all nets into ONNX for trtexec eval
Browse files Browse the repository at this point in the history
  • Loading branch information
MadFunMaker committed Apr 5, 2022
1 parent a81fb3e commit ed16b4e
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions demo/workloads/torch_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"mobilenet_v2": mobilenet_v2,
"dcgan": DCGAN,
"yolov3": YoloV3,

# Multiple input models
# Note that it requires some code changes when evaluating performance
"gpt2":get_gpt2_model,
Expand Down Expand Up @@ -89,7 +89,7 @@ def load_torch_model_from_code(name, batch_size):
model = NETWORK_TO_TORCH_MODEL[name]() # .cuda()

model.eval()


# print(f"Input data: {input_shape}")
scripted_model = torch.jit.trace(model.cpu(), input_data).eval()
Expand Down Expand Up @@ -138,13 +138,13 @@ def export_onnx_network_from_torch(name, batch_size):
model = NETWORK_TO_TORCH_MODEL[name]() # .cuda()

model.eval()

torch.onnx.export(model, # model being run
input_data, # model input (or a tuple for multiple inputs)
f"{name}.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for
do_constant_folding=True, # whether to execute constant folding for
)


Expand All @@ -157,4 +157,8 @@ def crop_network_from_torch(name, batch_size, post_dfs_order):


if __name__ == "__main__":
export_onnx_network_from_torch("resnext50_32x4d", 1)
export_onnx_network_from_torch("resnext50_32x4d", 1)
export_onnx_network_from_torch("bert_full", 1)
export_onnx_network_from_torch("resnet50_3d", 1)
export_onnx_network_from_torch("nasneta", 1)
export_onnx_network_from_torch("dcgan", 1)

0 comments on commit ed16b4e

Please sign in to comment.