Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llvmcall #5046

Merged
merged 4 commits into from
Aug 12, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ export
JULIA_HOME, nothing, Main,
# intrinsics module
Intrinsics
#ccall, cglobal, abs_float, add_float, add_int, and_int, ashr_int,
#ccall, cglobal, llvmcall, abs_float, add_float, add_int, and_int, ashr_int,
#box, bswap_int, checked_fptosi, checked_fptoui, checked_sadd,
#checked_smul, checked_ssub, checked_uadd, checked_umul, checked_usub,
#checked_trunc_sint, checked_trunc_uint,
Expand Down
4 changes: 4 additions & 0 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ t_func[nan_dom_err] = (2, 2, (a, b)->a)
t_func[eval(Core.Intrinsics,:ccall)] =
(3, Inf, (fptr, rt, at, a...)->(is(rt,Type{Void}) ? Nothing :
isType(rt) ? rt.parameters[1] : Any))
t_func[eval(Core.Intrinsics,:llvmcall)] =
(3, Inf, (fptr, rt, at, a...)->(is(rt,Type{Void}) ? Nothing :
isType(rt) ? rt.parameters[1] :
isa(rt,Tuple) ? map(x->x.parameters[1],rt) : Any))
t_func[eval(Core.Intrinsics,:cglobal)] =
(1, 2, (fptr, t...)->(isempty(t) ? Ptr{Void} :
isType(t[1]) ? Ptr{t[1].parameters[1]} : Ptr))
Expand Down
206 changes: 206 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,212 @@ static Value *emit_cglobal(jl_value_t **args, size_t nargs, jl_codectx_t *ctx)
return mark_julia_type(res, rt);
}

