Skip to content

Commit 4b626bf

Browse files
Address review: explicitly pass {T, Ti} to COO constructor
1 parent c55adc8 commit 4b626bf

1 file changed

Lines changed: 20 additions & 39 deletions

File tree

lib/cusparse/src/array.jl

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ function Base.similar(mat::CuSparseMatrixCOO, ::Type{T}, dims::Dims{2}) where {T
300300
new_rowInd = similar(mat.rowInd)
301301
new_colInd = similar(mat.colInd)
302302
new_nzVal = similar(mat.nzVal, T)
303-
return CuSparseMatrixCOO(new_rowInd, new_colInd, new_nzVal, dims)
303+
return CuSparseMatrixCOO{T, Ti}(new_rowInd, new_colInd, new_nzVal, dims)
304304
end
305305

306306
function Base.similar(mat::CuSparseMatrixBSR{Tv, Ti}, ::Type{T}, dims::Dims{2}) where {Tv, Ti, T}
@@ -470,57 +470,39 @@ Base.getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse
470470

471471
function Base.getindex(A::CuSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
472472
@boundscheck checkbounds(A, i)
473-
result = zero(Tv)
474-
for k in 1:nnz(A)
475-
A.iPtr[k] == i && (result = sum_duplicate(result, A.nzVal[k]))
476-
end
477-
return result
473+
ii = searchsortedfirst(A.iPtr, convert(Ti, i))
474+
(ii > nnz(A) || A.iPtr[ii] != i) && return zero(Tv)
475+
A.nzVal[ii]
478476
end
479477

480-
# Scalar getindex methods linear-scan the minor axis rather than binary-searching
481-
# and sum across matching entries. cuSPARSE formats don't guarantee sorted indices
482-
# within a major-axis slice (e.g. SpGEMM output may leave CSR columns unsorted
483-
# within a row, and COO is only guaranteed row-sorted), nor uniqueness — duplicate
484-
# (i, j) entries are permitted and their values sum, matching the convention of
485-
# Julia's `sparse()` constructor and SciPy/CuPy. For Bool we OR instead of sum,
486-
# also matching `sparse()`, since Bool + Bool doesn't stay Bool.
487-
sum_duplicate(a, b) = a + b
488-
sum_duplicate(a::Bool, b::Bool) = a | b
489-
490478
function Base.getindex(A::CuSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
491479
@boundscheck checkbounds(A, i0, i1)
492480
r1 = Int(A.colPtr[i1])
493481
r2 = Int(A.colPtr[i1+1]-1)
494-
result = zero(T)
495-
for k in r1:r2
496-
rowvals(A)[k] == i0 && (result = sum_duplicate(result, nonzeros(A)[k]))
497-
end
498-
return result
482+
(r1 > r2) && return zero(T)
483+
r1 = searchsortedfirst(rowvals(A), i0, r1, r2, Base.Order.Forward)
484+
(r1 > r2 || rowvals(A)[r1] != i0) && return zero(T)
485+
nonzeros(A)[r1]
499486
end
500487

501488
function Base.getindex(A::CuSparseMatrixCSR{T}, i0::Integer, i1::Integer) where T
502489
@boundscheck checkbounds(A, i0, i1)
503490
c1 = Int(A.rowPtr[i0])
504491
c2 = Int(A.rowPtr[i0+1]-1)
505-
result = zero(T)
506-
for k in c1:c2
507-
A.colVal[k] == i1 && (result = sum_duplicate(result, nonzeros(A)[k]))
508-
end
509-
return result
492+
(c1 > c2) && return zero(T)
493+
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
494+
(c1 > c2 || A.colVal[c1] != i1) && return zero(T)
495+
nonzeros(A)[c1]
510496
end
511497

512498
function Base.getindex(A::CuSparseMatrixCOO{T}, i0::Integer, i1::Integer) where T
513499
@boundscheck checkbounds(A, i0, i1)
514-
# cuSPARSE only guarantees COO is sorted by row, so binary-search the row
515-
# range but linear-scan for the column.
516500
r1 = searchsortedfirst(A.rowInd, i0, Base.Order.Forward)
517501
(r1 > length(A.rowInd) || A.rowInd[r1] > i0) && return zero(T)
518-
r2 = searchsortedlast(A.rowInd, i0, Base.Order.Forward)
519-
result = zero(T)
520-
for k in r1:r2
521-
A.colInd[k] == i1 && (result = sum_duplicate(result, nonzeros(A)[k]))
522-
end
523-
return result
502+
r2 = min(searchsortedfirst(A.rowInd, i0+1, Base.Order.Forward), length(A.rowInd))
503+
c1 = searchsortedfirst(A.colInd, i1, r1, r2, Base.Order.Forward)
504+
(c1 > r2 || c1 == length(A.colInd) + 1 || A.colInd[c1] > i1) && return zero(T)
505+
nonzeros(A)[c1]
524506
end
525507

526508
function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where T
@@ -530,11 +512,10 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
530512
block_idx = (i0_idx - 1) * A.blockDim + i1_idx - 1
531513
c1 = Int(A.rowPtr[i0_block])
532514
c2 = Int(A.rowPtr[i0_block+1]-1)
533-
result = zero(T)
534-
for k in c1:c2
535-
A.colVal[k] == i1_block && (result = sum_duplicate(result, nonzeros(A)[k+block_idx]))
536-
end
537-
return result
515+
(c1 > c2) && return zero(T)
516+
c1 = searchsortedfirst(A.colVal, i1_block, c1, c2, Base.Order.Forward)
517+
(c1 > c2 || A.colVal[c1] != i1_block) && return zero(T)
518+
nonzeros(A)[c1+block_idx]
538519
end
539520

540521
# matrix slices

0 commit comments

Comments
 (0)