-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathabstract_interpretation.jl
88 lines (76 loc) · 3.08 KB
/
abstract_interpretation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# AbstractInterpretation -- this is an instance of a Julia AbstractInterpreter. We use it
# in conjunction with the contexts above to decide what should be inlined and what should
# not be inlined. Similar strategies are employed by Enzyme and Diffractor.
# The most important bit of this code is `inlining_policy` -- the rest is copy + pasted
# boiler plate, largely taken from https://github.com/JuliaLang/julia/blob/2fe4190b3d26b4eee52b2b1b1054ddd6e38a941e/test/compiler/newinterp.jl#L11
struct TICache
dict::IdDict{Core.MethodInstance, Core.CodeInstance}
end
TICache() = TICache(IdDict{Core.MethodInstance, Core.CodeInstance}())
struct TapirInterpreter{C} <: CC.AbstractInterpreter
meta # additional information
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::TICache
oc_cache::Dict{Any, Any}
function TapirInterpreter(
::Type{C};
meta=nothing,
world::UInt=Base.get_world_counter(),
inf_params::CC.InferenceParams=CC.InferenceParams(),
opt_params::CC.OptimizationParams=CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[],
code_cache::TICache=TICache(),
) where {C}
return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, Dict())
end
end
TapirInterpreter() = TapirInterpreter(DefaultCtx)
const PInterp = TapirInterpreter
CC.InferenceParams(interp::PInterp) = interp.inf_params
CC.OptimizationParams(interp::PInterp) = interp.opt_params
CC.get_world_counter(interp::PInterp) = interp.world
CC.get_inference_cache(interp::PInterp) = interp.inf_cache
function CC.code_cache(interp::PInterp)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
end
function CC.get(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance, default)
return get(wvc.cache.dict, mi, default)
end
function CC.getindex(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance)
return getindex(wvc.cache.dict, mi)
end
CC.haskey(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance) = haskey(wvc.cache.dict, mi)
function CC.setindex!(
wvc::CC.WorldView{TICache}, ci::Core.CodeInstance, mi::Core.MethodInstance
)
return setindex!(wvc.cache.dict, ci, mi)
end
_type(x) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{x.thentype, x.elsetype}
function CC.inlining_policy(
interp::TapirInterpreter{C},
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt8,
mi::Core.MethodInstance,
argtypes::Vector{Any},
) where {C}
# Do not inline away primitives.
argtype_tuple = Tuple{map(_type, argtypes)...}
is_primitive(C, argtype_tuple) && return nothing
# If not a primitive, AD doesn't care about it. Use the usual inlining strategy.
return @invoke CC.inlining_policy(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt8,
mi::Core.MethodInstance,
argtypes::Vector{Any},
)
end
context_type(::PInterp{C}) where {C} = C