Skip to content

Commit

Permalink
remove sorting by shapes, revert test_networks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksei-grovety committed Dec 2, 2022
1 parent a73dd43 commit 30a4503
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 64 deletions.
43 changes: 0 additions & 43 deletions src/contrib/ethosu/cascader/pareto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,6 @@ std::vector<Plan> ParetoCullPlans(std::vector<Plan> plans, size_t max_plans,

std::sort(plans.begin(), plans.end(), [](const Plan& a, const Plan& b) -> bool {
if (a->GetMemoryUsage() == b->GetMemoryUsage()) {
if (a->GetCycles() == b->GetCycles()) {
// In the case of equal metrics compare stripe shapes.
if (auto result = CompareStripeShapes(a, b)) {
return result.value();
}
}
return a->GetCycles() < b->GetCycles();
}
return a->GetMemoryUsage() < b->GetMemoryUsage();
Expand Down Expand Up @@ -132,29 +126,6 @@ std::vector<Proposal> ParetoCullProposals(std::vector<Proposal> proposals, size_

std::sort(proposals.begin(), proposals.end(), [](const Proposal& a, const Proposal& b) -> bool {
if (a->GetMemoryUsage() == b->GetMemoryUsage()) {
if (a->GetCycles() == b->GetCycles()) {
auto plans_a = a->GetPlans();
auto plans_b = b->GetPlans();
const auto plans_size = std::min(plans_a.size(), plans_b.size());
auto comparison_function = [](const Plan& plan_a, const Plan& plan_b) -> bool {
if (plan_a->GetMemoryUsage() == plan_b->GetMemoryUsage()) {
return plan_a->GetCycles() < plan_b->GetCycles();
}
return plan_a->GetMemoryUsage() < plan_b->GetMemoryUsage();
};
// After ParetoCullPlans, there should be no variants with the same metrics, so the plans
// are sorted by metrics.
std::sort(plans_a.begin(), plans_a.end(), comparison_function);
std::sort(plans_b.begin(), plans_b.end(), comparison_function);

for (size_t i = 0; i < plans_size; i++) {
// In the case of equal metrics compare stripe shapes, if the plans have the same stripe
// shapes, then move on to the next pair.
if (auto result = CompareStripeShapes(plans_a.at(i), plans_b.at(i))) {
return result.value();
}
}
}
return a->GetCycles() < b->GetCycles();
}
return a->GetMemoryUsage() < b->GetMemoryUsage();
Expand All @@ -180,20 +151,6 @@ std::vector<Proposal> ParetoCullProposals(std::vector<Proposal> proposals, size_
return ThinVector(optimal_proposals, max_proposals);
}

std::optional<bool> CompareStripeShapes(const Plan& plan_a, const Plan& plan_b) {
const auto stripe_configs_a = plan_a->GetOutputConfig()->GetStripeConfigs();
const auto stripe_configs_b = plan_b->GetOutputConfig()->GetStripeConfigs();
const auto stripe_configs_size = std::min(stripe_configs_a.size(), stripe_configs_b.size());
for (size_t i = 0; i < stripe_configs_size; i++) {
const auto stripe_shape_a = stripe_configs_a.at(i)->GetShape();
const auto stripe_shape_b = stripe_configs_b.at(i)->GetShape();
if (stripe_shape_a != stripe_shape_b) {
return stripe_shape_a > stripe_shape_b;
}
}
return std::nullopt;
}

TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.GetParetoFrontier")
.set_body_typed([](Array<Array<FloatImm>> tcosts) {
std::vector<std::array<float, 2>> costs;
Expand Down
10 changes: 0 additions & 10 deletions src/contrib/ethosu/cascader/pareto.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <algorithm>
#include <array>
#include <optional>
#include <vector>

namespace tvm {
Expand Down Expand Up @@ -73,15 +72,6 @@ std::vector<Plan> ParetoCullPlans(std::vector<Plan> plans, size_t max_plans,
std::vector<Proposal> ParetoCullProposals(std::vector<Proposal> proposals, size_t max_proposals,
bool disable_pareto_metric);

/*!
* \brief Compare shapes of plan stripes.
* \param plan_a The first plan.
* \param plan_b The second plan.
* \return std::nullopt if stripe shapes are equals, true if plan_a stripe shape is bigger otherwise
* false.
*/
std::optional<bool> CompareStripeShapes(const Plan& plan_a, const Plan& plan_b);

} // namespace cascader
} // namespace ethosu
} // namespace contrib
Expand Down
22 changes: 11 additions & 11 deletions tests/python/contrib/test_ethosu/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def test_networks_with_usmp_and_cascader_wo_striping(accel_type, model_url, work
"accel_type, model_url, workspace_size",
[
# Checks the same test case multiple times to make sure its not flaky
("ethos-u55-256", MOBILENET_V1_URL, 1005312),
("ethos-u55-256", MOBILENET_V1_URL, 1005312),
("ethos-u55-256", MOBILENET_V1_URL, 1005312),
("ethos-u55-256", MOBILENET_V1_URL, 1005312),
("ethos-u55-256", MOBILENET_V1_URL, 1005312),
("ethos-u55-256", MOBILENET_V1_URL, 1010000),
("ethos-u55-256", MOBILENET_V1_URL, 1010000),
("ethos-u55-256", MOBILENET_V1_URL, 1010000),
("ethos-u55-256", MOBILENET_V1_URL, 1010000),
("ethos-u55-256", MOBILENET_V1_URL, 1010000),
# Checks the same test case multiple times to make sure its not flaky
("ethos-u55-256", MOBILENET_V2_URL, 1162448),
("ethos-u55-256", MOBILENET_V2_URL, 1162448),
("ethos-u55-256", MOBILENET_V2_URL, 1162448),
("ethos-u55-256", MOBILENET_V2_URL, 1162448),
("ethos-u55-256", MOBILENET_V2_URL, 1162448),
("ethos-u55-256", MOBILENET_V2_URL, 1400000),
("ethos-u55-256", MOBILENET_V2_URL, 1400000),
("ethos-u55-256", MOBILENET_V2_URL, 1400000),
("ethos-u55-256", MOBILENET_V2_URL, 1400000),
("ethos-u55-256", MOBILENET_V2_URL, 1400000),
],
)
def test_networks_with_usmp_and_cascader_with_striping(accel_type, model_url, workspace_size):
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_networks_with_usmp_and_cascader_with_striping(accel_type, model_url, wo
allocated_pool_info = list(
dict(compiled_models[0].executor_factory.executor_codegen_metadata.pool_inputs).values()
)[0]
assert allocated_pool_info.allocated_size == workspace_size
assert allocated_pool_info.allocated_size <= workspace_size


if __name__ == "__main__":
Expand Down

0 comments on commit 30a4503

Please sign in to comment.