Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update deep_svdd.py #607

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 124 additions & 119 deletions pyod/models/deep_svdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,75 +73,91 @@ class InnerDeepSVDD(nn.Module):
def __init__(self, n_features, use_ae,
hidden_neurons, hidden_activation,
output_activation,
dropout_rate, l2_regularizer):
dropout_rate, l2_regularizer, feature_type, input_shape=None):
super(InnerDeepSVDD, self).__init__()
self.n_features = n_features
self.use_ae = use_ae
self.hidden_neurons = hidden_neurons or [64, 32]
self.hidden_activation = hidden_activation
self.output_activation = output_activation
self.dropout_rate = dropout_rate
self.l2_regularizer = l2_regularizer
self.model = self._build_model()
self.feature_type = feature_type
self.input_shape = input_shape
if self.feature_type == "obs":
self.embedder_features = n_features
self.linear_features = n_features
self.embedder = self._build_embedder()
elif self.feature_type in ["hidden", "dist"]:
self.linear_features = self.input_shape[1]
elif self.feature_type == "hidden_obs":
self.embedder_features = n_features
self.linear_features = n_features + self.input_shape[-1]
self.embedder = self._build_embedder()
self.fc_part = self._build_fc()
self.c = None # Center of the hypersphere for DeepSVDD

def _init_c(self, X_norm, eps=0.1):
intermediate_output = {}
hook_handle = self.model._modules.get(
'net_output').register_forward_hook(
lambda module, input, output: intermediate_output.update(
{'net_output': output})
hook_handle = self.fc_part._modules.get('net_output').register_forward_hook(
lambda module, input, output: intermediate_output.update({'net_output': output})
)
output = self.model(X_norm)

if self.feature_type in ["obs", "hidden", "dist"]:
output = self.forward(X_norm)
elif self.feature_type == "hidden_obs":
output = self.forward([X_norm[0], X_norm[1]])
out = intermediate_output['net_output']
hook_handle.remove()

self.c = torch.mean(out, dim=0)
self.c[(torch.abs(self.c) < eps) & (self.c < 0)] = -eps
self.c[(torch.abs(self.c) < eps) & (self.c > 0)] = eps

def _build_model(self):
def _build_embedder(self):
if len(self.input_shape) == 3:
channels = self.input_shape[0]
else:
channels = self.input_shape[1]
layers = nn.Sequential()
layers.add_module('cnn_layer1', nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1))
layers.add_module('cnn_activation1', nn.ReLU())
layers.add_module('cnn_pool', nn.MaxPool2d(kernel_size=2, stride=2))
layers.add_module('cnn_layer2', nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1))
layers.add_module('cnn_activation2', nn.ReLU())
layers.add_module('cnn_adaptive_pool', nn.AdaptiveMaxPool2d((32, 32)))
layers.add_module('flatten', nn.Flatten())
layers.add_module('cnn_fc', nn.Linear(32 * 32 * 32, self.embedder_features, bias=False))
layers.add_module('cnn_fc_activation', nn.ReLU())
return layers

def _build_fc(self):
layers = nn.Sequential()
layers.add_module('input_layer',
nn.Linear(self.n_features, self.hidden_neurons[0],
bias=False))
layers.add_module('hidden_activation_e0',
get_activation_by_name(self.hidden_activation))
layers.add_module('input_layer', nn.Linear(self.linear_features, self.hidden_neurons[0], bias=False))
layers.add_module('hidden_activation_e0', get_activation_by_name(self.hidden_activation))
for i in range(1, len(self.hidden_neurons) - 1):
layers.add_module(f'hidden_layer_e{i}',
nn.Linear(self.hidden_neurons[i - 1],
self.hidden_neurons[i], bias=False))
layers.add_module(f'hidden_activation_e{i}',
get_activation_by_name(self.hidden_activation))
layers.add_module(f'hidden_dropout_e{i}',
nn.Dropout(self.dropout_rate))
layers.add_module(f'net_output', nn.Linear(self.hidden_neurons[-2],
self.hidden_neurons[-1],
bias=False))
layers.add_module(f'hidden_activation_e{len(self.hidden_neurons)}',
get_activation_by_name(self.hidden_activation))
layers.add_module(f'hidden_layer_e{i}', nn.Linear(self.hidden_neurons[i - 1], self.hidden_neurons[i], bias=False))
layers.add_module(f'hidden_activation_e{i}', get_activation_by_name(self.hidden_activation))
layers.add_module(f'hidden_dropout_e{i}', nn.Dropout(self.dropout_rate))
layers.add_module('net_output', nn.Linear(self.hidden_neurons[-2], self.hidden_neurons[-1], bias=False))
layers.add_module(f'hidden_activation_e{len(self.hidden_neurons)}', get_activation_by_name(self.hidden_activation))

