I would like to complete the answer given by @seberg. If one wants to set nnz values to zero, then one should modify the structure of the CSR matrix instead of just modifying .data
attribute.
Current behavior of this code is,
>>> import scipy.sparse
>>> import numpy as np
>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> A.toarray()
array([[0, 1, 0],
[2, 0, 3],
[0, 0, 0],
[4, 0, 0]], dtype=int64)
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
[0, 0, 0],
[0, 0, 0],
[4, 0, 0]], dtype=int64)
>>> A.data
array([1, 0, 0, 4], dtype=int64)
>>> A.indices
array([1, 0, 2, 0], dtype=int32)
>>> A.indptr
array([0, 1, 3, 3, 4], dtype=int32)
Because we are dealing with sparse matrices, we do not want zeros in the A.data
array. I think one should modify the csr_row_set_nz_to_val
as follows
def csr_row_set_nz_to_val(csr, row, value=0):
"""Set all nonzero elements of a CSR matrix M (elements currently in the sparsity pattern)
to the given value. Useful to set to 0 mostly.
"""
if not isinstance(csr, scipy.sparse.csr_matrix):
raise ValueError("Matrix given must be of CSR format.")
if value == 0:
csr.data = np.delete(csr.data, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz values
csr.indices = np.delete(csr.indices, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz column indices
csr.indptr[(row+1):] = csr.indptr[(row+1):] - (csr.indptr[row+1] - csr.indptr[row])
else:
csr.data[csr.indptr[row]:csr.indptr[row+1]] = value # replace nnz values by another nnz value
Finally, we would get instead
>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
[0, 0, 0],
[0, 0, 0],
[4, 0, 0]], dtype=int64)
>>> A.data
array([1, 4], dtype=int64)
>>> A.indices
array([1, 0], dtype=int32)
>>> A.indptr
array([0, 1, 1, 1, 2], dtype=int32)