Skip to content

Commit a82b86b

Browse files
christiangnrdpxl-th
andcommitted
shfl_down intrinsics
Co-Authored-By: Anton Smirnov <tonysmn97@gmail.com>
1 parent 42acc6a commit a82b86b

2 files changed

Lines changed: 64 additions & 0 deletions

File tree

src/intrinsics.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,32 @@ Declare memory that is local to a workgroup.
188188
"""
189189
localmemory(::Type{T}, dims) where {T} = localmemory(T, Val(dims))
190190

191+
"""
192+
shfl_down(val::T, offset::Integer)::T where T
193+
194+
Read `val` from a lane with higher id given by `offset`.
195+
196+
!!! note
197+
Backend implementations **must** implement:
198+
```
199+
@device_override shfl_down(val::T, offset::Integer)::T where T
200+
```
201+
As well as the on-device functionality.
202+
"""
203+
function shfl_down end
204+
205+
"""
206+
shfl_down_types(::Backend)::Vector{DataType}
207+
208+
Returns a vector of `DataType`s supported on `backend`
209+
210+
!!! note
211+
Backend implementations **must** implement this function
212+
only if they support `shfl_down` for any types.
213+
"""
214+
shfl_down_types(::Backend) = DataType[]
215+
216+
191217
"""
192218
barrier()
193219

test/intrinsics.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,32 @@ function test_subgroup_kernel(results)
4141
return
4242
end
4343

44+
# Do NOT use this kernel as an example for your code.
45+
# It was written assuming one workgroup of size 32 and
46+
# is only valid for those
47+
function shfl_down_test_kernel(a, b)
48+
# This is not valid
49+
idx = KI.get_local_id().x
50+
51+
temp = KI.localmemory(eltype(b), 32)
52+
temp[idx] = a[idx]
53+
54+
KI.barrier()
55+
56+
if idx == 1
57+
value = temp[idx]
58+
59+
value = value + KI.shfl_down(value, 16)
60+
value = value + KI.shfl_down(value, 8)
61+
value = value + KI.shfl_down(value, 4)
62+
value = value + KI.shfl_down(value, 2)
63+
value = value + KI.shfl_down(value, 1)
64+
65+
b[idx] = value
66+
end
67+
return
68+
end
69+
4470
function intrinsics_testsuite(backend, AT)
4571
@testset "KernelIntrinsics Tests" begin
4672
@testset "Launch parameters" begin
@@ -174,6 +200,18 @@ function intrinsics_testsuite(backend, AT)
174200
@test sg_data.sub_group_local_id == expected_sg_local
175201
end
176202
end
203+
@testset "shfl_down(::$T)" for T in KI.shfl_down_types(backend())
204+
a = zeros(T, 32)
205+
rand!(a, (1:4))
206+
207+
dev_a = AT(a)
208+
dev_b = AT(zeros(T, 32))
209+
210+
KI.@kernel backend() workgroupsize=32 shfl_down_test_kernel(dev_a, dev_b)
211+
212+
b = Array(dev_b)
213+
@test sum(a) b[1]
214+
end
177215
end
178216
return nothing
179217
end

0 commit comments

Comments
 (0)