if self.use_ae:
# Add reverse layers for the autoencoder if needed
for j in range(len(self.hidden_neurons) - 1, 0, -1):
layers.add_module(f'hidden_layer_d{j}',
nn.Linear(self.hidden_neurons[j],
self.hidden_neurons[j - 1],
bias=False))
layers.add_module(f'hidden_activation_d{j}',
get_activation_by_name(
self.hidden_activation))
layers.add_module(f'hidden_dropout_d{j}',
nn.Dropout(self.dropout_rate))
layers.add_module(f'output_layer',
nn.Linear(self.hidden_neurons[0],
self.n_features, bias=False))
layers.add_module(f'output_activation',
get_activation_by_name(self.output_activation))
layers.add_module(f'hidden_layer_d{j}', nn.Linear(self.hidden_neurons[j], self.hidden_neurons[j - 1], bias=False))
layers.add_module(f'hidden_activation_d{j}', get_activation_by_name(self.hidden_activation))
layers.add_module(f'hidden_dropout_d{j}', nn.Dropout(self.dropout_rate))
layers.add_module('output_layer', nn.Linear(self.hidden_neurons[0], self.n_features, bias=False))
layers.add_module('output_activation', get_activation_by_name(self.output_activation))

return layers

def forward(self, x):
return self.model(x)

if self.feature_type == "obs":
x = self.embedder(x)
elif self.feature_type == "hidden_obs":
features = self.embedder(x[0])
x = torch.cat([features, x[1]], dim=-1)
x = self.fc_part(x)
return x

class DeepSVDD(BaseDetector):
"""Deep One-Class Classifier with AutoEncoder (AE) is a type of neural
Expand All @@ -155,7 +171,7 @@ class DeepSVDD(BaseDetector):

Parameters
----------
n_features: int,
n_features: int,
Number of features in the input data.

c: float, optional (default='forwad_nn_pass')
Expand Down Expand Up @@ -240,9 +256,9 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None,
hidden_activation='relu',
output_activation='sigmoid', optimizer='adam', epochs=100,
batch_size=32,
dropout_rate=0.2, l2_regularizer=0.1, validation_size=0.1,
dropout_rate=0.2, l2_regularizer=0.1, feature_type="obs", validation_size=0.1,
preprocessing=True,
verbose=1, random_state=None, contamination=0.1):
verbose=1, random_state=None, contamination=0.1, input_shape=None):
super(DeepSVDD, self).__init__(contamination=contamination)

self.n_features = n_features
Expand All @@ -256,24 +272,38 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None,
self.batch_size = batch_size
self.dropout_rate = dropout_rate
self.l2_regularizer = l2_regularizer
self.feature_type = feature_type
self.validation_size = validation_size
self.preprocessing = preprocessing
self.verbose = verbose
self.random_state = random_state
self.model_ = None
self.best_model_dict = None
self.input_shape = input_shape

if self.random_state is not None:
torch.manual_seed(self.random_state)
check_parameter(dropout_rate, 0, 1, param_name='dropout_rate',
include_left=True)
check_parameter(dropout_rate, 0, 1, param_name='dropout_rate', include_left=True)

# Initialize the DeepSVDD model with updated input shape
self.model_ = InnerDeepSVDD(
n_features=self.n_features, # Now determined by CNN output
use_ae=self.use_ae,
hidden_neurons=self.hidden_neurons,
hidden_activation=self.hidden_activation,
output_activation=self.output_activation,
dropout_rate=self.dropout_rate,
l2_regularizer=self.l2_regularizer,
feature_type=self.feature_type,
input_shape=self.input_shape,
)

def fit(self, X, y=None):
"""Fit detector. y is ignored in unsupervised methods.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
X : list or numpy array of shape (n_samples, channels, height, width)
The input samples.

