Skip to content

Commit c02ef8d

Browse files
gbaraldiclaudevchuravy
authored
Replace ExceptionInfo with lightweight packed UInt64 atomic (#894)
* Replace heavyweight ExceptionInfo with single atomic UInt64 The old exception handling inlined ~20 flat_store_byte instructions at every error site (bounds checks, div-by-zero, etc.), writing a 56-byte ExceptionInfo struct byte-by-byte through flat memory. This bloated register usage by ~15 VGPRs per error site, reducing occupancy even though the error paths are never taken at runtime. Replace with a single UInt64 packed with workgroup IDs (16 bits each) and an error code (8 bits), written via one atomic CAS. Each error site now needs ~3 VGPRs instead of ~15. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent bee8fca commit c02ef8d

5 files changed

Lines changed: 133 additions & 124 deletions

File tree

src/AMDGPU.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ export workitemIdx, workgroupIdx, workgroupDim, gridItemDim, gridGroupDim
100100
export sync_workgroup, sync_workgroup_count, sync_workgroup_and, sync_workgroup_or
101101

102102
struct KernelState
103-
exception_info::Ptr{Device.ExceptionInfo}
103+
exception_info::Ptr{UInt64}
104104
malloc_hc::Ptr{Cvoid}
105105
free_hc::Ptr{Cvoid}
106106
output_context::Ptr{Cvoid}

src/device/exceptions.jl

Lines changed: 78 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,97 @@
1-
"""
2-
- `status::Int32`: whether exception has been thrown (0 - no, 1 - yes).
3-
"""
4-
struct ExceptionInfo
5-
status::Int32
6-
output_lock::Int32
1+
# Exception reason codes — encoded in bits [7:0] of the packed exception UInt64
2+
# 0 means no exception
3+
module ExceptionCode
4+
const NONE = UInt8(0)
5+
const UNKNOWN = UInt8(1)
6+
const BOUNDS_ERROR = UInt8(2)
7+
const DOMAIN_ERROR = UInt8(3)
8+
const OVERFLOW_ERROR = UInt8(4)
9+
const INEXACT_ERROR = UInt8(5)
10+
const ARGUMENT_ERROR = UInt8(6)
11+
const DIVIDE_ERROR = UInt8(7)
12+
const DIM_MISMATCH = UInt8(8)
13+
end
714

8-
thread::@NamedTuple{x::UInt32, y::UInt32, z::UInt32}
9-
block::@NamedTuple{x::UInt32, y::UInt32, z::UInt32}
15+
const EXCEPTION_REASON_STRINGS = Dict{UInt8, String}(
16+
ExceptionCode.NONE => "No exception",
17+
ExceptionCode.UNKNOWN => "Unknown exception",
18+
ExceptionCode.BOUNDS_ERROR => "BoundsError: Out-of-bounds array access",
19+
ExceptionCode.DOMAIN_ERROR => "DomainError",
20+
ExceptionCode.OVERFLOW_ERROR => "OverflowError",
21+
ExceptionCode.INEXACT_ERROR => "InexactError: Inexact conversion",
22+
ExceptionCode.ARGUMENT_ERROR => "ArgumentError",
23+
ExceptionCode.DIVIDE_ERROR => "DivideError: Integer division error",
24+
ExceptionCode.DIM_MISMATCH => "DimensionMismatch",
25+
)
1026

11-
reason::LLVMPtr{UInt8, AS.Global}
12-
reason_length::Int64
27+
# Packed exception format (UInt64):
28+
# [63:48] workgroup_x (16 bits)
29+
# [47:32] workgroup_y (16 bits)
30+
# [31:16] workgroup_z (16 bits)
31+
# [15:8] reserved
32+
# [7:0] error code (non-zero = exception occurred)
1333

14-
ExceptionInfo() = new(
15-
Int32(0), Int32(0),
16-
(; x=UInt32(0), y=UInt32(0), z=UInt32(0)),
17-
(; x=UInt32(0), y=UInt32(0), z=UInt32(0)),
18-
LLVMPtr{UInt8, AS.Global}(), 0)
34+
@inline function pack_exception(code::UInt8)
35+
wg = workgroupIdx()
36+
wg_x = UInt64(wg.x % UInt16) << 48
37+
wg_y = UInt64(wg.y % UInt16) << 32
38+
wg_z = UInt64(wg.z % UInt16) << 16
39+
return wg_x | wg_y | wg_z | UInt64(code)
1940
end
2041

21-
@inline function Base.getproperty(ei::Ptr{ExceptionInfo}, field::Symbol)
22-
if field == :status
23-
unsafe_load(convert(Ptr{Int32}, ei))
24-
elseif field == :output_lock
25-
unsafe_load(convert(Ptr{Int32}, ei + sizeof(Int32)))
26-
elseif field == :output_lock_ptr
27-
reinterpret(LLVMPtr{Int32, AS.Generic}, ei + sizeof(Int32))
28-
elseif field == :thread
29-
offset = 2 * sizeof(Int32)
30-
unsafe_load(convert(Ptr{@NamedTuple{x::UInt32, y::UInt32, z::UInt32}}, ei + offset))
31-
elseif field == :block
32-
offset = 2 * sizeof(Int32) + sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32})
33-
unsafe_load(convert(Ptr{@NamedTuple{x::UInt32, y::UInt32, z::UInt32}}, ei + offset))
34-
elseif field == :reason
35-
offset =
36-
2 * sizeof(Int32) +
37-
2 * sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32})
38-
unsafe_load(convert(Ptr{LLVMPtr{UInt8, AS.Global}}, ei + offset))
39-
elseif field == :reason_length
40-
offset =
41-
2 * sizeof(Int32) +
42-
2 * sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32}) +
43-
sizeof(LLVMPtr{UInt8, AS.Global})
44-
unsafe_load(convert(Ptr{Int64}, ei + offset))
45-
else
46-
getfield(ei, field)
47-
end
42+
@inline function unpack_exception(packed::UInt64)
43+
wg_x = UInt16((packed >> 48) & 0xFFFF)
44+
wg_y = UInt16((packed >> 32) & 0xFFFF)
45+
wg_z = UInt16((packed >> 16) & 0xFFFF)
46+
code = UInt8(packed & 0xFF)
47+
return (; wg_x, wg_y, wg_z, code)
4848
end
4949

