Skip to content

Commit

Permalink
optimize pytorch dataset loader
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhao062 committed Jul 2, 2024
1 parent fbce11b commit 6e5b103
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 48 deletions.
32 changes: 7 additions & 25 deletions pyod/models/ae1svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,7 @@

from .base import BaseDetector
from ..utils.stat_models import pairwise_distances_no_broadcast
from ..utils.torch_utility import get_activation_by_name


class PyODDataset(torch.utils.data.Dataset):
"""PyOD Dataset class for PyTorch Dataloader"""

def __init__(self, X, y=None, mean=None, std=None):
super(PyODDataset, self).__init__()
self.X = X
self.mean = mean
self.std = std

def __len__(self):
return self.X.shape[0]

def __getitem__(self, idx):
sample = self.X[idx, :]
if self.mean is not None and self.std is not None:
sample = (sample - self.mean) / self.std
return torch.from_numpy(sample), idx
from ..utils.torch_utility import get_activation_by_name, TorchDataset


class InnerAE1SVM(nn.Module):
Expand Down Expand Up @@ -135,9 +116,10 @@ def fit(self, X, y=None):
if self.preprocessing:
self.mean, self.std = np.mean(X, axis=0), np.std(X, axis=0)
self.std[self.std == 0] = 1e-6 # Avoid division by zero
train_set = PyODDataset(X=X, mean=self.mean, std=self.std)
train_set = TorchDataset(X=X, mean=self.mean, std=self.std,
return_idx=True)
else:
train_set = PyODDataset(X=X)
train_set = TorchDataset(X=X, return_idx=True)

train_loader = torch.utils.data.DataLoader(train_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -193,9 +175,9 @@ def _train_autoencoder(self, train_loader):
def decision_function(self, X):
check_is_fitted(self, ['model', 'best_model_dict'])
X = check_array(X)
dataset = PyODDataset(X=X, mean=self.mean,
std=self.std) if self.preprocessing else (
PyODDataset(X=X))
dataset = TorchDataset(X=X, mean=self.mean,
std=self.std, return_idx=True) \
if self.preprocessing else (TorchDataset(X=X, return_idx=True))
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=self.batch_size,
shuffle=False)
Expand Down
24 changes: 5 additions & 19 deletions pyod/models/mo_gaal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@
from .gaal_base import create_discriminator, create_generator


class PyODDataset(torch.utils.data.Dataset):
"""Custom Dataset for handling data operations in PyTorch for outlier detection."""

def __init__(self, X):
super(PyODDataset, self).__init__()
self.X = torch.tensor(X, dtype=torch.float32)

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
return self.X[idx]


class MO_GAAL(BaseDetector):
"""Multi-Objective Generative Adversarial Active Learning.
Expand Down Expand Up @@ -143,8 +129,8 @@ def fit(self, X, y=None):

dataloader = DataLoader(TensorDataset(
torch.tensor(X, dtype=torch.float32).to(self.device)),
batch_size=min(500, data_size),
shuffle=True)
batch_size=min(500, data_size),
shuffle=True)

stop = 0

Expand All @@ -166,15 +152,15 @@ def fit(self, X, y=None):
if i != (self.k - 1):
noise_start = int(
(((self.k + (self.k - i + 1)) * i) / 2) * (
batch_size // block))
batch_size // block))
noise_end = int(
(((self.k + (self.k - i)) * (i + 1)) / 2) * (
batch_size // block))
batch_size // block))
names['noise' + str(i)] = noise[noise_start:noise_end]
else:
noise_start = int(
(((self.k + (self.k - i + 1)) * i) / 2) * (
batch_size // block))
batch_size // block))
names['noise' + str(i)] = noise[noise_start:batch_size]

names['generated_data' + str(i)] = names[
Expand Down
17 changes: 13 additions & 4 deletions pyod/utils/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

class TorchDataset(torch.utils.data.Dataset):
def __init__(self, X, y=None, mean=None, std=None, eps=1e-8,
X_dtype=torch.float32, y_dtype=torch.float32):
X_dtype=torch.float32, y_dtype=torch.float32,
return_idx=False):
self.X = X
self.y = y
self.mean = mean
self.std = std
self.eps = eps
self.X_dtype = X_dtype
self.y_dtype = y_dtype
self.return_idx = return_idx

def __len__(self):
return len(self.X)
Expand All @@ -31,10 +33,17 @@ def __getitem__(self, idx):
sample = (sample - self.mean) / (self.std + self.eps)

if self.y is not None:
return torch.as_tensor(sample, dtype=self.X_dtype), \
torch.as_tensor(self.y[idx], dtype=self.y_dtype)
if self.return_idx:
return torch.as_tensor(sample, dtype=self.X_dtype), \
torch.as_tensor(self.y[idx], dtype=self.y_dtype), idx
else:
return torch.as_tensor(sample, dtype=self.X_dtype), \
torch.as_tensor(self.y[idx], dtype=self.y_dtype)
else:
return torch.as_tensor(sample, dtype=self.X_dtype)
if self.return_idx:
return torch.as_tensor(sample, dtype=self.X_dtype), idx
else:
return torch.as_tensor(sample, dtype=self.X_dtype)


class LinearBlock(nn.Module):
Expand Down

0 comments on commit 6e5b103

Please sign in to comment.