@@ -7,6 +7,8 @@ using ..POCL: device, clconvert, clfunction
77import KernelAbstractions as KA
88import KernelAbstractions. KernelIntrinsics as KI
99
10+ import SPIRVIntrinsics
11+
1012import StaticArrays
1113
1214import Adapt
@@ -174,10 +176,36 @@ end
174176function KI. max_work_group_size (:: POCLBackend ):: Int
175177 return Int (device (). max_work_group_size)
176178end
179+ function KI. sub_group_size (:: POCLBackend ):: Int
180+ sg_sizes = cl. device (). sub_group_sizes
181+ if 32 in sg_sizes
182+ return 32
183+ elseif 64 in sg_sizes
184+ return 64
185+ elseif 16 in sg_sizes
186+ return 16
187+ else
188+ return 1
189+ end
190+ end
177191function KI. multiprocessor_count (:: POCLBackend ):: Int
178192 return Int (device (). max_compute_units)
179193end
180194
195+ function KI. shfl_down_types (:: POCLBackend )
196+ res = copy (SPIRVIntrinsics. gentypes)
197+
198+ backend_extensions = cl. device (). extensions
199+ if " cl_khr_fp64" ∉ backend_extensions
200+ res = setdiff (res, [Float64])
201+ end
202+ if " cl_khr_fp16" ∉ backend_extensions
203+ res = setdiff (res, [Float16])
204+ end
205+
206+ return res
207+ end
208+
181209# # Indexing Functions
182210
183211@device_override @inline function KI. get_local_id ()
204232 return (; x = Int (get_global_size (1 )), y = Int (get_global_size (2 )), z = Int (get_global_size (3 )))
205233end
206234
235+ @device_override KI. get_sub_group_size () = get_sub_group_size ()
236+
237+ @device_override KI. get_max_sub_group_size () = get_max_sub_group_size ()
238+
239+ @device_override KI. get_num_sub_groups () = get_num_sub_groups ()
240+
241+ @device_override KI. get_sub_group_id () = get_sub_group_id ()
242+
243+ @device_override KI. get_sub_group_local_id () = get_sub_group_local_id ()
244+
207245@device_override @inline function KA. __validindex (ctx)
208246 if KA. __dynamic_checkbounds (ctx)
209247 I = @inbounds KA. expand (KA. __iterspace (ctx), get_group_id (1 ), get_local_id (1 ))
232270 work_group_barrier (POCL. LOCAL_MEM_FENCE | POCL. GLOBAL_MEM_FENCE)
233271end
234272
273+ @device_override @inline function KI. sub_group_barrier ()
274+ sub_group_barrier (POCL. LOCAL_MEM_FENCE | POCL. GLOBAL_MEM_FENCE)
275+ end
276+
277+ @device_override function KI. shfl_down (val:: T , offset:: Integer ) where {T}
278+ sub_group_shuffle (val, get_sub_group_local_id () + offset)
279+ end
280+
235281@device_override @inline function KI. _print (args... )
236282 POCL. _print (args... )
237283end
0 commit comments