Skip to content

Commit c55adc8

Browse files
Added Base.similar methods for CuSparseMatrixCOO and BSR
1 parent 5e4118b commit c55adc8

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

lib/cusparse/src/array.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,22 @@ Base.similar(Mat::CuSparseMatrixCOO, dims::Tuple{Int, Int}) = similar(Mat, dims.
295295

296296
Base.similar(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
297297

298+
299+
function Base.similar(mat::CuSparseMatrixCOO, ::Type{T}, dims::Dims{2}) where {T}
300+
new_rowInd = similar(mat.rowInd)
301+
new_colInd = similar(mat.colInd)
302+
new_nzVal = similar(mat.nzVal, T)
303+
return CuSparseMatrixCOO(new_rowInd, new_colInd, new_nzVal, dims)
304+
end
305+
306+
function Base.similar(mat::CuSparseMatrixBSR{Tv, Ti}, ::Type{T}, dims::Dims{2}) where {Tv, Ti, T}
307+
new_rowPtr = similar(mat.rowPtr)
308+
new_colVal = similar(mat.colVal)
309+
new_nzVal = similar(mat.nzVal, T)
310+
311+
return CuSparseMatrixBSR{T, Ti}(new_rowPtr, new_colVal, new_nzVal, dims, mat.blockDim, mat.dir, mat.nnzb)
312+
end
313+
298314
## array interface
299315

300316
Base.length(g::CuSparseVector) = g.len

0 commit comments

Comments
 (0)