Skip to content

Commit a5b090c

Browse files
committed
Initial subgroups support
1 parent 8296740 commit a5b090c

2 files changed

Lines changed: 166 additions & 0 deletions

File tree

src/intrinsics.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,78 @@ Returns the unique group ID.
102102
"""
103103
function get_group_id end
104104

105+
"""
106+
get_sub_group_size()::UInt32
107+
108+
Returns the number of work-items in the sub-group.
109+
110+
!!! note
111+
Backend implementations **must** implement:
112+
```
113+
@device_override get_sub_group_size()::UInt32
114+
```
115+
"""
116+
function get_sub_group_size end
117+
118+
"""
119+
get_max_sub_group_size()::UInt32
120+
121+
Returns the maximum sub-group size for sub-groups in the current workgroup.
122+
123+
!!! note
124+
Backend implementations **must** implement:
125+
```
126+
@device_override get_max_sub_group_size()::UInt32
127+
```
128+
"""
129+
function get_max_sub_group_size end
130+
131+
"""
132+
get_num_sub_groups()::UInt32
133+
134+
Returns the number of sub-groups in the current workgroup.
135+
136+
!!! note
137+
Backend implementations **must** implement:
138+
```
139+
@device_override get_num_sub_groups()::UInt32
140+
```
141+
"""
142+
function get_num_sub_groups end
143+
144+
"""
145+
get_sub_group_id()::UInt32
146+
147+
Returns the sub-group ID within the work-group.
148+
149+
!!! note
150+
1-based.
151+
152+
!!! note
153+
Backend implementations **must** implement:
154+
```
155+
@device_override get_sub_group_id()::UInt32
156+
```
157+
"""
158+
function get_sub_group_id end
159+
160+
"""
161+
get_sub_group_local_id()::UInt32
162+
163+
Returns the work-item ID within the current sub-group.
164+
165+
!!! note
166+
1-based.
167+
168+
!!! note
169+
Backend implementations **must** implement:
170+
```
171+
@device_override get_sub_group_local_id()::UInt32
172+
```
173+
"""
174+
function get_sub_group_local_id end
175+
176+
105177
"""
106178
localmemory(::Type{T}, dims)
107179
@@ -139,6 +211,29 @@ function barrier()
139211
error("Group barrier used outside kernel or not captured")
140212
end
141213

214+
"""
215+
sub_group_barrier()
216+
217+
After a `sub_group_barrier()` call, all read and writes to global and local memory
218+
from each thread in the sub-group are visible in from all other threads in the
219+
sub-group.
220+
221+
This does **not** guarantee that a write from a thread in a certain sub-group will
222+
be visible to a thread in a different sub-group.
223+
224+
!!! note
225+
`sub_group_barrier()` must be encountered by all workitems of a sub-group executing the kernel or by none at all.
226+
227+
!!! note
228+
Backend implementations **must** implement:
229+
```
230+
@device_override sub_group_barrier()
231+
```
232+
"""
233+
function sub_group_barrier()
234+
error("Sub-group barrier used outside kernel or not captured")
235+
end
236+
142237
"""
143238
_print(args...)
144239
@@ -220,6 +315,22 @@ kernel launch with too big a workgroup is attempted.
220315
"""
221316
function max_work_group_size end
222317

318+
"""
319+
sub_group_size(backend)::Int
320+
321+
Returns a reasonable sub-group size supported by the currently
322+
active device for the specified backend. This would typically
323+
be 32, or 64 for devices that don't support 32.
324+
325+
!!! note
326+
Backend implementations **must** implement:
327+
```
328+
sub_group_size(backend::NewBackend)::Int
329+
```
330+
As well as the on-device functionality.
331+
"""
332+
function sub_group_size end
333+
223334
"""
224335
multiprocessor_count(backend::NewBackend)::Int
225336

test/intrinsics.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ function test_intrinsics_kernel(results)
2323
end
2424
return
2525
end
26+
struct SubgroupData
27+
sub_group_size::UInt32
28+
max_sub_group_size::UInt32
29+
num_sub_groups::UInt32
30+
sub_group_id::UInt32
31+
sub_group_local_id::UInt32
32+
end
33+
function test_subgroup_kernel(results)
34+
i = KI.get_global_id().x
35+
36+
if i <= length(results)
37+
@inbounds results[i] = SubgroupData(
38+
KI.get_sub_group_size(),
39+
KI.get_max_sub_group_size(),
40+
KI.get_num_sub_groups(),
41+
KI.get_sub_group_id(),
42+
KI.get_sub_group_local_id()
43+
)
44+
end
45+
return
46+
end
2647

2748
function intrinsics_testsuite(backend, AT)
2849
@testset "KernelIntrinsics Tests" begin
@@ -122,6 +143,40 @@ function intrinsics_testsuite(backend, AT)
122143
@test k_data.local_id == expected_local
123144
end
124145
end
146+
147+
@testset "Sub-groups" begin
148+
@test KI.sub_group_size(backend()) isa Int
149+
150+
# Test with small kernel
151+
sg_size = KI.sub_group_size(backend())
152+
sg_n = 2
153+
workgroupsize = sg_size * sg_n
154+
numworkgroups = 2
155+
N = workgroupsize * numworkgroups
156+
157+
results = AT(Vector{SubgroupData}(undef, N))
158+
kernel = KI.@kernel backend() launch = false test_subgroup_kernel(results)
159+
160+
kernel(results; workgroupsize, numworkgroups)
161+
KernelAbstractions.synchronize(backend())
162+
163+
host_results = Array(results)
164+
165+
# Verify results make sense
166+
for (i, sg_data) in enumerate(host_results)
167+
@test sg_data.sub_group_size == sg_size
168+
@test sg_data.max_sub_group_size == sg_size
169+
@test sg_data.num_sub_groups == sg_n
170+
171+
# Group ID should be 1-based
172+
expected_sub_group = div(((i - 1) % workgroupsize), sg_size) + 1
173+
@test sg_data.sub_group_id == expected_sub_group
174+
175+
# Local ID should be 1-based within group
176+
expected_sg_local = ((i - 1) % sg_size) + 1
177+
@test sg_data.sub_group_local_id == expected_sg_local
178+
end
179+
end
125180
end
126181
return nothing
127182
end

0 commit comments

Comments
 (0)