Skip to content

Commit 404370b

Browse files
vchuravyclaude
andauthored
Support Symbols on the GPU (#888)
--------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b7fb0b0 commit 404370b

3 files changed

Lines changed: 113 additions & 19 deletions

File tree

src/array.jl

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
44
offset::Int # Offset is in number of elements (not bytes).
55

66
function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer}
7-
@assert isbitstype(T) "ROCArray only supports bits types"
7+
check_eltype("ROCArray", T)
88
sz::Int64 = prod(dims) * aligned_sizeof(T)
99
ref = GPUArrays.cached_alloc((ROCArray, AMDGPU.device(), B, sz)) do
1010
@debug "Allocate `T=$T`, `dims=$dims`: $(Base.format_bytes(sz))"
@@ -14,12 +14,80 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
1414
end
1515

1616
function ROCArray{T, N}(buf::DataRef{Managed{B}}, dims::Dims{N}; offset::Integer = 0) where {T, N, B <: Mem.AbstractAMDBuffer}
17-
@assert isbitstype(T) "ROCArray only supports bits types"
17+
check_eltype("ROCArray", T)
1818
xs = new{T, N, B}(buf, dims, offset)
1919
return finalizer(unsafe_free!, xs)
2020
end
2121
end
2222

23+
function hasfieldcount(@nospecialize(dt))
24+
try
25+
fieldcount(dt)
26+
catch
27+
return false
28+
end
29+
return true
30+
end
31+
32+
explain_nonisbits(@nospecialize(T), depth=0) = " "^depth * "$T is not a bitstype\n"
33+
34+
function explain_eltype(@nospecialize(T), depth=0; maxdepth=10)
35+
depth > maxdepth && return ""
36+
37+
if T isa Union
38+
msg = " "^depth * "$T is a union that's not allocated inline\n"
39+
for U in Base.uniontypes(T)
40+
if !Base.allocatedinline(U)
41+
msg *= explain_eltype(U, depth+1)
42+
end
43+
end
44+
elseif Base.ismutabletype(T) && Base.datatype_fieldcount(T) != 0
45+
msg = " "^depth * "$T is a mutable type\n"
46+
elseif hasfieldcount(T)
47+
msg = " "^depth * "$T is a struct that's not allocated inline\n"
48+
for U in fieldtypes(T)
49+
if !Base.allocatedinline(U)
50+
msg *= explain_nonisbits(U, depth+1)
51+
end
52+
end
53+
else
54+
msg = " "^depth * "$T is not allocated inline\n"
55+
end
56+
return msg
57+
end
58+
59+
# ROCArray only supports element types that are allocated inline (`Base.allocatedinline`).
60+
# These come in three forms:
61+
# 1. plain bitstypes (`Int`, `(Float32, Float64)`, plain immutable structs, etc).
62+
# these are simply stored contiguously in memory.
63+
# 2. structs of unions (`struct Foo; x::Union{Int, Float32}; end`)
64+
# these are stored with a selector at the end (handled by Julia).
65+
# 3. bitstype unions (`Union{Int, Float32}`, etc)
66+
# these are stored contiguously and require a selector array (handled by us)
67+
# As well as "mutable singleton" types like `Symbol` that use pointer-identity
68+
69+
function valid_type(@nospecialize(T))
70+
if Base.allocatedinline(T)
71+
if hasfieldcount(T)
72+
return all(valid_type, fieldtypes(T))
73+
end
74+
return true
75+
elseif Base.ismutabletype(T)
76+
return Base.datatype_fieldcount(T) == 0
77+
end
78+
return false
79+
end
80+
81+
82+
@inline function check_eltype(name, T)
83+
if !valid_type(T)
84+
explanation = explain_eltype(T)
85+
error("""
86+
$name only supports element types that are allocated inline.
87+
$explanation""")
88+
end
89+
end
90+
2391
GPUArrays.storage(a::ROCArray) = a.buf
2492

2593
function GPUArrays.derive(::Type{T}, x::ROCArray, dims::Dims{N}, offset::Int) where {N, T}
@@ -190,7 +258,7 @@ function Base.unsafe_wrap(
190258
::Type{<:ROCArray}, ptr::Ptr{T}, dims::NTuple{N, <:Integer};
191259
own::Bool = false,
192260
) where {T,N}
193-
@assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray"
261+
check_eltype("unsafe_wrap(CuArray, ...)", T)
194262