y : Ignored
Expand All @@ -284,81 +314,51 @@ def fit(self, X, y=None):
self : object
Fitted estimator.
"""
# validate inputs X and y (optional)
X = check_array(X)
self._set_n_classes(y)

# Verify and construct the hidden units
self.n_samples_, self.n_features_ = X.shape[0], X.shape[1]
X_norm = self.normalization(X)

# Standardize data for better performance
if self.preprocessing:
self.scaler_ = StandardScaler()
X_norm = self.scaler_.fit_transform(X)
else:
X_norm = np.copy(X)

# Shuffle the data for validation as Keras do not shuffling for
# Validation Split
np.random.shuffle(X_norm)

# Validate and complete the number of hidden neurons
if np.min(self.hidden_neurons) > self.n_features_ and self.use_ae:
raise ValueError("The number of neurons should not exceed "
"the number of features")

# Build DeepSVDD model & fit with X
self.model_ = InnerDeepSVDD(self.n_features, use_ae=self.use_ae,
hidden_neurons=self.hidden_neurons,
hidden_activation=self.hidden_activation,
output_activation=self.output_activation,
dropout_rate=self.dropout_rate,
l2_regularizer=self.l2_regularizer)
X_norm = torch.tensor(X_norm, dtype=torch.float32)
if self.c is None:
self.c = 0.0
self.model_._init_c(X_norm)

# Predict on X itself and calculate the reconstruction error as
# the outlier scores. Noted X_norm was shuffled has to recreate
if self.preprocessing:
X_norm = self.scaler_.transform(X)
# Prepare DataLoader for batch processing
if self.feature_type == "hidden_obs":
dataset = TensorDataset(*X_norm, *X_norm)
else:
X_norm = np.copy(X)

X_norm = torch.tensor(X_norm, dtype=torch.float32)
dataset = TensorDataset(X_norm, X_norm)
dataloader = DataLoader(dataset, batch_size=self.batch_size,
shuffle=True)
dataset = TensorDataset(X_norm, X_norm)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

best_loss = float('inf')
best_model_dict = None

optimizer = optimizer_dict[self.optimizer](self.model_.parameters(),
weight_decay=self.l2_regularizer)
w_d = 1e-6 * sum(
[torch.linalg.norm(w) for w in self.model_.parameters()])
optimizer = optimizer_dict[self.optimizer](self.model_.parameters(), weight_decay=self.l2_regularizer)
# w_d = 1e-6 * sum([torch.linalg.norm(w) for w in self.model_.parameters()])

for epoch in range(self.epochs):
self.model_.train()
epoch_loss = 0
for batch_x, _ in dataloader:
optimizer.zero_grad()
for batch in dataloader:
if self.feature_type == "hidden_obs":
batch_x = batch[0], batch[1]
else:
batch_x = batch[0]
outputs = self.model_(batch_x)
dist = torch.sum((outputs - self.c) ** 2, dim=-1)

w_d = 1e-6 * sum([torch.linalg.norm(w) for w in self.model_.parameters()])

if self.use_ae:
loss = torch.mean(dist) + w_d + torch.mean(
torch.square(outputs - batch_x))
loss = torch.mean(dist) + w_d + torch.mean(torch.square(outputs - batch_x))
else:
loss = torch.mean(dist) + w_d

# loss.backward()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if epoch_loss < best_loss:
best_loss = epoch_loss
best_model_dict = self.model_.state_dict()
epoch_loss /= len(dataloader)
print(f"Epoch {epoch + 1}/{self.epochs}, Loss: {epoch_loss}")
if epoch_loss < best_loss:
best_loss = epoch_loss
best_model_dict = self.model_.state_dict()
self.best_model_dict = best_model_dict

self.decision_scores_ = self.decision_function(X)
Expand All @@ -368,32 +368,37 @@ def fit(self, X, y=None):
def decision_function(self, X):
"""Predict raw anomaly score of X using the fitted detector.

The anomaly score of an input sample is computed based on different
detector algorithms. For consistency, outliers are assigned with
larger anomaly scores.
The anomaly score of an input sample is computed based on the DeepSVDD model.
Outliers are assigned with larger anomaly scores.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only
if they are supported by the base estimator.
X : numpy array of shape (n_samples, channels, height, width)
The input samples.

Returns
-------
anomaly_scores : numpy array of shape (n_samples,)
The anomaly score of the input samples.
"""
# check_is_fitted(self, ['model_', 'history_'])
X = check_array(X)

if self.preprocessing:
X_norm = self.scaler_.transform(X)
else:
X_norm = np.copy(X)
X_norm = torch.tensor(X_norm, dtype=torch.float32)
# Normalize data if pixel values are in [0, 255] range
X = self.normalization(X)
self.model_.eval()
with torch.no_grad():
outputs = self.model_(X_norm)
outputs = self.model_(X)
dist = torch.sum((outputs - self.c) ** 2, dim=-1)
anomaly_scores = dist.numpy()
anomaly_scores = dist.cpu().numpy()
return anomaly_scores

def normalization(self, X):
if self.feature_type in ["obs", "hidden_obs"]:
X_img = X if self.feature_type == "obs" else X[0]
# Normalize the image data if pixel values are in the range [0, 255]
if X_img.max() > 1:
X_img = X_img / 255.0
X_norm = X_img if self.feature_type == "obs" else [X_img, X[1]]
elif self.feature_type in ["hidden", "dist"]:
X_norm = X
else:
raise ValueError(f"Unknown feature type: {self.feature_type}")
return X_norm
Loading