Skip to content

Commit

Permalink
Merge pull request #3 from SongY123/main
Browse files Browse the repository at this point in the history
fix bugs when indice is numpy object
  • Loading branch information
SongY123 authored Jan 13, 2025
2 parents 44e7a1a + 909ec40 commit d419b0e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion openhufu/dataset/splitters/generic/iid_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def __call__(self, dataset):
indices = np.arange(len(dataset))
np.random.shuffle(indices)
idx_slices = np.array_split(indices, self.n_clients)
subsets = [self.subset(dataset, idxs) for idxs in idx_slices]
subsets = [self.subset(dataset, [int(idx) for idx in idxs]) for idxs in idx_slices]
return subsets
2 changes: 1 addition & 1 deletion openhufu/dataset/splitters/generic/lda_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def __call__(self, dataset, get_labels_fn, alpha=0.5, **kwargs):
for client_id, split in enumerate(splits):
client_indices[client_id].extend(split)

subsets = [self.subset(dataset, idxs) for idxs in client_indices]
subsets = [self.subset(dataset, [int(idx) for idx in idxs]) for idxs in client_indices]
return subsets

0 comments on commit d419b0e

Please sign in to comment.