50-
@inline function Base.setproperty!(ei::Ptr{ExceptionInfo}, field::Symbol, value)
51-
if field == :status
52-
unsafe_store!(convert(Ptr{Int32}, ei), value)
53-
elseif field == :output_lock
54-
unsafe_store!(convert(Ptr{Int32}, ei + sizeof(Int32)), value)
55-
elseif field == :thread
56-
offset = 2 * sizeof(Int32)
57-
unsafe_store!(convert(Ptr{@NamedTuple{x::UInt32, y::UInt32, z::UInt32}}, ei + offset), value)
58-
elseif field == :block
59-
offset = 2 * sizeof(Int32) + sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32})
60-
unsafe_store!(convert(Ptr{@NamedTuple{x::UInt32, y::UInt32, z::UInt32}}, ei + offset), value)
61-
elseif field == :reason
62-
offset =
63-
2 * sizeof(Int32) +
64-
2 * sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32})
65-
unsafe_store!(convert(Ptr{LLVMPtr{UInt8, AS.Global}}, ei + offset), value)
66-
elseif field == :reason_length
67-
offset =
68-
2 * sizeof(Int32) +
69-
2 * sizeof(@NamedTuple{x::UInt32, y::UInt32, z::UInt32}) +
70-
sizeof(LLVMPtr{UInt8, AS.Global})
71-
unsafe_store!(convert(Ptr{Int64}, ei + offset), value)
72-
else
73-
setfield!(ei, field, value)
74-
end
75-
end
50+
# Legacy compat — ExceptionInfo is now just a UInt64
51+
const ExceptionInfo = UInt64
7652

7753
function alloc_exception_info()
78-
ei_ptr = Mem.HostBuffer(sizeof(ExceptionInfo), HIP.hipHostAllocDefault)
79-
unsafe_store!(convert(Ptr{ExceptionInfo}, ei_ptr), ExceptionInfo())
54+
ei_ptr = Mem.HostBuffer(sizeof(UInt64), HIP.hipHostAllocDefault)
55+
unsafe_store!(convert(Ptr{UInt64}, ei_ptr), UInt64(0))
8056
return ei_ptr
8157
end
8258

