Skip to content

Commit

Permalink
fix(sklearn_serializer): convert sparse matrix components to the corr…
Browse files Browse the repository at this point in the history
…ect type
  • Loading branch information
raulmarindev committed Nov 7, 2024
1 parent 4fc3454 commit 85b230d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions openmodels/serializers/sklearn_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,14 @@ def _convert_to_sklearn_types(value: Any, attr_type: str = "none") -> Any:
# Base case: if attr_type is not a list, convert value based on attr_type
if isinstance(attr_type, str):
if attr_type == "csr_matrix":
# Convert indices and indptr to int32 for SVM compatibility
return csr_matrix(
(value["data"], value["indices"], value["indptr"]),
shape=value["shape"],
(
np.array(value["data"], dtype=np.float64),
np.array(value["indices"], dtype=np.int32),
np.array(value["indptr"], dtype=np.int32),
),
shape=tuple(value["shape"]),
)
elif attr_type == "ndarray":
return np.array(value)
Expand All @@ -209,8 +214,7 @@ def _convert_to_sklearn_types(value: Any, attr_type: str = "none") -> Any:
return float(value)
elif attr_type == "str":
return str(value)
# Add other types as needed
return value # Return as-is if no specific conversion is needed
return value

# Recursive case: if attr_type is a list, process each element in value
elif isinstance(attr_type, list) and isinstance(value, list):
Expand Down

0 comments on commit 85b230d

Please sign in to comment.