Skip to content

Commit

Permalink
Remove current usage of #include <regex>
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Feb 8, 2024
1 parent 75a546d commit ed1ac32
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 47 deletions.
9 changes: 2 additions & 7 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <unordered_set>

#include "../runtime/object_internal.h"
#include "../support/regex.h"

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
.str();

auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule {
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
CHECK(regex_match_func)
<< "RuntimeError: "
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
<< "This can occur if the TVM Python library has not yet been imported.";

IRModule subset;

for (const auto& [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if ((*regex_match_func)(func_name_regex, name)) {
if (tvm::support::regex_match(name, func_name_regex)) {
subset->Add(gvar, func);
}
}
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../utils.h"
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/contrib/dnnl/query_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
#include "../../../../support/regex.h"
#include "../../utils.h"
#include "dnnl.hpp"
namespace tvm {
Expand Down Expand Up @@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
}

void check_shapes(const std::vector<std::string> shapes) {
std::regex valid_pat("(\\d*)(,(\\d*))*");
bool checked = std::regex_match(shapes[0], valid_pat);
std::string valid_pat("(\\d*)(,(\\d*))*");
bool checked = tvm::support::regex_match(shapes[0], valid_pat);
for (size_t i = 1; i < shapes.size() - 1; i++) {
checked &= std::regex_match(shapes[i], valid_pat);
checked &= tvm::support::regex_match(shapes[i], valid_pat);
}
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
checked &= tvm::support::regex_match(shapes[shapes.size() - 1], "\\d*");
if (!checked) {
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
}
Expand All @@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
std::string dilates, std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_layout(tvm::support::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::support::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down Expand Up @@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
check_layout(tvm::support::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::support::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down
62 changes: 32 additions & 30 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <string>
#include <vector>

#include "../../../support/regex.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"

Expand Down Expand Up @@ -194,53 +195,54 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");
std::regex mish_pat(".*_mish.*");
std::string bias_add_pat(".*_bias.*");
std::string relu_pat(".*_relu.*");
std::string tanh_pat(".*_tanh.*");
std::string sigmoid_pat(".*_sigmoid.*");
std::string clip_pat(".*_clip.*");
std::string gelu_pat(".*_gelu.*");
std::string swish_pat(".*_swish.*");
std::string sum_pat(".*_sum.*");
std::string mish_pat(".*_mish.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (std::regex_match(op_name, sum_pat)) {
if (tvm::support::regex_match(op_name, sum_pat)) {
ops.append_sum(1.f);
}
if (std::regex_match(op_name, relu_pat)) {
if (tvm::support::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
if (std::regex_match(op_name, tanh_pat)) {
if (tvm::support::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (std::regex_match(op_name, clip_pat)) {
if (tvm::support::regex_match(op_name, clip_pat)) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
if (tvm::support::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
if (std::regex_match(op_name, swish_pat)) {
if (tvm::support::regex_match(op_name, swish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
}
if (std::regex_match(op_name, gelu_pat)) {
if (tvm::support::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (std::regex_match(op_name, mish_pat)) {
if (tvm::support::regex_match(op_name, mish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
*bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
*bias_tr =
tvm::support::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};

return attr;
}
Expand All @@ -253,31 +255,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
tensor_registry_ = TensorRegistry(engine_, io_eid_set);

std::regex conv_pat(".*conv[1-3]d.*");
std::regex deconv_pat(".*deconv[1-3]d.*");
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::regex dense_pat(".*dense.*");
std::regex max_pool_pat(".*max_pool[1-3]d");
std::regex avg_pool_pat(".*avg_pool[1-3]d");
std::string conv_pat(".*conv[1-3]d.*");
std::string deconv_pat(".*deconv[1-3]d.*");
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::string dense_pat(".*dense.*");
std::string max_pool_pat(".*max_pool[1-3]d");
std::string avg_pool_pat(".*avg_pool[1-3]d");

// Build subgraph engine.
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
ICHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if (std::regex_match(op_name, deconv_pat) ||
std::regex_match(op_name, conv_transpose_pat)) {
if (tvm::support::regex_match(op_name, deconv_pat) ||
tvm::support::regex_match(op_name, conv_transpose_pat)) {
Deconvolution(nid);
} else if (std::regex_match(op_name, conv_pat)) {
} else if (tvm::support::regex_match(op_name, conv_pat)) {
Convolution(nid);
} else if (std::regex_match(op_name, dense_pat)) {
} else if (tvm::support::regex_match(op_name, dense_pat)) {
Dense(nid);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if (std::regex_match(op_name, max_pool_pat)) {
} else if (tvm::support::regex_match(op_name, max_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_max);
} else if (std::regex_match(op_name, avg_pool_pat)) {
} else if (tvm::support::regex_match(op_name, avg_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_avg);
} else if (elt_name2algo.count(op_name)) {
Eltwise(nid);
Expand Down
41 changes: 41 additions & 0 deletions src/support/regex.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/support/regex.cc
* \brief Exposes calls to python's `re` library.
*/

#include "./regex.h"

#include <tvm/runtime/registry.h>

namespace tvm {
namespace support {

bool regex_match(const std::string& match_against, const std::string& regex_pattern) {
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
CHECK(regex_match_func) << "RuntimeError: "
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
<< "This can occur if the TVM Python library has not yet been imported.";
return (*regex_match_func)(regex_pattern, match_against);
}

} // namespace support
} // namespace tvm
64 changes: 64 additions & 0 deletions src/support/regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file regex.h
* \brief Exposes calls to python's `re` library.
*/
#ifndef TVM_SUPPORT_REGEX_H_
#define TVM_SUPPORT_REGEX_H_

#include <string>

namespace tvm {
namespace support {

/* \brief Check if a pattern matches a regular expression
*
* This function should be used instead of `std::regex` within C++
* call sites, to avoid ABI incompatibilities with pytorch.
*
* Currently, the pytorch wheels available through pip install use
* the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
* user the pre-C++11 ABI, this would cause breakages with
* dynamically-linked LLVM environments.
*
* Use of the `<regex>` header in TVM should be avoided, as its
* implementation is not supported by gcc's dual ABI. This ABI
* incompatibility results in runtime errors either when `std::regex`
* is called from TVM, or when `std::regex` is called from pytorch,
* depending on which library was loaded first. This restriction can
* be removed when a version of pytorch compiled using
* `-DUSE_CXX11_ABI=1` is available from PyPI.
*
* [0] https://github.com/pytorch/pytorch/issues/51039
*
* \param match_against The string against which to match the regular expression
*
* \param regex_pattern The regular expression
*
* \returns match_result True if `match_against` matches the pattern
* defined by `regex_pattern`, and False otherwise.
*/

bool regex_match(const std::string& match_against, const std::string& regex_pattern);

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_REGEX_H_

0 comments on commit ed1ac32

Please sign in to comment.