// llvmcall(ir, (rettypes...), (argtypes...), args...)
static Value *emit_llvmcall(jl_value_t **args, size_t nargs, jl_codectx_t *ctx)
{

JL_NARGSV(llvmcall, 3)
jl_value_t *rt = NULL, *at = NULL, *ir = NULL;
JL_GC_PUSH3(&ir, &rt, &at);
{
JL_TRY {
at = jl_interpret_toplevel_expr_in(ctx->module, args[3],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting llvmcall return type");
}
}
{
JL_TRY {
rt = jl_interpret_toplevel_expr_in(ctx->module, args[2],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting llvmcall argument tuple");
}
}
{
JL_TRY {
ir = jl_interpret_toplevel_expr_in(ctx->module, args[1],
&jl_tupleref(ctx->sp,0),
jl_tuple_len(ctx->sp)/2);
}
JL_CATCH {
jl_rethrow_with_add("error interpreting IR argument");
}
}
int i = 1;
if (ir == NULL) {
jl_error("Cannot statically evaluate first argument to llvmcall");
}
bool isString = jl_is_byte_string(ir);
bool isPtr = jl_is_cpointer(ir);
if (!isString && !isPtr)
{
jl_error("First argument to llvmcall must be a string or pointer to an LLVM Function");
}

JL_TYPECHK(llvmcall, type, rt);
JL_TYPECHK(llvmcall, tuple, at);
JL_TYPECHK(llvmcall, type, at);

std::stringstream ir_stream;

jl_tuple_t *stt = jl_alloc_tuple(nargs - 3);

for (size_t i = 0; i < nargs-3; ++i)
{
jl_tupleset(stt,i,expr_type(args[4+i],ctx));
}

// Generate arguments
std::string arguments;
llvm::raw_string_ostream argstream(arguments);
jl_tuple_t *tt = (jl_tuple_t*)at;
jl_value_t *rtt = rt;

size_t nargt = jl_tuple_len(tt);
Value *argvals[nargt];
std::vector<llvm::Type*> argtypes;
/*
* Semantics for arguments are as follows:
* If the argument type is immutable (including bitstype), we pass the loaded llvm value
* type. Otherwise we pass a pointer to a jl_value_t.
*/
for (size_t i = 0; i < nargt; ++i)
{
jl_value_t *tti = jl_tupleref(tt,i);
Type *t = julia_type_to_llvm(tti);
argtypes.push_back(t);
if (4+i > nargs)
{
jl_error("Missing arguments to llvmcall!");
}
jl_value_t *argi = args[4+i];
Value *arg;
bool needroot = false;
if (t == jl_pvalue_llvmt || !jl_isbits(tti)) {
arg = emit_expr(argi, ctx, true);
if (t == jl_pvalue_llvmt && arg->getType() != jl_pvalue_llvmt) {
arg = boxed(arg, ctx);
needroot = true;
}
}
else {
arg = emit_unboxed(argi, ctx);
if (jl_is_bitstype(expr_type(argi, ctx))) {
arg = emit_unbox(t, arg, tti);
}
}

#ifdef JL_GC_MARKSWEEP
// make sure args are rooted
if (t == jl_pvalue_llvmt && (needroot || might_need_root(argi))) {
make_gcroot(arg, ctx);
}
#endif
bool mightNeedTempSpace = false;
argvals[i] = julia_to_native(t,tti,arg,argi,false,i,ctx,&mightNeedTempSpace,&mightNeedTempSpace);
}

Function *f;
Type *rettype = julia_type_to_llvm(rtt);
if (isString) {
// Make sure to find a unique name
std::string ir_name;
while(true) {
std::stringstream name;
name << (ctx->f->getName().str()) << i++;
ir_name = name.str();
if(jl_Module->getFunction(ir_name) == NULL)
break;
}

bool first = true;
for (std::vector<Type *>::iterator it = argtypes.begin(); it != argtypes.end(); ++it) {
if(!first)
argstream << ",";
else
first = false;
(*it)->print(argstream);
argstream << " ";
}

std::string rstring;
llvm::raw_string_ostream rtypename(rstring);
rettype->print(rtypename);

ir_stream << "; Number of arguments: " << nargt << "\n"
<< "define "<<rtypename.str()<<" @\"" << ir_name << "\"("<<argstream.str()<<") {\n"
<< jl_string_data(ir) << "\n}";
SMDiagnostic Err = SMDiagnostic();
std::string ir_string = ir_stream.str();
Module *m = ParseAssemblyString(ir_string.data(),jl_Module,Err,jl_LLVMContext);
if (m == NULL) {
std::string message = "Failed to parse LLVM Assembly: \n";
llvm::raw_string_ostream stream(message);
Err.print("julia",stream,true);
jl_error(stream.str().c_str());
}
f = m->getFunction(ir_name);
} else {
assert(isPtr);
// Create Function skeleton
f = (llvm::Function*)jl_unbox_voidpointer(ir);
assert(f->getReturnType() == rettype);
int i = 0;
for (std::vector<Type *>::iterator it = argtypes.begin();
it != argtypes.end(); ++it, ++i)
assert(*it == f->getFunctionType()->getParamType(i));

#ifdef USE_MCJIT
if (f->getParent() != jl_Module)
{
FunctionMover mover(jl_Module,f->getParent());
f = (llvm::Function*)MapValue(f,mover.VMap,RF_None,NULL,&mover);
}
#endif

//f->dump();
#ifndef LLVM35
if (verifyFunction(*f,PrintMessageAction)) {
#else
llvm::raw_fd_ostream out(1,false);
if (verifyFunction(*f,&out))
{
#endif
f->dump();
jl_error("Malformed LLVM Function");
}
}

/*
* It might be tempting to just try to set the Always inline attribute on the function
* and hope for the best. However, this doesn't work since that would require an inlining
* pass (which is a Call Graph pass and cannot be managed by a FunctionPassManager). Instead
* We are sneaky and call the inliner directly. This however doesn't work until we've actually
* generated the entire function, so we need to store it in the context until the end of the
* function. This also has the benefit of looking exactly like we cut/pasted it in in `code_llvm`.
*/
f->setLinkage(GlobalValue::LinkOnceODRLinkage);

// the actual call
CallInst *inst = builder.CreateCall(prepare_call(f),ArrayRef<Value*>(&argvals[0],nargt));
ctx->to_inline.push_back(inst);

JL_GC_POP();

if(inst->getType() != rettype)
{
jl_error("Return type of llvmcall'ed function does not match declared return type");
}

return mark_julia_type(inst,rtt);
}

// --- code generator for ccall itself ---

// ccall(pointer, rettype, (argtypes...), args...)
Expand Down
1 change: 1 addition & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "llvm/Target/TargetMachine.h"
#else
#include "llvm/Analysis/Verifier.h"
#include "llvm/Assembly/Parser.h"
#endif
#include "llvm/DebugInfo/DIContext.h"
#if defined(LLVM_VERSION_MAJOR) && LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 4
Expand Down
4 changes: 3 additions & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace JL_I {
// pointer access
pointerref, pointerset, pointertoref,
// c interface
ccall, cglobal, jl_alloca
ccall, cglobal, jl_alloca, llvmcall
};
};

Expand Down Expand Up @@ -820,6 +820,7 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
switch (f) {
case ccall: return emit_ccall(args, nargs, ctx);
case cglobal: return emit_cglobal(args, nargs, ctx);
case llvmcall: return emit_llvmcall(args, nargs, ctx);

HANDLE(box,2) return generic_box(args[1], args[2], ctx);
HANDLE(unbox,2) return generic_unbox(args[1], args[2], ctx);
Expand Down Expand Up @@ -1438,4 +1439,5 @@ extern "C" void jl_init_intrinsic_functions(void)
ADD_I(nan_dom_err);
ADD_I(ccall); ADD_I(cglobal);
ADD_I(jl_alloca);
ADD_I(llvmcall);
}
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TESTS = all core keywordargs numbers strings unicode collections hashing \
git pkg resolve suitesparse complex version pollfd mpfr broadcast \
socket floatapprox priorityqueue readdlm regex float16 combinatorics \
sysinfo rounding ranges mod2pi euler show lineedit \
replcompletions backtrace repl test goto
replcompletions backtrace repl test goto llvmcall

default: all

Expand Down
33 changes: 33 additions & 0 deletions test/llvmcall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Base.llvmcall

function add1234(x::(Int32,Int32,Int32,Int32))
llvmcall("""%3 = add <4 x i32> %1, %0
ret <4 x i32> %3""",(Int32,Int32,Int32,Int32),
((Int32,Int32,Int32,Int32),(Int32,Int32,Int32,Int32)),
(int32(1),int32(2),int32(3),int32(4)),
x)
end

function add1234(x::NTuple{4,Int64})
llvmcall("""%3 = add <4 x i64> %1, %0
ret <4 x i64> %3""",NTuple{4,Int64},
(NTuple{4,Int64},NTuple{4,Int64}),
(int64(1),int64(2),int64(3),int64(4)),
x)
end

@test add1234(map(int32,(2,3,4,5))) === map(int32,(3,5,7,9))
@test add1234(map(int64,(2,3,4,5))) === map(int64,(3,5,7,9))

# Test whether llvmcall escapes the function name correctly
baremodule PlusTest
using Base.llvmcall
using Base.Test
using Base

function +(x::Int32, y::Int32)
llvmcall("""%3 = add i32 %1, %0
ret i32 %3""", Int32, (Int32, Int32), x, y)
end
@test int32(1)+int32(2)==int32(3)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ testnames = [
"resolve", "pollfd", "mpfr", "broadcast", "complex", "socket",
"floatapprox", "readdlm", "regex", "float16", "combinatorics",
"sysinfo", "rounding", "ranges", "mod2pi", "euler", "show",
"lineedit", "replcompletions", "repl", "test", "examples", "goto"
"lineedit", "replcompletions", "repl", "test", "examples", "goto", "llvmcall"
]
@unix_only push!(testnames, "unicode")

Expand Down