Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
YGZWQZD committed Sep 28, 2022
1 parent 724b319 commit 0586b7d
Show file tree
Hide file tree
Showing 8 changed files with 784 additions and 3 deletions.
Empty file added Examples/CIFAR10_imbalance.py
Empty file.
138 changes: 138 additions & 0 deletions LAMDA_SSL/Algorithm/Classification/ADFixMatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import copy
from LAMDA_SSL.Base.InductiveEstimator import InductiveEstimator
from LAMDA_SSL.Base.DeepModelMixin import DeepModelMixin
from sklearn.base import ClassifierMixin
import LAMDA_SSL.Config.FixMatch as config
from LAMDA_SSL.Loss.Cross_Entropy import Cross_Entropy
from LAMDA_SSL.Loss.Semi_Supervised_Loss import Semi_Supervised_Loss

import torch

class FixMatch(InductiveEstimator,DeepModelMixin,ClassifierMixin):
def __init__(self,
threshold=config.threshold,
lambda_u=config.lambda_u,
T=config.T,
mu=config.mu,
weight_decay=config.weight_decay,
ema_decay=config.ema_decay,
epoch=config.epoch,
num_it_epoch=config.num_it_epoch,
num_it_total=config.num_it_total,
eval_epoch=config.eval_epoch,
eval_it=config.eval_it,
optimizer=config.optimizer,
scheduler=config.scheduler,
device=config.device,
train_dataset=config.train_dataset,
labeled_dataset=config.labeled_dataset,
unlabeled_dataset=config.unlabeled_dataset,
valid_dataset=config.valid_dataset,
test_dataset=config.test_dataset,
train_dataloader=config.train_dataloader,
valid_dataloader=config.valid_dataloader,
test_dataloader=config.test_dataloader,
augmentation=config.augmentation,
network=config.network,
train_sampler=config.train_sampler,
train_batch_sampler=config.train_batch_sampler,
valid_sampler=config.valid_sampler,
valid_batch_sampler=config.valid_batch_sampler,
test_sampler=config.test_sampler,
test_batch_sampler=config.test_batch_sampler,
labeled_dataloader=config.labeled_dataloader,
unlabeled_dataloader=config.unlabeled_dataloader,
labeled_sampler=config.labeled_sampler,
unlabeled_sampler=config.unlabeled_sampler,
labeled_batch_sampler=config.labeled_batch_sampler,
unlabeled_batch_sampler=config.unlabeled_batch_sampler,
parallel=config.parallel,
evaluation=config.evaluation,
file=config.file,
verbose=config.verbose
):

# >> Parameter:
# >> - threshold: The confidence threshold for choosing samples.
# >> - lambda_u: The weight of unsupervised loss.
# >> - T: Sharpening temperature.
# >> - num_classes: The number of classes for the classification task.
# >> - thresh_warmup: Whether to use threshold warm-up mechanism.
# >> - use_hard_labels: Whether to use hard labels in the consistency regularization.
# >> - use_DA: Whether to perform distribution alignment for soft labels.
# >> - p_target: p(y) based on the labeled examples seen during training.

DeepModelMixin.__init__(self,train_dataset=train_dataset,
valid_dataset=valid_dataset,
labeled_dataset=labeled_dataset,
unlabeled_dataset=unlabeled_dataset,
test_dataset=test_dataset,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
test_dataloader=test_dataloader,
augmentation=augmentation,
network=network,
train_sampler=train_sampler,
train_batch_sampler=train_batch_sampler,
valid_sampler=valid_sampler,
valid_batch_sampler=valid_batch_sampler,
test_sampler=test_sampler,
test_batch_sampler=test_batch_sampler,
labeled_dataloader=labeled_dataloader,
unlabeled_dataloader=unlabeled_dataloader,
labeled_sampler=labeled_sampler,
unlabeled_sampler=unlabeled_sampler,
labeled_batch_sampler=labeled_batch_sampler,
unlabeled_batch_sampler=unlabeled_batch_sampler,
epoch=epoch,
num_it_epoch=num_it_epoch,
num_it_total=num_it_total,
eval_epoch=eval_epoch,
eval_it=eval_it,
mu=mu,
weight_decay=weight_decay,
ema_decay=ema_decay,
optimizer=optimizer,
scheduler=scheduler,
device=device,
evaluation=evaluation,
parallel=parallel,
file=file,
verbose=verbose
)
self.lambda_u=lambda_u
self.threshold=threshold
self.T=T
self.weight_decay=weight_decay
self._estimator_type=ClassifierMixin._estimator_type

