Skip to content

Commit 7fbbe55

Browse files
committed
support emitting a generated function
1 parent 1f84b17 commit 7fbbe55

2 files changed

Lines changed: 15 additions & 4 deletions

File tree

src/KernelAbstractions.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ synchronize(backend)
5050
```
5151
"""
5252
macro kernel(expr)
53-
return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false)
53+
return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false, #=unsafe_indices=# false, #=generated=# false)
5454
end
5555

5656
"""
@@ -69,11 +69,12 @@ This allows for two different configurations:
6969
"""
7070
macro kernel(ex...)
7171
if length(ex) == 1
72-
return __kernel(ex[1], true, false, false)
72+
return __kernel(ex[1], true, false, false, false)
7373
else
7474
generate_cpu = true
7575
unsafe_indices = false
7676
force_inbounds = false
77+
generated = false
7778
for i in 1:(length(ex) - 1)
7879
if ex[i] isa Expr && ex[i].head == :(=) &&
7980
ex[i].args[1] == :cpu && ex[i].args[2] isa Bool
@@ -84,17 +85,21 @@ macro kernel(ex...)
8485
elseif ex[i] isa Expr && ex[i].head == :(=) &&
8586
ex[i].args[1] == :unsafe_indices && ex[i].args[2] isa Bool
8687
unsafe_indices = ex[i].args[2]
88+
elseif ex[i] isa Expr && ex[i].head == :(=) &&
89+
ex[i].args[1] == :generated && ex[i].args[2] isa Bool
90+
generated = ex[i].args[2]
8791
else
8892
error(
8993
"Configuration should be of form:\n" *
9094
"* `cpu=false`\n" *
9195
"* `inbounds=true`\n" *
9296
"* `unsafe_indices=true`\n" *
97+
"* `generated=true`\n" *
9398
"got `", ex[i], "`",
9499
)
95100
end
96101
end
97-
return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices)
102+
return __kernel(ex[end], generate_cpu, force_inbounds, unsafe_indices, generated)
98103
end
99104
end
100105

src/macros.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function find_return(stmt)
1010
end
1111

1212
# XXX: Proper errors
13-
function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false)
13+
function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indices = false, generated = false)
1414
def = splitdef(expr)
1515
name = def[:name]
1616
args = def[:args]
@@ -41,12 +41,18 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false, unsafe_indi
4141
def_cpu = deepcopy(def)
4242
def_cpu[:name] = cpu_name
4343
transform_cpu!(def_cpu, constargs, force_inbounds)
44+
if generated
45+
def_cpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_cpu[:body])), Expr(:meta, :generated_only))
46+
end
4447
cpu_function = combinedef(def_cpu)
4548
end
4649

4750
def_gpu = deepcopy(def)
4851
def_gpu[:name] = gpu_name = Symbol(:gpu_, name)
4952
transform_gpu!(def_gpu, constargs, force_inbounds, unsafe_indices)
53+
if generated
54+
def_gpu[:body] = Expr(:if, Expr(:generated), Expr(:copyast, QuoteNode(def_gpu[:body])), Expr(:meta, :generated_only))
55+
end
5056
gpu_function = combinedef(def_gpu)
5157

5258
# create constructor functions

0 commit comments

Comments
 (0)