-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathexecution.jl
More file actions
219 lines (178 loc) · 7.37 KB
/
execution.jl
File metadata and controls
219 lines (178 loc) · 7.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
export @opencl, clfunction
## high-level @opencl interface
const MACRO_KWARGS = [:launch]
const COMPILER_KWARGS = [:kernel, :name, :always_inline]
const LAUNCH_KWARGS = [:global_size, :local_size, :queue]
macro opencl(ex...)
call = ex[end]
kwargs = map(ex[1:end-1]) do kwarg
if kwarg isa Symbol
:($kwarg = $kwarg)
elseif Meta.isexpr(kwarg, :(=))
kwarg
else
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
end
end
# destructure the kernel call
Meta.isexpr(call, :call) || throw(ArgumentError("second argument to @opencl should be a function call"))
f = call.args[1]
args = call.args[2:end]
code = quote end
vars, var_exprs = assign_args!(code, args)
# group keyword argument
macro_kwargs, compiler_kwargs, call_kwargs, other_kwargs =
split_kwargs(kwargs, MACRO_KWARGS, COMPILER_KWARGS, LAUNCH_KWARGS)
if !isempty(other_kwargs)
key,val = first(other_kwargs).args
throw(ArgumentError("Unsupported keyword argument '$key'"))
end
# handle keyword arguments that influence the macro's behavior
launch = true
for kwarg in macro_kwargs
key,val = kwarg.args
if key == :launch
isa(val, Bool) || throw(ArgumentError("`launch` keyword argument to @opencl should be a constant value"))
launch = val::Bool
else
throw(ArgumentError("Unsupported keyword argument '$key'"))
end
end
if !launch && !isempty(call_kwargs)
error("@opencl with launch=false does not support launch-time keyword arguments; use them when calling the kernel")
end
# FIXME: macro hygiene wrt. escaping kwarg values (this broke with 1.5)
# we esc() the whole thing now, necessitating gensyms...
@gensym f_var kernel_f kernel_args kernel_tt kernel
# convert the arguments, call the compiler and launch the kernel
# while keeping the original arguments alive
push!(code.args,
quote
$f_var = $f
GC.@preserve $(vars...) $f_var begin
$kernel_f = $kernel_convert($f_var)
$kernel_args = map($kernel_convert, ($(var_exprs...),))
$kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
$kernel = $clfunction($kernel_f, $kernel_tt; $(compiler_kwargs...))
if $launch
$kernel($(var_exprs...); $(call_kwargs...))
end
$kernel
end
end)
return esc(quote
let
$code
end
end)
end
## argument conversion
struct KernelAdaptor
indirect_memory::Vector{cl.AbstractMemory}
end
# when converting to pointers, we need to keep track of the underlying memory type
function Adapt.adapt_storage(to::KernelAdaptor, buf::cl.AbstractMemory)
ptr = pointer(buf)
push!(to.indirect_memory, buf)
return ptr
end
function Adapt.adapt_storage(to::KernelAdaptor, arr::CLArray{T, N}) where {T, N}
push!(to.indirect_memory, arr.data[].mem)
return Base.unsafe_convert(CLDeviceArray{T, N, AS.CrossWorkgroup}, arr)
end
# Base.RefValue isn't GPU compatible, so provide a compatible alternative
# TODO: port improvements from CUDA.jl
struct CLRefValue{T} <: Ref{T}
x::T
end
Base.getindex(r::CLRefValue) = r.x
Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue) = CLRefValue(adapt(to, r[]))
# broadcast sometimes passes a ref(type), resulting in a GPU-incompatible DataType box.
# avoid that by using a special kind of ref that knows about the boxed type.
struct CLRefType{T} <: Ref{DataType} end
Base.getindex(r::CLRefType{T}) where T = T
Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue{<:Union{DataType,Type}}) =
CLRefType{r[]}()
# case where type is the function being broadcasted
Adapt.adapt_structure(to::KernelAdaptor,
bc::Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} =
Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes)
"""
kernel_convert(x)
This function is called for every argument to be passed to a kernel, allowing it to be
converted to a GPU-friendly format. By default, the function does nothing and returns the
input object `x` as-is.
Do not add methods to this function, but instead extend the underlying Adapt.jl package and
register methods for the the `OpenCL.KernelAdaptor` type.
"""
kernel_convert(arg, indirect_memory::Vector{cl.AbstractMemory} = cl.AbstractMemory[]) =
adapt(KernelAdaptor(indirect_memory), arg)
## abstract kernel functionality
abstract type AbstractKernel{F, TT} end
@inline @generated function (kernel::AbstractKernel{F,TT})(args...;
call_kwargs...) where {F,TT}
sig = Tuple{F, TT.parameters...} # Base.signature_type with a function type
args = (:(kernel.f), (:(kernel_convert(args[$i], indirect_memory)) for i in 1:length(args))...)
# filter out ghost arguments that shouldn't be passed
predicate = dt -> isghosttype(dt) || Core.Compiler.isconstType(dt)
to_pass = map(!predicate, sig.parameters)
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
# add the kernel state as the first argument
pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(kernel.state))
# replace non-isbits arguments (they should be unused, or compilation would have failed)
for (i,dt) in enumerate(call_t)
if !isbitstype(dt)
call_t[i] = Ptr{Any}
call_args[i] = :C_NULL
end
end
# finalize types
call_tt = Base.to_tuple_type(call_t)
quote
indirect_memory = cl.AbstractMemory[]
# add exception info buffer to indirect memory
# XXX: this is too expensive
if kernel.state.exception_info != C_NULL
ctx = cl.context()
if haskey(exception_infos, ctx)
push!(indirect_memory, exception_infos[ctx])
end
end
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, call_kwargs...)
end
end
## host-side kernels
struct HostKernel{F,TT} <: AbstractKernel{F,TT}
f::F
fun::cl.Kernel
state::KernelState
end
## host-side API
const clfunction_lock = ReentrantLock()
function clfunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
ctx = cl.context()
dev = cl.device()
Base.@lock clfunction_lock begin
# compile the function
cache = compiler_cache(ctx)
source = methodinstance(F, tt)
config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig
fun = GPUCompiler.cached_compilation(cache, source, config, compile, link)
# create a callable object that captures the function instance. we don't need to think
# about world age here, as GPUCompiler already does and will return a different object
h = hash(fun, hash(f, hash(tt)))
kernel = get(_kernel_instances, h, nothing)
if kernel === nothing
# create the kernel state object
exception_info = create_exceptions!(ctx, dev)
state = KernelState(exception_info)
kernel = HostKernel{F,tt}(f, fun, state)
_kernel_instances[h] = kernel
end
return kernel::HostKernel{F,tt}
end
end
# cache of kernel instances
const _kernel_instances = Dict{UInt, Any}()