def init_transform(self):
self._train_dataset.add_unlabeled_transform(copy.copy(self.train_dataset.unlabeled_transform),dim=0,x=1)
self._train_dataset.add_transform(self.weak_augmentation,dim=1,x=0,y=0)
self._train_dataset.add_unlabeled_transform(self.weak_augmentation,dim=1,x=0,y=0)
self._train_dataset.add_unlabeled_transform(self.strong_augmentation,dim=1,x=1,y=0)

def train(self,lb_X,lb_y,ulb_X,lb_idx=None,ulb_idx=None,*args,**kwargs):
lb_X=lb_X[0] if isinstance(lb_X,(tuple,list)) else lb_X
lb_y=lb_y[0] if isinstance(lb_y,(tuple,list)) else lb_y
w_ulb_X,s_ulb_X=ulb_X[0],ulb_X[1]
batch_size = lb_X.shape[0]
inputs=torch.cat((lb_X, w_ulb_X, s_ulb_X))
logits = self._network(inputs)
lb_logits = logits[:batch_size]
w_ulb_logits, s_ulb_logits = logits[batch_size:].chunk(2)
train_result=(lb_logits,lb_y,w_ulb_logits, s_ulb_logits)
return train_result

def get_loss(self,train_result,*args,**kwargs):
lb_logits, lb_y, w_ulb_logits, s_ulb_logits = train_result
sup_loss=Cross_Entropy(reduction='mean')(logits=lb_logits,targets=lb_y)
pseudo_label = torch.softmax(w_ulb_logits.detach() / self.T, dim=-1)
max_probs, targets_u = torch.max(pseudo_label, dim=-1)
mask = max_probs.ge(self.threshold).float()
unsup_loss = (Cross_Entropy(reduction='none')(s_ulb_logits, targets_u) * mask).mean()
loss=Semi_Supervised_Loss(lambda_u =self.lambda_u)(sup_loss,unsup_loss)
return loss

def predict(self,X=None,valid=None):
return DeepModelMixin.predict(self,X=X,valid=valid)
8 changes: 5 additions & 3 deletions LAMDA_SSL/Algorithm/Classification/Assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def fit(self,X,y,unlabeled_X):
for _ in range(l+u):
if _y_all[_]!=y_all[_]:
epsilon+=sample_weight[_]
w=np.log((1-epsilon)/epsilon)*0.5

w=np.log((1-epsilon)/(epsilon+1e-8))*0.5
self.w.append(w)
if epsilon>0.5:
break
Expand All @@ -80,13 +81,14 @@ def fit(self,X,y,unlabeled_X):
else:
alpha=self.alpha
sample_weight=alpha*np.exp(-logits)
sample_weight=sample_weight/sample_weight.sum()
sample_weight=(sample_weight+1e-8)/(sample_weight+1e-8).sum()
# num_sample=min((sample_weight!=0).sum(),l)
idx_sample=np.random.choice(l+u,l,False,p=sample_weight.tolist())
X_sample=X_all[idx_sample]
y_sample=y_all[idx_sample]
sample_weight_sample=sample_weight[idx_sample]
classfier=copy.deepcopy(self.base_estimator)
classfier.fit(X_sample,y_sample,sample_weight_sample)
classfier.fit(X_sample,y_sample,sample_weight_sample*X_sample.shape[0])
return self

def evaluate(self,X,y=None):
Expand Down
Loading

0 comments on commit 0586b7d

Please sign in to comment.