Skip to content

Commit

Permalink
Reland "[CUDA][HIP] Fix overloading resolution in global var init" (l…
Browse files Browse the repository at this point in the history
…lvm#65606)

https://reviews.llvm.org/D158247 caused regressions for HIP on Windows
and was reverted.

A reduced test case is:

```
typedef void (__stdcall* funcTy)();
void invoke(funcTy f);

static void __stdcall callee() noexcept {
}

void foo() {
   invoke(callee);
}
```

It is due to clang missing handling host/device attributes for calling
convention at a few places

This patch fixes that.
  • Loading branch information
yxsamliu authored Sep 8, 2023
1 parent 4d2536c commit 9b77638
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 93 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ pythonenv*
/clang/utils/analyzer/projects/*/RefScanBuildResults
# automodapi puts generated documentation files here.
/lldb/docs/python_api/
/Debug/
46 changes: 37 additions & 9 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,14 @@ class Sema final {
}
} DelayedDiagnostics;

enum CUDAFunctionTarget {
CFT_Device,
CFT_Global,
CFT_Host,
CFT_HostDevice,
CFT_InvalidTarget
};

/// A RAII object to temporarily push a declaration context.
class ContextRAII {
private:
Expand Down Expand Up @@ -4753,8 +4761,13 @@ class Sema final {
bool isValidPointerAttrType(QualType T, bool RefOkay = false);

bool CheckRegparmAttr(const ParsedAttr &attr, unsigned &value);

/// Check validaty of calling convention attribute \p attr. If \p FD
/// is not null pointer, use \p FD to determine the CUDA/HIP host/device
/// target. Otherwise, it is specified by \p CFT.
bool CheckCallingConvAttr(const ParsedAttr &attr, CallingConv &CC,
const FunctionDecl *FD = nullptr);
const FunctionDecl *FD = nullptr,
CUDAFunctionTarget CFT = CFT_InvalidTarget);
bool CheckAttrTarget(const ParsedAttr &CurrAttr);
bool CheckAttrNoArgs(const ParsedAttr &CurrAttr);
bool checkStringLiteralArgumentAttr(const AttributeCommonInfo &CI,
Expand Down Expand Up @@ -13266,14 +13279,6 @@ class Sema final {
void checkTypeSupport(QualType Ty, SourceLocation Loc,
ValueDecl *D = nullptr);

enum CUDAFunctionTarget {
CFT_Device,
CFT_Global,
CFT_Host,
CFT_HostDevice,
CFT_InvalidTarget
};

/// Determines whether the given function is a CUDA device/host/kernel/etc.
/// function.
///
Expand All @@ -13292,6 +13297,29 @@ class Sema final {
/// Determines whether the given variable is emitted on host or device side.
CUDAVariableTarget IdentifyCUDATarget(const VarDecl *D);

/// Defines kinds of CUDA global host/device context where a function may be
/// called.
enum CUDATargetContextKind {
CTCK_Unknown, /// Unknown context
CTCK_InitGlobalVar, /// Function called during global variable
/// initialization
};

/// Define the current global CUDA host/device context where a function may be
/// called. Only used when a function is called outside of any functions.
struct CUDATargetContext {
CUDAFunctionTarget Target = CFT_HostDevice;
CUDATargetContextKind Kind = CTCK_Unknown;
Decl *D = nullptr;
} CurCUDATargetCtx;

struct CUDATargetContextRAII {
Sema &S;
CUDATargetContext SavedCtx;
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
};

/// Gets the CUDA target for the current context.
CUDAFunctionTarget CurrentCUDATarget() {
return IdentifyCUDATarget(dyn_cast<FunctionDecl>(CurContext));
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2571,6 +2571,7 @@ Decl *Parser::ParseDeclarationAfterDeclaratorAndAttributes(
}
}

Sema::CUDATargetContextRAII X(Actions, Sema::CTCK_InitGlobalVar, ThisDecl);
switch (TheInitKind) {
// Parse declarator '=' initializer.
case InitKind::Equal: {
Expand Down
24 changes: 21 additions & 3 deletions clang/lib/Sema/SemaCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,37 @@ Sema::IdentifyCUDATarget(const ParsedAttributesView &Attrs) {
}

template <typename A>
static bool hasAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
return D->hasAttrs() && llvm::any_of(D->getAttrs(), [&](Attr *Attribute) {
return isa<A>(Attribute) &&
!(IgnoreImplicitAttr && Attribute->isImplicit());
});
}

Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
CUDATargetContextKind K,
Decl *D)
: S(S_) {
SavedCtx = S.CurCUDATargetCtx;
assert(K == CTCK_InitGlobalVar);
auto *VD = dyn_cast_or_null<VarDecl>(D);
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
auto Target = CFT_Host;
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
Target = CFT_Device;
S.CurCUDATargetCtx = {Target, K, VD};
}
}

/// IdentifyCUDATarget - Determine the CUDA compilation target for this function
Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
bool IgnoreImplicitHDAttr) {
// Code that lives outside a function is run on the host.
// Code that lives outside a function gets the target from CurCUDATargetCtx.
if (D == nullptr)
return CFT_Host;
return CurCUDATargetCtx.Target;

if (D->hasAttr<CUDAInvalidTargetAttr>())
return CFT_InvalidTarget;
Expand Down
9 changes: 6 additions & 3 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5132,7 +5132,8 @@ static void handleCallConvAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
// Diagnostic is emitted elsewhere: here we store the (valid) AL
// in the Decl node for syntactic reasoning, e.g., pretty-printing.
CallingConv CC;
if (S.CheckCallingConvAttr(AL, CC, /*FD*/nullptr))
if (S.CheckCallingConvAttr(AL, CC, /*FD*/ nullptr,
S.IdentifyCUDATarget(dyn_cast<FunctionDecl>(D))))
return;