83-
@inline function lock_output!(ei::Ptr{ExceptionInfo})
84-
# if llvm_atomic_cas(ei.output_lock_ptr, zero(Int32), one(Int32)) == zero(Int32)
85-
if llvm_atomic_cas(ei.output_lock_ptr, Int32(0x0), Int32(0x1)) == Int32(0x0)
86-
# Take the lock & write thread info.
87-
ei.thread = workitemIdx()
88-
ei.block = workgroupIdx()
89-
sync_workgroup()
90-
return true
91-
elseif (
92-
ei.output_lock == Int32(0x1) &&
93-
ei.thread == workitemIdx() &&
94-
ei.block == workgroupIdx()
95-
)
96-
# Thread already has the lock.
97-
return true
98-
else
99-
# Other thread has the lock.
100-
return false
101-
end
59+
@inline function signal_exception!(ei::Ptr{UInt64}, code::UInt8)
60+
packed = pack_exception(code)
61+
# First writer wins via atomic CAS; losers are no-ops.
62+
ei_llvm = reinterpret(LLVMPtr{UInt64, AS.Generic}, ei)
63+
llvm_atomic_cas(ei_llvm, UInt64(0), packed)
64+
endpgm()
65+
return
10266
end
10367

10468
macro gpu_throw(reason)
69+
code = _reason_to_code(reason)
10570
quote
10671
ei = kernel_state().exception_info
107-
if lock_output!(ei)
108-
reason_ptr, reason_length = @strptr $reason
109-
ei.reason = reason_ptr
110-
ei.reason_length = reason_length
111-
end
112-
throw(nothing)
72+
signal_exception!(ei, $code)
73+
throw(nothing) # unreachable, but keeps Julia's type system happy
74+
end
75+
end
76+
77+
# Map reason strings to error codes at macro expansion time
78+
function _reason_to_code(reason::String)
79+
if startswith(reason, "BoundsError")
80+
ExceptionCode.BOUNDS_ERROR
81+
elseif startswith(reason, "DomainError")
82+
ExceptionCode.DOMAIN_ERROR
83+
elseif startswith(reason, "OverflowError")
84+
ExceptionCode.OVERFLOW_ERROR
85+
elseif startswith(reason, "InexactError")
86+
ExceptionCode.INEXACT_ERROR
87+
elseif startswith(reason, "ArgumentError")
88+
ExceptionCode.ARGUMENT_ERROR
89+
elseif startswith(reason, "DivideError")
90+
ExceptionCode.DIVIDE_ERROR
91+
elseif startswith(reason, "DimensionMismatch")
92+
ExceptionCode.DIM_MISMATCH
93+
else
94+
ExceptionCode.UNKNOWN
11395
end
11496
end
97+
_reason_to_code(reason) = ExceptionCode.UNKNOWN

src/device/runtime.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ end
7777

7878
function signal_exception()
7979
ei = kernel_state().exception_info
80-
ei.status = Int32(0x1)
81-
# Lock in case it was not locked before, to get workitem and workgroup info.
82-
lock_output!(ei)
83-
endpgm() # Without endpgm we'll get hardware exception.
80+
signal_exception!(ei, ExceptionCode.UNKNOWN)
8481
return
8582
end
8683

src/exception_handler.jl

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,30 @@
22
const GLOBAL_EXCEPTION_INFO = Dict{UInt, Mem.HostBuffer}()
33

44
# TODO RT_LOCK?
5-
function exception_info(dev::HIPDevice)::Ptr{Device.ExceptionInfo}
5+
function exception_info(dev::HIPDevice)::Ptr{UInt64}
66
ei = get!(
77
() -> Device.alloc_exception_info(),
88
GLOBAL_EXCEPTION_INFO, hash(dev))
9-
return convert(Ptr{Device.ExceptionInfo}, Mem.device_ptr(ei))
9+
return convert(Ptr{UInt64}, Mem.device_ptr(ei))
1010
end
1111