195263
memtype = Mem.attributes(ptr).type
196264
B = if memtype == HIP.hipMemoryTypeUnregistered
@@ -209,7 +277,10 @@ function Base.unsafe_wrap(
209277
return ROCArray{T, N}(dref, dims)
210278
end
211279

212-
Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T =
280+
Base.unsafe_wrap(::Type{<:ROCArray}, ptr::Ptr, dim::Integer; own::Bool=false) =
281+
unsafe_wrap(ROCArray, ptr, (dim,); own)
282+
283+
Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims::NTuple{N, <:Integer}; kwargs...) where {T, N} =
213284
unsafe_wrap(ROCArray, Base.unsafe_convert(Ptr{T}, ptr), dims; kwargs...)
214285

215286
## interop with CPU arrays

src/runtime/hip-execution.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
# In contrast to `Base.RefValue` we just need a container for both pass-by-ref (Symbol),
2+
# and pass-by-value (immutable structs).
3+
mutable struct ArgBox{T}
4+
const val::T
5+
end
6+
7+
function Base.unsafe_convert(P::Union{Type{Ptr{T}}, Type{Ptr{Cvoid}}}, b::ArgBox{T})::P where {T}
8+
return pointer_from_objref(b)
9+
end
10+
111
"""
212
(ker::HIPKernel)(args::Vararg{Any, N}; kwargs...)
313
@@ -29,15 +39,6 @@ end
2939
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
3040
call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
3141

32-
# replace non-isbits arguments (they should be unused, or compilation would have failed)
33-
# alternatively, allow `launch` with non-isbits arguments.
34-
for (i,dt) in enumerate(call_t)
35-
if !isbitstype(dt)
36-
call_t[i] = Ptr{Any}
37-
call_args[i] = :C_NULL
38-
end
39-
end
40-
4142
# add the kernel state
4243
pushfirst!(call_t, AMDGPU.KernelState)
4344
pushfirst!(call_args, :(AMDGPU.KernelState(
@@ -87,17 +88,12 @@ function roccall(fun::F, tt::Type{T}, args::Vararg{Any, N}; kwargs...) where {F,
8788
end
8889

8990
@inline @generated function pack_arguments(f::Function, args...)
90-
for arg in args
91-
isbitstype(arg) || throw(ArgumentError(
92-
"Arguments to kernel should be bitstype, instead `$(arg)` was given."))
93-
end
94-
9591
ex = quote end
9692

9793
arg_refs = Vector{Symbol}(undef, length(args))
9894
for i in 1:length(args)
9995
arg_refs[i] = gensym()
100-
push!(ex.args, :($(arg_refs[i]) = Base.RefValue(args[$i])))
96+
push!(ex.args, :($(arg_refs[i]) = $ArgBox(args[$i])))
10197
end
10298

10399
arg_ptrs = [

test/core/rocarray_base.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,20 @@ end
160160
y .= @view(xd[1:3, :, :])
161161
@test Array(y) @view(x[1:3, :, :])
162162
end
163+
164+
@testset "Symbols" begin
165+
# symbols and tuples thereof
166+
let a = ROCArray([:a])
167+
b = unsafe_wrap(ROCArray, pointer(a), 1)
168+
@test typeof(b) <: ROCArray{Symbol,1}
169+
@test size(b) == (1,)
170+
end
171+
let a = ROCArray([(:a,:b)])
172+
b = unsafe_wrap(ROCArray, pointer(a), 1)
173+
@test typeof(b) <: ROCArray{Tuple{Symbol,Symbol},1}
174+
@test size(b) == (1,)
175+
end
176+
end
163177
end
164178

165179
@testset "unsafe_free" begin
@@ -224,4 +238,17 @@ end
224238
@test Array(dtarget) == target
225239
end
226240

241+
@testset "Symbols" begin
242+
function pass_symbol(x, name)
243+
i = name == :var ? 1 : 2
244+
x[i] = true
245+
return nothing
246+
end
247+
x = ROCArray([false, false])
248+
@roc pass_symbol(x, :var)
249+
@test Array(x) == [true, false]
250+
@roc pass_symbol(x, :not_var)
251+
@test Array(x) == [true, true]
252+
end
253+
227254
end

0 commit comments

Comments
 (0)