@@ -4,7 +4,7 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
44 offset:: Int # Offset is in number of elements (not bytes).
55
66 function ROCArray {T, N, B} (:: UndefInitializer , dims:: Dims{N} ) where {T, N, B <: Mem.AbstractAMDBuffer }
7- @assert isbitstype (T) " ROCArray only supports bits types "
7+ check_eltype ( " ROCArray" , T)
88 sz:: Int64 = prod (dims) * aligned_sizeof (T)
99 ref = GPUArrays. cached_alloc ((ROCArray, AMDGPU. device (), B, sz)) do
1010 @debug " Allocate `T=$T `, `dims=$dims `: $(Base. format_bytes (sz)) "
@@ -14,12 +14,80 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
1414 end
1515
1616 function ROCArray {T, N} (buf:: DataRef{Managed{B}} , dims:: Dims{N} ; offset:: Integer = 0 ) where {T, N, B <: Mem.AbstractAMDBuffer }
17- @assert isbitstype (T) " ROCArray only supports bits types "
17+ check_eltype ( " ROCArray" , T)
1818 xs = new {T, N, B} (buf, dims, offset)
1919 return finalizer (unsafe_free!, xs)
2020 end
2121end
2222
23+ function hasfieldcount (@nospecialize (dt))
24+ try
25+ fieldcount (dt)
26+ catch
27+ return false
28+ end
29+ return true
30+ end
31+
32+ explain_nonisbits (@nospecialize (T), depth= 0 ) = " " ^ depth * " $T is not a bitstype\n "
33+
34+ function explain_eltype (@nospecialize (T), depth= 0 ; maxdepth= 10 )
35+ depth > maxdepth && return " "
36+
37+ if T isa Union
38+ msg = " " ^ depth * " $T is a union that's not allocated inline\n "
39+ for U in Base. uniontypes (T)
40+ if ! Base. allocatedinline (U)
41+ msg *= explain_eltype (U, depth+ 1 )
42+ end
43+ end
44+ elseif Base. ismutabletype (T) && Base. datatype_fieldcount (T) != 0
45+ msg = " " ^ depth * " $T is a mutable type\n "
46+ elseif hasfieldcount (T)
47+ msg = " " ^ depth * " $T is a struct that's not allocated inline\n "
48+ for U in fieldtypes (T)
49+ if ! Base. allocatedinline (U)
50+ msg *= explain_nonisbits (U, depth+ 1 )
51+ end
52+ end
53+ else
54+ msg = " " ^ depth * " $T is not allocated inline\n "
55+ end
56+ return msg
57+ end
58+
59+ # ROCArray only supports element types that are allocated inline (`Base.allocatedinline`).
60+ # These come in three forms:
61+ # 1. plain bitstypes (`Int`, `(Float32, Float64)`, plain immutable structs, etc).
62+ # these are simply stored contiguously in memory.
63+ # 2. structs of unions (`struct Foo; x::Union{Int, Float32}; end`)
64+ # these are stored with a selector at the end (handled by Julia).
65+ # 3. bitstype unions (`Union{Int, Float32}`, etc)
66+ # these are stored contiguously and require a selector array (handled by us)
67+ # As well as "mutable singleton" types like `Symbol` that use pointer-identity
68+
69+ function valid_type (@nospecialize (T))
70+ if Base. allocatedinline (T)
71+ if hasfieldcount (T)
72+ return all (valid_type, fieldtypes (T))
73+ end
74+ return true
75+ elseif Base. ismutabletype (T)
76+ return Base. datatype_fieldcount (T) == 0
77+ end
78+ return false
79+ end
80+
81+
82+ @inline function check_eltype (name, T)
83+ if ! valid_type (T)
84+ explanation = explain_eltype (T)
85+ error ("""
86+ $name only supports element types that are allocated inline.
87+ $explanation """ )
88+ end
89+ end
90+
2391GPUArrays. storage (a:: ROCArray ) = a. buf
2492
2593function GPUArrays. derive (:: Type{T} , x:: ROCArray , dims:: Dims{N} , offset:: Int ) where {N, T}
@@ -190,7 +258,7 @@ function Base.unsafe_wrap(
190258 :: Type{<:ROCArray} , ptr:: Ptr{T} , dims:: NTuple{N, <:Integer} ;
191259 own:: Bool = false ,
192260) where {T,N}
193- @assert isbitstype (T) " Cannot wrap a non-bitstype pointer as a ROCArray "
261+ check_eltype ( " unsafe_wrap(CuArray, ...) " , T)
194262
195263 memtype = Mem. attributes (ptr). type
196264 B = if memtype == HIP. hipMemoryTypeUnregistered
@@ -209,7 +277,10 @@ function Base.unsafe_wrap(
209277 return ROCArray {T, N} (dref, dims)
210278end
211279
212- Base. unsafe_wrap (:: Type{ROCArray{T}} , ptr:: Ptr , dims; kwargs... ) where T =
280+ Base. unsafe_wrap (:: Type{<:ROCArray} , ptr:: Ptr , dim:: Integer ; own:: Bool = false ) =
281+ unsafe_wrap (ROCArray, ptr, (dim,); own)
282+
283+ Base. unsafe_wrap (:: Type{ROCArray{T}} , ptr:: Ptr , dims:: NTuple{N, <:Integer} ; kwargs... ) where {T, N} =
213284 unsafe_wrap (ROCArray, Base. unsafe_convert (Ptr{T}, ptr), dims; kwargs... )
214285
215286# # interop with CPU arrays
0 commit comments