1212
function has_exception(dev::HIPDevice)::Bool
13-
return exception_info(dev).status != 0
13+
return unsafe_load(exception_info(dev)) != UInt64(0)
1414
end
1515

1616
function reset_exception_info!(dev::HIPDevice)
17-
unsafe_store!(exception_info(dev), Device.ExceptionInfo())
17+
unsafe_store!(exception_info(dev), UInt64(0))
1818
return
1919
end
2020

21-
function device_str_to_host(str_ptr, str_length)
22-
str_length == 0 && return ""
23-
24-
buf = Vector{UInt8}(undef, str_length)
25-
HSA.memory_copy(
26-
convert(Ptr{Cvoid}, pointer(buf)),
27-
reinterpret(Ptr{Cvoid}, str_ptr), str_length) |> Runtime.check
28-
return String(buf)
29-
end
30-
3121
function get_exception_info_string(dev::HIPDevice)
32-
ei = exception_info(dev)
33-
reason = device_str_to_host(ei.reason, ei.reason_length)
34-
35-
workitemIdx = ei.thread
36-
workgroupIdx = ei.block
22+
packed = unsafe_load(exception_info(dev))
23+
info = Device.unpack_exception(packed)
24+
reason = get(Device.EXCEPTION_REASON_STRINGS, info.code, "Unkown error code $(info.code)")
3725

38-
isempty(reason) && (reason = "Unknown reason";)
3926
return """GPU Kernel Exception:
4027
$reason
41-
workitemIdx: $workitemIdx
42-
workgroupIdx: $workgroupIdx"""
28+
workgroupIdx: ($(info.wg_x), $(info.wg_y), $(info.wg_z))"""
4329
end
4430

4531
function throw_if_exception(dev::HIPDevice)

test/device/exceptions.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,52 @@ using StaticArrays
1616
AMDGPU.synchronize()
1717
catch err
1818
@test err isa ErrorException
19+
@test occursin("GPU Kernel Exception", err.msg)
1920
end
20-
# TODO check exception message
21-
# TODO check specific exception type
21+
end
22+
23+
@testset "Exception codegen" begin
24+
# Kernel with multiple div() calls — each generates an error path
25+
function div_kernel(X, a, b, c, d)
26+
i = workitemIdx().x
27+
x = div(a, b)
28+
y = div(c, d)
29+
X[i] = x + y
30+
return
31+
end
32+
33+
iob = IOBuffer()
34+
AMDGPU.code_native(iob, div_kernel, Tuple{
35+
Device.ROCDeviceArray{Int64, 1, 1},
36+
Int64, Int64, Int64, Int64,
37+
}; kernel=true)
38+
asm = String(take!(iob))
39+
40+
# The new lightweight exception path should NOT generate flat_store_byte
41+
# instructions for writing ExceptionInfo fields. Previously each div check
42+
# inlined ~20 flat_store_byte for the 56-byte ExceptionInfo struct.
43+
n_flat_store_byte = count("flat_store_byte", asm)
44+
@test n_flat_store_byte == 0
45+
46+
# Should use global_atomic_cmpswap for the exception flag instead
47+
@test occursin("global_atomic_cmpswap", asm) || occursin("flat_atomic_cmpswap", asm)
48+
49+
# Kernel with bounds-checked array access
50+
function boundscheck_kernel(X, Y)
51+
i = workitemIdx().x
52+
X[i] = Y[i] + Y[i+1]
53+
return
54+
end
55+
56+
iob2 = IOBuffer()
57+
AMDGPU.code_native(iob2, boundscheck_kernel, Tuple{
58+
Device.ROCDeviceArray{Float64, 1, 1},
59+
Device.ROCDeviceArray{Float64, 1, 1},
60+
}; kernel=true)
61+
asm2 = String(take!(iob2))
62+
63+
@test count("flat_store_byte", asm2) == 0
64+
@test occursin("global_atomic_cmpswap", asm2) || occursin("flat_atomic_cmpswap", asm2)
2265
end
2366

2467
if VERSION v"1.11-"

0 commit comments

Comments
 (0)