Skip to content

Commit

Permalink
fix(SUPPORTED_STIMATORS): Convert indices and indptr to int32 explici…
Browse files Browse the repository at this point in the history
…tly on scipy csr_matrix
  • Loading branch information
Gnpd committed Nov 7, 2024
1 parent 237c3e4 commit aee14ba
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions openmodels/serializers/sklearn_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,15 @@ def _convert_to_serializable_types(value: Any) -> Any:
return SklearnSerializer._array_to_list(value)
if isinstance(value, _csr.csr_matrix):
# Convert indices and indptr to int32 explicitly
csr_value = csr_matrix(
(
value.data,
value.indices.astype(np.int32),
value.indptr.astype(np.int32),
),
shape=value.shape,
)
csr_value = csr_matrix(value)
serialized_sparse_matrix = {
"data": SklearnSerializer._array_to_list(csr_value.data),
"indptr": SklearnSerializer._array_to_list(csr_value.indptr),
"indices": SklearnSerializer._array_to_list(csr_value.indices),
"indptr": SklearnSerializer._array_to_list(
csr_value.indptr.astype(np.int32)
),
"indices": SklearnSerializer._array_to_list(
csr_value.indices.astype(np.int32)
),
"shape": SklearnSerializer._array_to_list(csr_value.shape),
}
return serialized_sparse_matrix
Expand Down

0 comments on commit aee14ba

Please sign in to comment.