@@ -300,7 +300,7 @@ function Base.similar(mat::CuSparseMatrixCOO, ::Type{T}, dims::Dims{2}) where {T
300300 new_rowInd = similar (mat. rowInd)
301301 new_colInd = similar (mat. colInd)
302302 new_nzVal = similar (mat. nzVal, T)
303- return CuSparseMatrixCOO (new_rowInd, new_colInd, new_nzVal, dims)
303+ return CuSparseMatrixCOO {T, Ti} (new_rowInd, new_colInd, new_nzVal, dims)
304304end
305305
306306function Base. similar (mat:: CuSparseMatrixBSR{Tv, Ti} , :: Type{T} , dims:: Dims{2} ) where {Tv, Ti, T}
@@ -470,57 +470,39 @@ Base.getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse
470470
471471function Base. getindex (A:: CuSparseVector{Tv, Ti} , i:: Integer ) where {Tv, Ti}
472472 @boundscheck checkbounds (A, i)
473- result = zero (Tv)
474- for k in 1 : nnz (A)
475- A. iPtr[k] == i && (result = sum_duplicate (result, A. nzVal[k]))
476- end
477- return result
473+ ii = searchsortedfirst (A. iPtr, convert (Ti, i))
474+ (ii > nnz (A) || A. iPtr[ii] != i) && return zero (Tv)
475+ A. nzVal[ii]
478476end
479477
480- # Scalar getindex methods linear-scan the minor axis rather than binary-searching
481- # and sum across matching entries. cuSPARSE formats don't guarantee sorted indices
482- # within a major-axis slice (e.g. SpGEMM output may leave CSR columns unsorted
483- # within a row, and COO is only guaranteed row-sorted), nor uniqueness — duplicate
484- # (i, j) entries are permitted and their values sum, matching the convention of
485- # Julia's `sparse()` constructor and SciPy/CuPy. For Bool we OR instead of sum,
486- # also matching `sparse()`, since Bool + Bool doesn't stay Bool.
487- sum_duplicate (a, b) = a + b
488- sum_duplicate (a:: Bool , b:: Bool ) = a | b
489-
490478function Base. getindex (A:: CuSparseMatrixCSC{T} , i0:: Integer , i1:: Integer ) where T
491479 @boundscheck checkbounds (A, i0, i1)
492480 r1 = Int (A. colPtr[i1])
493481 r2 = Int (A. colPtr[i1+ 1 ]- 1 )
494- result = zero (T)
495- for k in r1: r2
496- rowvals (A)[k] == i0 && (result = sum_duplicate (result, nonzeros (A)[k]))
497- end
498- return result
482+ (r1 > r2) && return zero (T)
483+ r1 = searchsortedfirst (rowvals (A), i0, r1, r2, Base. Order. Forward)
484+ (r1 > r2 || rowvals (A)[r1] != i0) && return zero (T)
485+ nonzeros (A)[r1]
499486end
500487
501488function Base. getindex (A:: CuSparseMatrixCSR{T} , i0:: Integer , i1:: Integer ) where T
502489 @boundscheck checkbounds (A, i0, i1)
503490 c1 = Int (A. rowPtr[i0])
504491 c2 = Int (A. rowPtr[i0+ 1 ]- 1 )
505- result = zero (T)
506- for k in c1: c2
507- A. colVal[k] == i1 && (result = sum_duplicate (result, nonzeros (A)[k]))
508- end
509- return result
492+ (c1 > c2) && return zero (T)
493+ c1 = searchsortedfirst (A. colVal, i1, c1, c2, Base. Order. Forward)
494+ (c1 > c2 || A. colVal[c1] != i1) && return zero (T)
495+ nonzeros (A)[c1]
510496end
511497
512498function Base. getindex (A:: CuSparseMatrixCOO{T} , i0:: Integer , i1:: Integer ) where T
513499 @boundscheck checkbounds (A, i0, i1)
514- # cuSPARSE only guarantees COO is sorted by row, so binary-search the row
515- # range but linear-scan for the column.
516500 r1 = searchsortedfirst (A. rowInd, i0, Base. Order. Forward)
517501 (r1 > length (A. rowInd) || A. rowInd[r1] > i0) && return zero (T)
518- r2 = searchsortedlast (A. rowInd, i0, Base. Order. Forward)
519- result = zero (T)
520- for k in r1: r2
521- A. colInd[k] == i1 && (result = sum_duplicate (result, nonzeros (A)[k]))
522- end
523- return result
502+ r2 = min (searchsortedfirst (A. rowInd, i0+ 1 , Base. Order. Forward), length (A. rowInd))
503+ c1 = searchsortedfirst (A. colInd, i1, r1, r2, Base. Order. Forward)
504+ (c1 > r2 || c1 == length (A. colInd) + 1 || A. colInd[c1] > i1) && return zero (T)
505+ nonzeros (A)[c1]
524506end
525507
526508function Base. getindex (A:: CuSparseMatrixBSR{T} , i0:: Integer , i1:: Integer ) where T
@@ -530,11 +512,10 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
530512 block_idx = (i0_idx - 1 ) * A. blockDim + i1_idx - 1
531513 c1 = Int (A. rowPtr[i0_block])
532514 c2 = Int (A. rowPtr[i0_block+ 1 ]- 1 )
533- result = zero (T)
534- for k in c1: c2
535- A. colVal[k] == i1_block && (result = sum_duplicate (result, nonzeros (A)[k+ block_idx]))
536- end
537- return result
515+ (c1 > c2) && return zero (T)
516+ c1 = searchsortedfirst (A. colVal, i1_block, c1, c2, Base. Order. Forward)
517+ (c1 > c2 || A. colVal[c1] != i1_block) && return zero (T)
518+ nonzeros (A)[c1+ block_idx]
538519end
539520
540521# matrix slices
0 commit comments