Skip to content

Commit aa89f72

Browse files
committed
Support KA groupreduce API
1 parent 27062bc commit aa89f72

4 files changed

Lines changed: 49 additions & 22 deletions

File tree

src/ROCKernels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,8 @@ end
166166
# TODO
167167
end
168168

169+
@device_override @inline function KA.__shfl_down(val, offset)
170+
AMDGPU.Device.shfl_down(val, offset)
171+
end
172+
169173
end

src/device/gcn/wavefront.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,20 @@ end
5757
"""
5858
wfred(op::Function, val::T) where T -> T
5959
60-
Performs a wavefront-wide reduction on `val` in each lane, and returns the
61-
result. A limited subset of functions are available to be passed as `op`. When
62-
`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be
63-
`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`,
64-
`T` may also be `<:Union{Float32, Float64}`.
60+
Performs a wavefront-wide reduction on `val` in each lane, and returns the result.
61+
A limited subset of functions are available to be passed as `op`.
62+
When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`.
63+
When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`.
6564
"""
6665
wfred
6766

6867
"""
6968
wfscan(op::Function, val::T) where T -> T
7069
7170
Performs a wavefront-wide scan on `val` in each lane, and returns the
72-
result. A limited subset of functions are available to be passed as `op`. When
73-
`op` is one of `(+, max, min, &, |, ⊻)`, `T` may be
74-
`<:Union{Cint, Clong, Cuint, Culong}`. When `op` is one of `(+, max, min)`,
75-
`T` may also be `<:Union{Float32, Float64}`.
71+
result. A limited subset of functions are available to be passed as `op`.
72+
When `op` is one of `(+, max, min, &, |, ⊻)`, `T` may be `<:Union{Cint, Clong, Cuint, Culong}`.
73+
When `op` is one of `(+, max, min)`, `T` may also be `<:Union{Float32, Float64}`.
7674
"""
7775
wfscan
7876

t.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using AMDGPU
2+
using KernelAbstractions
3+
import KernelAbstractions as KA
4+
5+
@kernel function ker!(y, x, neutral, op)
6+
i = @index(Global)
7+
val = i > length(x) ? neutral : x[i]
8+
9+
res = KA.@groupreduce(:warp, op, val, neutral)
10+
if i == 1
11+
y[1] = res
12+
end
13+
end
14+
15+
function main()
16+
n = 256
17+
x = ROCArray(ones(Int64, n))
18+
y = ROCArray(zeros(Int64, 1))
19+
20+
ker!(ROCBackend(), n)(y, x, 0, +; ndrange=length(x))
21+
@show y
22+
return
23+
end
24+
main()

test/runtests.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,25 @@ AMDGPU.versioninfo()
110110

111111
@info "Test suite info"
112112
data = String["$np" "$(AMDGPU.device())" join(TARGET_TESTS, ", ");]
113-
PrettyTables.pretty_table(data; header=[
114-
"Workers", "Device", "Tests"], crop=:none)
113+
PrettyTables.pretty_table(data; header=["Workers", "Device", "Tests"], crop=:none)
115114

116115
runtests(AMDGPU; nworkers=np, nworker_threads=1, testitem_timeout=60 * 30) do ti
116+
return ti.name == "kernelabstractions"
117+
117118
for tt in TARGET_TESTS
118119
startswith(ti.name, tt) && return true
119120
end
120121
return false
121122
end
122123

123-
if "core" in TARGET_TESTS && Sys.islinux()
124-
@info "Testing `Hostcalls` on the main thread."
125-
@testset "Hostcalls" begin
126-
include("device/hostcall.jl")
127-
128-
# TODO 1.11 fails
129-
if VERSION < v"1.11-"
130-
include("device/output.jl")
131-
end
132-
end
133-
end
124+
# if "core" in TARGET_TESTS && Sys.islinux()
125+
# @info "Testing `Hostcalls` on the main thread."
126+
# @testset "Hostcalls" begin
127+
# include("device/hostcall.jl")
128+
129+
# # TODO 1.11 fails
130+
# if VERSION < v"1.11-"
131+
# include("device/output.jl")
132+
# end
133+
# end
134+
# end

0 commit comments

Comments
 (0)