Skip to content

Commit 1765294

Browse files
christiangnrdpxl-th
andcommitted
shfl_down intrinsics
Co-Authored-By: Anton Smirnov <tonysmn97@gmail.com>
1 parent a5b090c commit 1765294

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

src/intrinsics.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,32 @@ Declare memory that is local to a workgroup.
188188
"""
189189
localmemory(::Type{T}, dims) where {T} = localmemory(T, Val(dims))
190190

191+
"""
192+
shfl_down(val::T, offset::Integer) where T
193+
194+
Read `val` from a lane with higher id given by `offset`.
195+
196+
!!! note
197+
Backend implementations **must** implement:
198+
```
199+
@device_override shfl_down(val::T, offset::Integer) where T
200+
```
201+
As well as the on-device functionality.
202+
"""
203+
function shfl_down end
204+
205+
"""
206+
shfl_down_types(::Backend)::Vector{DataType}
207+
208+
Returns a vector of `DataType`s supported on `backend`
209+
210+
!!! note
211+
Backend implementations **must** implement this function
212+
only if they support `shfl_down` for any types.
213+
"""
214+
shfl_down_types(::Backend) = DataType[]
215+
216+
191217
"""
192218
barrier()
193219

test/intrinsics.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,25 @@ function test_subgroup_kernel(results)
4545
return
4646
end
4747

48+
function shfl_down_test_kernel(a, b, ::Val{N}) where {N}
49+
idx = KI.get_sub_group_local_id()
50+
51+
val = a[idx]
52+
53+
offset = 0x00000001
54+
while offset < N
55+
val += KI.shfl_down(val, offset)
56+
offset <<= 1
57+
end
58+
59+
KI.sub_group_barrier()
60+
61+
if idx == 1
62+
b[idx] = val
63+
end
64+
return
65+
end
66+
4867
function intrinsics_testsuite(backend, AT)
4968
@testset "KernelIntrinsics Tests" begin
5069
@testset "Launch parameters" begin
@@ -177,6 +196,23 @@ function intrinsics_testsuite(backend, AT)
177196
@test sg_data.sub_group_local_id == expected_sg_local
178197
end
179198
end
199+
@testset "shfl_down" begin
200+
@test !isempty(KI.shfl_down_types(backend()))
201+
types_to_test = setdiff(KI.shfl_down_types(backend()), [Bool])
202+
@testset "$T" for T in types_to_test
203+
N = KI.sub_group_size(backend())
204+
a = zeros(T, N)
205+
rand!(a, (0:1))
206+
207+
dev_a = AT(a)
208+
dev_b = AT(zeros(T, N))
209+
210+
KI.@kernel backend() workgroupsize = N shfl_down_test_kernel(dev_a, dev_b, Val(N))
211+
212+
b = Array(dev_b)
213+
@test sum(a) b[1]
214+
end
215+
end
180216
end
181217
return nothing
182218
end

0 commit comments

Comments
 (0)