Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Few fixes for broken Adreno docs #17518

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
==========================================================
**Author**: Siva Rama Krishna

This article is a step-by-step tutorial to deploy pretrained Keras resnet50 model on Adreno™.
This article is a step-by-step tutorial to deploy pretrained PyTorch resnet50 model on Adreno™.

Besides that, you should have TVM built for Android.
See the following instructions on how to build it and setup RPC environment.
Expand Down Expand Up @@ -71,16 +71,27 @@
)

#######################################################################
# Make a Keras Resnet50 Model
# ---------------------------
# Make a PyTorch Resnet50 Model
# -----------------------------

from tensorflow.keras.applications.resnet50 import ResNet50
import torch
import torchvision.models as models

tmp_path = utils.tempdir()
model_file_name = tmp_path.relpath("resnet50.h5")
# Load the ResNet50 model pre-trained on ImageNet
model = models.resnet50(pretrained=True)

model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000)
model.save(model_file_name)
# Set the model to evaluation mode
model.eval()

# Define the input shape
dummy_input = torch.randn(1, 3, 224, 224)

# Trace the model
traced_model = torch.jit.trace(model, dummy_input)

# Save the traced model
model_file_name = "resnet50_traced.pt"
traced_model.save(model_file_name)


#######################################################################
Expand All @@ -89,7 +100,10 @@
# Convert a model from any framework to a tvm relay module.
# tvmc.load supports models from any framework (like tensorflow saves_model, onnx, tflite ..etc) and auto detects the filetype.

tvmc_model = tvmc.load(model_file_name)
input_shape = (1, 3, 224, 224) # Batch size, channels, height, width

# Load the TorchScript model with TVMC
tvmc_model = tvmc.load(model_file_name, shape_dict={"input": input_shape}, model_format="pytorch")

print(tvmc_model.mod)

Expand Down Expand Up @@ -158,7 +172,7 @@
# Altrernatively, we can save the compilation output and save it as a TVMCPackage.
# This way avoids loading of compiled module without compiling again.
target = target + ", clml"
pkg_path = tmp_path.relpath("keras-resnet50.tar")
pkg_path = tmp_path.relpath("torch-resnet50.tar")
tvmc.compile(
tvmc_model,
target=target,
Expand Down
Loading