if (!isa<ObjCMethodDecl>(D)) {
Expand Down Expand Up @@ -5317,7 +5318,8 @@ static void handleNoRandomizeLayoutAttr(Sema &S, Decl *D,
}

bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
const FunctionDecl *FD) {
const FunctionDecl *FD,
CUDAFunctionTarget CFT) {
if (Attrs.isInvalid())
return true;

Expand Down Expand Up @@ -5416,7 +5418,8 @@ bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
// on their host/device attributes.
if (LangOpts.CUDA) {
auto *Aux = Context.getAuxTargetInfo();
auto CudaTarget = IdentifyCUDATarget(FD);
assert(FD || CFT != CFT_InvalidTarget);
auto CudaTarget = FD ? IdentifyCUDATarget(FD) : CFT;
bool CheckHost = false, CheckDevice = false;
switch (CudaTarget) {
case CFT_HostDevice:
Expand Down
45 changes: 24 additions & 21 deletions clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6697,17 +6697,19 @@ void Sema::AddOverloadCandidate(
}

// (CUDA B.1): Check for invalid calls between targets.
if (getLangOpts().CUDA)
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
// Skip the check for callers that are implicit members, because in this
// case we may not yet know what the member's target is; the target is
// inferred for the member automatically, based on the bases and fields of
// the class.
if (!Caller->isImplicit() && !IsAllowedCUDACall(Caller, Function)) {
Candidate.Viable = false;
Candidate.FailureKind = ovl_fail_bad_target;
return;
}
if (getLangOpts().CUDA) {
const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
// Skip the check for callers that are implicit members, because in this
// case we may not yet know what the member's target is; the target is
// inferred for the member automatically, based on the bases and fields of
// the class.
if (!(Caller && Caller->isImplicit()) &&
!IsAllowedCUDACall(Caller, Function)) {
Candidate.Viable = false;
Candidate.FailureKind = ovl_fail_bad_target;
return;
}
}

if (Function->getTrailingRequiresClause()) {
ConstraintSatisfaction Satisfaction;
Expand Down Expand Up @@ -7219,12 +7221,11 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,

// (CUDA B.1): Check for invalid calls between targets.
if (getLangOpts().CUDA)
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
if (!IsAllowedCUDACall(Caller, Method)) {
Candidate.Viable = false;
Candidate.FailureKind = ovl_fail_bad_target;
return;
}
if (!IsAllowedCUDACall(getCurFunctionDecl(/*AllowLambda=*/true), Method)) {
Candidate.Viable = false;
Candidate.FailureKind = ovl_fail_bad_target;
return;
}

if (Method->getTrailingRequiresClause()) {
ConstraintSatisfaction Satisfaction;
Expand Down Expand Up @@ -12495,10 +12496,12 @@ class AddressOfFunctionResolver {
return false;

if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
if (S.getLangOpts().CUDA)
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
return false;
if (S.getLangOpts().CUDA) {
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
if (!(Caller && Caller->isImplicit()) &&
!S.IsAllowedCUDACall(Caller, FunDecl))
return false;
}
if (FunDecl->isMultiVersion()) {
const auto *TA = FunDecl->getAttr<TargetAttr>();
if (TA && !TA->isDefaultVersion())
Expand Down
60 changes: 35 additions & 25 deletions clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,14 @@ enum TypeAttrLocation {
TAL_DeclName
};

static void processTypeAttrs(TypeProcessingState &state, QualType &type,
TypeAttrLocation TAL,
const ParsedAttributesView &attrs);
static void
processTypeAttrs(TypeProcessingState &state, QualType &type,
TypeAttrLocation TAL, const ParsedAttributesView &attrs,
Sema::CUDAFunctionTarget CFT = Sema::CFT_HostDevice);

static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
QualType &type);
QualType &type,
Sema::CUDAFunctionTarget CFT);

static bool handleMSPointerTypeQualifierAttr(TypeProcessingState &state,
ParsedAttr &attr, QualType &type);
Expand Down Expand Up @@ -617,7 +619,8 @@ static void distributeFunctionTypeAttr(TypeProcessingState &state,
/// distributed, false if no location was found.
static bool distributeFunctionTypeAttrToInnermost(
TypeProcessingState &state, ParsedAttr &attr,
ParsedAttributesView &attrList, QualType &declSpecType) {
ParsedAttributesView &attrList, QualType &declSpecType,
Sema::CUDAFunctionTarget CFT) {
Declarator &declarator = state.getDeclarator();

// Put it on the innermost function chunk, if there is one.
Expand All @@ -629,19 +632,20 @@ static bool distributeFunctionTypeAttrToInnermost(
return true;
}

return handleFunctionTypeAttr(state, attr, declSpecType);
return handleFunctionTypeAttr(state, attr, declSpecType, CFT);
}

/// A function type attribute was written in the decl spec. Try to
/// apply it somewhere.
static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
ParsedAttr &attr,
QualType &declSpecType) {
static void
distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
ParsedAttr &attr, QualType &declSpecType,
Sema::CUDAFunctionTarget CFT) {
state.saveDeclSpecAttrs();

// Try to distribute to the innermost.
if (distributeFunctionTypeAttrToInnermost(
state, attr, state.getCurrentAttributes(), declSpecType))
state, attr, state.getCurrentAttributes(), declSpecType, CFT))
return;

// If that failed, diagnose the bad attribute when the declarator is
Expand All @@ -653,14 +657,14 @@ static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
/// Try to apply it somewhere.
/// `Attrs` is the attribute list containing the declaration (either of the
/// declarator or the declaration).
static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
ParsedAttr &attr,
QualType &declSpecType) {
static void distributeFunctionTypeAttrFromDeclarator(
TypeProcessingState &state, ParsedAttr &attr, QualType &declSpecType,
Sema::CUDAFunctionTarget CFT) {
Declarator &declarator = state.getDeclarator();

// Try to distribute to the innermost.
if (distributeFunctionTypeAttrToInnermost(
state, attr, declarator.getAttributes(), declSpecType))
state, attr, declarator.getAttributes(), declSpecType, CFT))
return;

// If that failed, diagnose the bad attribute when the declarator is
Expand All @@ -682,7 +686,8 @@ static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
/// `Attrs` is the attribute list containing the declaration (either of the
/// declarator or the declaration).
static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
QualType &declSpecType) {
QualType &declSpecType,
Sema::CUDAFunctionTarget CFT) {
// The called functions in this loop actually remove things from the current
// list, so iterating over the existing list isn't possible. Instead, make a
// non-owning copy and iterate over that.
Expand All @@ -699,7 +704,7 @@ static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
break;

FUNCTION_TYPE_ATTRS_CASELIST:
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType);
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType, CFT);
break;

MS_TYPE_ATTRS_CASELIST:
Expand Down Expand Up @@ -3544,7 +3549,8 @@ static QualType GetDeclSpecTypeForDeclarator(TypeProcessingState &state,
// Note: We don't need to distribute declaration attributes (i.e.
// D.getDeclarationAttributes()) because those are always C++11 attributes,
// and those don't get distributed.
distributeTypeAttrsFromDeclarator(state, T);
distributeTypeAttrsFromDeclarator(
state, T, SemaRef.IdentifyCUDATarget(D.getAttributes()));

// Find the deduced type in this type. Look in the trailing return type if we
// have one, otherwise in the DeclSpec type.
Expand Down Expand Up @@ -4055,7 +4061,8 @@ static CallingConv getCCForDeclaratorChunk(
// function type. We'll diagnose the failure to apply them in
// handleFunctionTypeAttr.
CallingConv CC;
if (!S.CheckCallingConvAttr(AL, CC) &&
if (!S.CheckCallingConvAttr(AL, CC, /*FunctionDecl=*/nullptr,
S.IdentifyCUDATarget(D.getAttributes())) &&
(!FTI.isVariadic || supportsVariadicCall(CC))) {
return CC;
}
Expand Down Expand Up @@ -5727,7 +5734,8 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
}

// See if there are any attributes on this declarator chunk.
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs());
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs(),
S.IdentifyCUDATarget(D.getAttributes()));

if (DeclType.Kind != DeclaratorChunk::Paren) {
if (ExpectNoDerefChunk && !IsNoDerefableChunk(DeclType))
Expand Down Expand Up @@ -7801,7 +7809,8 @@ static bool checkMutualExclusion(TypeProcessingState &state,
/// Process an individual function attribute. Returns true to
/// indicate that the attribute was handled, false if it wasn't.
static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
QualType &type) {
QualType &type,
Sema::CUDAFunctionTarget CFT) {
Sema &S = state.getSema();

FunctionTypeUnwrapper unwrapped(S, type);
Expand Down Expand Up @@ -8032,7 +8041,7 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,

// Otherwise, a calling convention.
CallingConv CC;
if (S.CheckCallingConvAttr(attr, CC))
if (S.CheckCallingConvAttr(attr, CC, /*FunctionDecl=*/nullptr, CFT))
return true;

const FunctionType *fn = unwrapped.get();
Expand Down Expand Up @@ -8584,7 +8593,8 @@ static void HandleLifetimeBoundAttr(TypeProcessingState &State,

static void processTypeAttrs(TypeProcessingState &state, QualType &type,
TypeAttrLocation TAL,
const ParsedAttributesView &attrs) {
const ParsedAttributesView &attrs,
Sema::CUDAFunctionTarget CFT) {

state.setParsedNoDeref(false);
if (attrs.empty())
Expand Down Expand Up @@ -8826,7 +8836,7 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
// appertain to and hence should not use the "distribution" logic below.
if (attr.isStandardAttributeSyntax() ||
attr.isRegularKeywordAttribute()) {
if (!handleFunctionTypeAttr(state, attr, type)) {
if (!handleFunctionTypeAttr(state, attr, type, CFT)) {
diagnoseBadTypeAttribute(state.getSema(), attr, type);
attr.setInvalid();
}
Expand All @@ -8836,10 +8846,10 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
// Never process function type attributes as part of the
// declaration-specifiers.
if (TAL == TAL_DeclSpec)
distributeFunctionTypeAttrFromDeclSpec(state, attr, type);
distributeFunctionTypeAttrFromDeclSpec(state, attr, type, CFT);

// Otherwise, handle the possible delays.
else if (!handleFunctionTypeAttr(state, attr, type))
else if (!handleFunctionTypeAttr(state, attr, type, CFT))
distributeFunctionTypeAttr(state, attr, type);
break;
case ParsedAttr::AT_AcquireHandle: {
Expand Down
Loading

0 comments on commit 9b77638

Please sign in to comment.