Skip to content

Commit e522164

Browse files
authored
Rework WMMA (#892)
- Move to `WMMA` module, remove `WMMA_` prefixes. - Add `RowMajor`, `ColMajor` fragment loading options. - Allow loading `C` fragment from FP16 / BFP16 types.
1 parent aa88ffc commit e522164

3 files changed

Lines changed: 306 additions & 92 deletions

File tree

docs/src/api/kernel_programming.md

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,54 +75,88 @@ Currently only RDNA 3 is supported and following types:
7575
- `FP16 ⋅ FP16 + FP32 -> FP32`;
7676
- `BFP16 ⋅ BFP16 + FP32 -> FP32`.
7777

78+
All WMMA functionality is in the `AMDGPU.Device.WMMA` submodule.
79+
The tile dimensions are fixed at 16×16×16 (`WMMA.M`, `WMMA.N`, `WMMA.K`).
80+
81+
### Layout types
82+
83+
Two layout types control how matrices are read from and written to memory:
84+
85+
- `WMMA.ColMajor` — column-major (Julia/Fortran) order: element `(row, col)` is at `ptr[col * stride + row]`.
86+
- `WMMA.RowMajor` — row-major (C) order: element `(row, col)` is at `ptr[row * stride + col]`.
87+
88+
### API
89+
7890
```@docs
79-
AMDGPU.Device.wmma_load_a
80-
AMDGPU.Device.wmma_fill_c
81-
AMDGPU.Device.wmma_store_d
82-
AMDGPU.Device.wmma_mma
91+
AMDGPU.Device.WMMA.Fragment
92+
AMDGPU.Device.WMMA.fill_c
93+
AMDGPU.Device.WMMA.load_a
94+
AMDGPU.Device.WMMA.load_b
95+
AMDGPU.Device.WMMA.load_c
96+
AMDGPU.Device.WMMA.store_d
97+
AMDGPU.Device.WMMA.mma
8398
```
8499

85-
Below is a simple example of matrix multiplication kernel using WMMA.
100+
`load_c` and `store_d` accept pointer types `Float32`, `Float16`, and `BFloat16`.
101+
When `T` is `Float16` or `BFloat16`, values are widened to `Float32` on load and
102+
narrowed back on store, so the `FragmentC_F32` accumulator type is always `Float32`
103+
regardless of the backing buffer type.
104+
105+
### Example
106+
107+
Below is a matrix multiplication kernel using WMMA with column-major inputs.
108+
Pass `WMMA.RowMajor` instead to load from row-major (C-style) buffers.
86109

87110
```@example wmma-matmul
88111
using AMDGPU
89-
using AMDGPU.Device: WMMA_M, WMMA_N, WMMA_K, wmma_fill_c, wmma_load_a, wmma_load_b, wmma_store_d, wmma_mma
112+
using AMDGPU.Device: WMMA
90113
91-
function wmma_kernel_ptr!(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32) where T
92-
tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA_M)
93-
tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA_N)
114+
function wmma_kernel!(C, A::AbstractArray{T}, B, M::Int32, N::Int32, K::Int32, layout) where T
115+
tile_row = (workgroupIdx().x - Int32(1)) * Int32(WMMA.M)
116+
tile_col = (workgroupIdx().y - Int32(1)) * Int32(WMMA.N)
94117
95118
C_ptr = pointer(C)
96119
A_ptr = pointer(A)
97120
B_ptr = pointer(B)
98121
99-
c_frag = wmma_fill_c(Float32, 0f0)
122+
c_frag = WMMA.fill_c(Float32, 0f0)
100123
k = Int32(0)
101124
while k < K
102-
a_ptr = A_ptr + (k * M + tile_row) * Int32(sizeof(T))
103-
b_ptr = B_ptr + (tile_col * K + k) * Int32(sizeof(T))
125+
a_ptr, a_stride = _a_tile(A_ptr, layout, tile_row, k, M, K, T)
126+
b_ptr, b_stride = _b_tile(B_ptr, layout, tile_col, k, N, K, T)
104127
105-
a_frag = wmma_load_a(a_ptr, M)
106-
b_frag = wmma_load_b(b_ptr, K)
107-
c_frag = wmma_mma(a_frag, b_frag, c_frag)
128+
a_frag = WMMA.load_a(a_ptr, a_stride, layout)
129+
b_frag = WMMA.load_b(b_ptr, b_stride, layout)
130+
c_frag = WMMA.mma(a_frag, b_frag, c_frag)
108131
109-
k += Int32(WMMA_K)
132+
k += Int32(WMMA.K)
110133
end
111134
112135
c_ptr = C_ptr + (tile_col * M + tile_row) * Int32(sizeof(Float32))
113-
wmma_store_d(c_ptr, c_frag, M)
136+
WMMA.store_d(c_ptr, c_frag, M, WMMA.ColMajor)
114137
return
115138
end
116139
140+
# Tile pointer + stride helpers — dispatched on layout, DCE'd by the compiler.
141+
_a_tile(ptr, ::Type{WMMA.ColMajor}, tile_row, k, M, K, ::Type{T}) where T =
142+
ptr + (k * M + tile_row) * Int32(sizeof(T)), M
143+
_a_tile(ptr, ::Type{WMMA.RowMajor}, tile_row, k, M, K, ::Type{T}) where T =
144+
ptr + (tile_row * K + k) * Int32(sizeof(T)), K
145+
146+
_b_tile(ptr, ::Type{WMMA.ColMajor}, tile_col, k, N, K, ::Type{T}) where T =
147+
ptr + (tile_col * K + k) * Int32(sizeof(T)), K
148+
_b_tile(ptr, ::Type{WMMA.RowMajor}, tile_col, k, N, K, ::Type{T}) where T =
149+
ptr + (k * N + tile_col) * Int32(sizeof(T)), N
150+
117151
M, N, K = 32, 32, 32
118152
A_host = Float16.(rand(M, K))
119153
B_host = Float16.(rand(K, N))
120154
A, B = ROCArray(A_host), ROCArray(B_host)
121155
C = ROCArray(zeros(Float32, M, N))
122156
123-
tiles_m, tiles_n = M ÷ WMMA_M, N ÷ WMMA_N
124-
@roc gridsize=(tiles_m, tiles_n) groupsize=32 wmma_kernel_ptr!(
125-
C, A, B, Int32(M), Int32(N), Int32(K))
157+
tiles_m, tiles_n = M ÷ WMMA.M, N ÷ WMMA.N
158+
@roc gridsize=(tiles_m, tiles_n) groupsize=32 wmma_kernel!(
159+
C, A, B, Int32(M), Int32(N), Int32(K), WMMA.ColMajor)
126160
127161
@assert maximum(abs.(Float32.(C) .- (Float32.(A) * Float32.(B)))) < 0.1
128162
```

0 commit comments

Comments
 (0)