Skip to content

Commit

Permalink
[Target][rocm] Replace rocm arch parsing from int to string (#15088)
Browse files Browse the repository at this point in the history
* fix rocm arch parse issue,

* recover some code from target_kind.cc

* code reformat

* format

---------

Co-authored-by: leiwang1999 <[email protected]>
  • Loading branch information
LeiWang1999 and LeiWang1999 authored Jun 15, 2023
1 parent 90b5acc commit 02136b3
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ static int ExtractIntWithPrefix(const std::string& str, const std::string& prefi
return result;
}

/*!
* \brief Extract a string from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
* This function first checks if `str` starts with `prefix`,
* then return the integer 20 after the `prefix`
* \param str The string to be extracted
* \param prefix The prefix to be checked
* \return A string, the extracted string. "" if the check fails
*/
std::string ExtractStringWithPrefix(const std::string& str, const std::string& prefix) {
if (str.find(prefix) != 0) return "";
std::size_t pos = prefix.length();
while (pos < str.length() && (std::isdigit(str[pos]) || std::isalpha(str[pos]))) {
++pos;
}
return str.substr(prefix.length(), pos - prefix.length());
}

/*!
* \brief Using TVM DeviceAPI to detect the device flag
* \param device The device to be detected
Expand Down Expand Up @@ -206,20 +224,20 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
TargetJSON UpdateROCmAttrs(TargetJSON target) {
CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
int arch;
std::string arch;
if (target.count("mcpu")) {
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "gfx");
ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
arch = ExtractStringWithPrefix(mcpu, "gfx");
ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
} else {
TVMRetValue val;
if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead";
arch = 900;
arch = "900";
} else {
arch = val.operator int();
arch = val.operator std::string();
}
target.Set("mcpu", String("gfx") + std::to_string(arch));
target.Set("mcpu", String("gfx") + arch);
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
Expand Down

0 comments on commit 02136b3

Please sign in to comment.