@@ -75,54 +75,88 @@ Currently only RDNA 3 is supported and following types:
7575- ` FP16 ⋅ FP16 + FP32 -> FP32 ` ;
7676- ` BFP16 ⋅ BFP16 + FP32 -> FP32 ` .
7777
78+ All WMMA functionality is in the ` AMDGPU.Device.WMMA ` submodule.
79+ The tile dimensions are fixed at 16×16×16 (` WMMA.M ` , ` WMMA.N ` , ` WMMA.K ` ).
80+
81+ ### Layout types
82+
83+ Two layout types control how matrices are read from and written to memory:
84+
85+ - ` WMMA.ColMajor ` — column-major (Julia/Fortran) order: element ` (row, col) ` is at ` ptr[col * stride + row] ` .
86+ - ` WMMA.RowMajor ` — row-major (C) order: element ` (row, col) ` is at ` ptr[row * stride + col] ` .
87+
88+ ### API
89+
7890``` @docs
79- AMDGPU.Device.wmma_load_a
80- AMDGPU.Device.wmma_fill_c
81- AMDGPU.Device.wmma_store_d
82- AMDGPU.Device.wmma_mma
91+ AMDGPU.Device.WMMA.Fragment
92+ AMDGPU.Device.WMMA.fill_c
93+ AMDGPU.Device.WMMA.load_a
94+ AMDGPU.Device.WMMA.load_b
95+ AMDGPU.Device.WMMA.load_c
96+ AMDGPU.Device.WMMA.store_d
97+ AMDGPU.Device.WMMA.mma
8398```
8499
85- Below is a simple example of matrix multiplication kernel using WMMA.
100+ ` load_c ` and ` store_d ` accept pointer types ` Float32 ` , ` Float16 ` , and ` BFloat16 ` .
101+ When ` T ` is ` Float16 ` or ` BFloat16 ` , values are widened to ` Float32 ` on load and
102+ narrowed back on store, so the ` FragmentC_F32 ` accumulator type is always ` Float32 `
103+ regardless of the backing buffer type.
104+
105+ ### Example
106+
107+ Below is a matrix multiplication kernel using WMMA with column-major inputs.
108+ Pass ` WMMA.RowMajor ` instead to load from row-major (C-style) buffers.
86109
87110``` @example wmma-matmul
88111using AMDGPU
89- using AMDGPU.Device: WMMA_M, WMMA_N, WMMA_K, wmma_fill_c, wmma_load_a, wmma_load_b, wmma_store_d, wmma_mma
112+ using AMDGPU.Device: WMMA
90113
91- function wmma_kernel_ptr !(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32) where T
92- tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA_M )
93- tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA_N )
114+ function wmma_kernel !(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32, layout ) where T
115+ tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA.M )
116+ tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA.N )
94117
95118 C_ptr = pointer(C)
96119 A_ptr = pointer(A)
97120 B_ptr = pointer(B)
98121
99- c_frag = wmma_fill_c (Float32, 0f0)
122+ c_frag = WMMA.fill_c (Float32, 0f0)
100123 k = Int32(0)
101124 while k < K
102- a_ptr = A_ptr + (k * M + tile_row) * Int32(sizeof(T) )
103- b_ptr = B_ptr + ( tile_col * K + k) * Int32(sizeof(T) )
125+ a_ptr, a_stride = _a_tile( A_ptr, layout, tile_row, k, M, K, T )
126+ b_ptr, b_stride = _b_tile( B_ptr, layout, tile_col, k, N, K, T )
104127
105- a_frag = wmma_load_a (a_ptr, M )
106- b_frag = wmma_load_b (b_ptr, K )
107- c_frag = wmma_mma (a_frag, b_frag, c_frag)
128+ a_frag = WMMA.load_a (a_ptr, a_stride, layout )
129+ b_frag = WMMA.load_b (b_ptr, b_stride, layout )
130+ c_frag = WMMA.mma (a_frag, b_frag, c_frag)
108131
109- k += Int32(WMMA_K )
132+ k += Int32(WMMA.K )
110133 end
111134
112135 c_ptr = C_ptr + (tile_col * M + tile_row) * Int32(sizeof(Float32))
113- wmma_store_d (c_ptr, c_frag, M)
136+ WMMA.store_d (c_ptr, c_frag, M, WMMA.ColMajor )
114137 return
115138end
116139
140+ # Tile pointer + stride helpers — dispatched on layout, DCE'd by the compiler.
141+ _a_tile(ptr, ::Type{WMMA.ColMajor}, tile_row, k, M, K, ::Type{T}) where T =
142+ ptr + (k * M + tile_row) * Int32(sizeof(T)), M
143+ _a_tile(ptr, ::Type{WMMA.RowMajor}, tile_row, k, M, K, ::Type{T}) where T =
144+ ptr + (tile_row * K + k) * Int32(sizeof(T)), K
145+
146+ _b_tile(ptr, ::Type{WMMA.ColMajor}, tile_col, k, N, K, ::Type{T}) where T =
147+ ptr + (tile_col * K + k) * Int32(sizeof(T)), K
148+ _b_tile(ptr, ::Type{WMMA.RowMajor}, tile_col, k, N, K, ::Type{T}) where T =
149+ ptr + (k * N + tile_col) * Int32(sizeof(T)), N
150+
117151M, N, K = 32, 32, 32
118152A_host = Float16.(rand(M, K))
119153B_host = Float16.(rand(K, N))
120154A, B = ROCArray(A_host), ROCArray(B_host)
121155C = ROCArray(zeros(Float32, M, N))
122156
123- tiles_m, tiles_n = M ÷ WMMA_M , N ÷ WMMA_N
124- @roc gridsize=(tiles_m, tiles_n) groupsize=32 wmma_kernel_ptr !(
125- C, A, B, Int32(M), Int32(N), Int32(K))
157+ tiles_m, tiles_n = M ÷ WMMA.M , N ÷ WMMA.N
158+ @roc gridsize=(tiles_m, tiles_n) groupsize=32 wmma_kernel !(
159+ C, A, B, Int32(M), Int32(N), Int32(K), WMMA.ColMajor )
126160
127161@assert maximum(abs.(Float32.(C) .- (Float32.(A) * Float32.(B)))) < 0.1
128162```
0 commit comments