Skip to content

Commit 1909ff3

Browse files
committed
Add direct sparse gpu constructors for COO, CSC, CSR and BSR
1 parent 9528a33 commit 1909ff3

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

lib/cusparse/array.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ mutable struct CuSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC
5353
new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
5454
end
5555
end
56+
function GPUSparseMatrixCSC(colPtr::CuVector{Ti, 1}, rowVal::CuVector{Ti, 1}, nzVal::CuVector{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
57+
return CuSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
58+
end
5659
CuSparseMatrixCSC{Tv, Ti}(csc::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} = csc
5760

5861
SparseArrays.rowvals(g::T) where {T<:CuSparseVector} = nonzeroinds(g)
@@ -94,7 +97,9 @@ mutable struct CuSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR
9497
new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
9598
end
9699
end
97-
100+
function GPUSparseMatrixCSR(rowPtr::CuVector{Ti, 1}, colVal::CuVector{Ti, 1}, nzVal::CuVector{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
101+
return CuSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
102+
end
98103
CuSparseMatrixCSR{Tv, Ti}(csr::CuSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} = csr
99104
CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A
100105

@@ -147,6 +152,9 @@ mutable struct CuSparseMatrixBSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
147152
new{Tv, Ti}(rowPtr, colVal, nzVal, dims, blockDim, dir, nnz)
148153
end
149154
end
155+
function GPUSparseMatrixBSR(rowPtr::CuVector{Ti, 1}, colVal::CuVector{Ti, 1}, nzVal::CuVector{Tv, 1}, dims::NTuple{2,<:Integer}, blockDim::Integer, args...) where {Tv, Ti <: Integer}
156+
return CuSparseMatrixBSR{Tv, Ti}(rowPtr, colVal, nzVal, dims, blockDim, args...)
157+
end
150158

151159
CuSparseMatrixBSR(A::CuSparseMatrixBSR) = A
152160

@@ -177,7 +185,9 @@ mutable struct CuSparseMatrixCOO{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
177185
new{Tv, Ti}(rowInd,colInd,nzVal,dims,nnz)
178186
end
179187
end
180-
188+
function GPUSparseMatrixCOO(rowInd::CuVector{Ti, 1}, colInd::CuVector{Ti, 1}, nzVal::CuVector{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
189+
return CuSparseMatrixCOO{Tv, Ti}(rowInd, colInd, nzVal, dims)
190+
end
181191
CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A
182192

183193
mutable struct CuSparseArrayCSR{Tv, Ti, N} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, N}

0 commit comments

Comments
 (0)