Skip to content

Commit

Permalink
.get() before returning for uvm compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
rishic3 committed Jan 15, 2025
1 parent 96288e8 commit 79a2781
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/src/spark_rapids_ml/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,19 +1169,19 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
yield pd.DataFrame(
data=[
{
"embedding_": list(embedding[start:end]),
"indices": list(indices),
"indptr": list(indptr),
"data": list(data),
"embedding_": list(embedding[start:end].get()),
"indices": list(indices.get()),
"indptr": list(indptr.get()),
"data": list(data.get()),
"shape": [end - start, dimension],
}
]
)
else:
yield pd.DataFrame(
{
"embedding_": list(embedding[start:end]),
"raw_data_": list(raw_data[start:end]),
"embedding_": list(embedding[start:end].get()),
"raw_data_": list(raw_data[start:end].get()),
}
)

Expand Down

0 comments on commit 79a2781

Please sign in to comment.