File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -188,6 +188,32 @@ Declare memory that is local to a workgroup.
188188"""
189189localmemory (:: Type{T} , dims) where {T} = localmemory (T, Val (dims))
190190
191+ """
192+ shfl_down(val::T, offset::Integer) where T
193+
194+ Read `val` from a lane with higher id given by `offset`.
195+
196+ !!! note
197+ Backend implementations **must** implement:
198+ ```
199+ @device_override shfl_down(val::T, offset::Integer) where T
200+ ```
201+ As well as the on-device functionality.
202+ """
203+ function shfl_down end
204+
205+ """
206+ shfl_down_types(::Backend)::Vector{DataType}
207+
208+ Returns a vector of `DataType`s supported on `backend`
209+
210+ !!! note
211+ Backend implementations **must** implement this function
212+ only if they support `shfl_down` for any types.
213+ """
214+ shfl_down_types (:: Backend ) = DataType[]
215+
216+
191217"""
192218 barrier()
193219
Original file line number Diff line number Diff line change @@ -45,6 +45,25 @@ function test_subgroup_kernel(results)
4545 return
4646end
4747
48+ function shfl_down_test_kernel (a, b, :: Val{N} ) where {N}
49+ idx = KI. get_sub_group_local_id ()
50+
51+ val = a[idx]
52+
53+ offset = 0x00000001
54+ while offset < N
55+ val += KI. shfl_down (val, offset)
56+ offset <<= 1
57+ end
58+
59+ KI. sub_group_barrier ()
60+
61+ if idx == 1
62+ b[idx] = val
63+ end
64+ return
65+ end
66+
4867function intrinsics_testsuite (backend, AT)
4968 @testset " KernelIntrinsics Tests" begin
5069 @testset " Launch parameters" begin
@@ -177,6 +196,23 @@ function intrinsics_testsuite(backend, AT)
177196 @test sg_data. sub_group_local_id == expected_sg_local
178197 end
179198 end
199+ @testset " shfl_down" begin
200+ @test ! isempty (KI. shfl_down_types (backend ()))
201+ types_to_test = setdiff (KI. shfl_down_types (backend ()), [Bool])
202+ @testset " $T " for T in types_to_test
203+ N = KI. sub_group_size (backend ())
204+ a = zeros (T, N)
205+ rand! (a, (0 : 1 ))
206+
207+ dev_a = AT (a)
208+ dev_b = AT (zeros (T, N))
209+
210+ KI. @kernel backend () workgroupsize = N shfl_down_test_kernel (dev_a, dev_b, Val (N))
211+
212+ b = Array (dev_b)
213+ @test sum (a) ≈ b[1 ]
214+ end
215+ end
180216 end
181217 return nothing
182218end
You can’t perform that action at this time.
0 commit comments