Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for TFIW dataset #42

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
---
noteId: "708531c0f33b11ec8e08f3de8bf47f07"
tags: []

---

# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al.
Including support for:
Expand Down
6 changes: 3 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ gpus: 1 # I recommend always assigning 1 GPU to 1 node
nr: 0 # machine nr. in node (0 -- nodes - 1)
dataparallel: 0 # Use DataParallel instead of DistributedDataParallel
workers: 8
dataset_dir: "./datasets"
dataset_dir: "/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all/train"

# train options
seed: 42 # sacred handles automatic seeding when passed in the config
batch_size: 128
image_size: 224
image_size: [108, 124]
start_epoch: 0
epochs: 100
dataset: "CIFAR10" # STL10
dataset: "TFIW" #"CIFAR10" # STL10
pretrain: True

# model options
Expand Down
18 changes: 17 additions & 1 deletion linear_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from simclr import SimCLR
from simclr.modules import LogisticRegression, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.tfiwDataset import TFIWDataset

from utils import yaml_config_hook

Expand Down Expand Up @@ -138,12 +139,26 @@ def test(args, loader, simclr_model, model, criterion, optimizer):
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)

test_dataset = torchvision.datasets.CIFAR10(
args.dataset_dir,
train=False,
download=True,
transform=TransformsSimCLR(size=args.image_size).test_transform,
)

elif args.dataset == "TFIW":
train_dataset = TFIWDataset(
args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/train
transform = TransformsSimCLR(size=args.image_size).test_transform,
)

test_dataset = TFIWDataset(
#args.dataset_dir,
"/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/val",
transform = TransformsSimCLR(size=args.image_size).test_transform,
)

else:
raise NotImplementedError

Expand Down Expand Up @@ -174,7 +189,8 @@ def test(args, loader, simclr_model, model, criterion, optimizer):
simclr_model.eval()

## Logistic Regression
n_classes = 10 # CIFAR-10 / STL-10
#n_classes = 10 # CIFAR-10 / STL-10
n_classes = 571 #TFIW has 571 families in the training dataset
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

Expand Down
24 changes: 20 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model

from simclr.modules.tfiwDataset import TFIWDataset
from net import LResNet50E_IR, LResNet

from model import load_optimizer, save_model
from utils import yaml_config_hook


def train(args, train_loader, model, criterion, optimizer, writer):
loss_epoch = 0
print(enumerate(train_loader))
for step, ((x_i, x_j), _) in enumerate(train_loader):
#for step, (x_i, x_j) in enumerate(train_loader):
#print(x_i)
#print(x_j)
optimizer.zero_grad()
x_i = x_i.cuda(non_blocking=True)
x_j = x_j.cuda(non_blocking=True)
#x_i = x_i.cuda(non_blocking=True)
#x_j = x_j.cuda(non_blocking=True)

# positive pair, with encoding
h_i, h_j, z_i, z_j = model(x_i, x_j)
Expand Down Expand Up @@ -76,6 +83,13 @@ def main(gpu, args):
download=True,
transform=TransformsSimCLR(size=args.image_size),
)

elif args.dataset == "TFIW":
train_dataset = TFIWDataset(
args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all/train
transform = TransformsSimCLR(size=args.image_size),
)

else:
raise NotImplementedError

Expand All @@ -96,8 +110,10 @@ def main(gpu, args):
)

# initialize ResNet
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features # get dimensions of fc layer
#encoder = get_resnet(args.resnet, pretrained=False)
encoder = LResNet50E_IR(is_gray=False)
n_features = 256 #encoder.fc.in_features # get dimensions of fc layer
print(encoder)

# initialize model
model = SimCLR(encoder, args.projection_dim, n_features)
Expand Down
10 changes: 9 additions & 1 deletion main_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simclr.modules import NT_Xent, get_resnet
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules.sync_batchnorm import convert_model
from simclr.modules.tfiwDataset import TFIWDataset

from utils import yaml_config_hook

Expand Down Expand Up @@ -76,7 +77,7 @@ def configure_optimizers(self):

parser = argparse.ArgumentParser(description="SimCLR")

config = yaml_config_hook("./config/config.yaml")
config = yaml_config_hook("/Users/gaurav/Desktop/thesis-work/contrastive/SimCLR-Faces/config/config.yaml")
for k, v in config.items():
parser.add_argument(f"--{k}", default=v, type=type(v))

Expand All @@ -95,6 +96,13 @@ def configure_optimizers(self):
download=True,
transform=TransformsSimCLR(size=args.image_size),
)

elif args.dataset == "TFIW":
train_dataset = TFIWDataset(
args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/train
transform = TransformsSimCLR(size=args.image_size),
)

else:
raise NotImplementedError

Expand Down
182 changes: 182 additions & 0 deletions net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch
import torch.nn as nn



# -------------------------------------- sphere network Begin --------------------------------------
class Block(nn.Module):
def __init__(self, planes):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.prelu1 = nn.PReLU(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.prelu2 = nn.PReLU(planes)

def forward(self, x):
return x + self.prelu2(self.conv2(self.prelu1(self.conv1(x))))


class sphere(nn.Module):
def __init__(self, type=20, is_gray=False):
super(sphere, self).__init__()
block = Block
if type is 20:
layers = [1, 2, 4, 1]
elif type is 64:
layers = [3, 7, 16, 3]
else:
raise ValueError('sphere' + str(type) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
filter_list = [3, 64, 128, 256, 512]
if is_gray:
filter_list[0] = 1

self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2)
self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2)
self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2)
self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2)
self.fc = nn.Linear(512 * 7 * 6, 512)

# Weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
if m.bias is not None:
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
else:
nn.init.normal_(m.weight, 0, 0.01)


def _make_layer(self, block, inplanes, planes, blocks, stride):
layers = []
layers.append(nn.Conv2d(inplanes, planes, 3, stride, 1))
layers.append(nn.PReLU(planes))
for i in range(blocks):
layers.append(block(planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = x.view(x.size(0), -1)
x = self.fc(x)

return x

def save(self, file_path):
with open(file_path, 'wb') as f:
torch.save(self.state_dict(), f)


# -------------------------------------- sphere network END --------------------------------------

# ---------------------------------- LResNet50E-IR network Begin ----------------------------------

class BlockIR(nn.Module):
def __init__(self, inplanes, planes, stride, dim_match):
super(BlockIR, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.prelu1 = nn.PReLU(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)

if dim_match:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes),
)

def forward(self, x):
residual = x

out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu1(out)
out = self.conv2(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual

return out


class LResNet(nn.Module):

def __init__(self, block, layers, filter_list, is_gray=False):
self.inplanes = 64
super(LResNet, self).__init__()
# input is (mini-batch,3 or 1,112,96)
# use (conv3x3, stride=1, padding=1) instead of (conv7x7, stride=2, padding=3)
if is_gray:
self.conv1 = nn.Conv2d(1, filter_list[0], kernel_size=3, stride=1, padding=1, bias=False) # gray
else:
self.conv1 = nn.Conv2d(3, filter_list[0], kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(filter_list[0])
self.prelu1 = nn.PReLU(filter_list[0])
self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2)
self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2)
self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2)
self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2)
self.fc = nn.Sequential(
nn.BatchNorm1d(filter_list[4] * 7 * 6),
nn.Dropout(p=0.4),
nn.Linear(filter_list[4] * 7 * 6, 512),
nn.BatchNorm1d(512), # fix gamma ???
)

# Weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)


def _make_layer(self, block, inplanes, planes, blocks, stride):
layers = []
layers.append(block(inplanes, planes, stride, False))
for i in range(1, blocks):
layers.append(block(planes, planes, stride=1, dim_match=True))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu1(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = x.view(x.size(0), -1)
x = self.fc(x)

return x

def save(self, file_path):
with open(file_path, 'wb') as f:
torch.save(self.state_dict(), f)


def LResNet50E_IR(is_gray=False):
print("Using LResNet50E from ArcFace")
filter_list = [64, 64, 128, 256, 512]
layers = [3, 4, 14, 3]
return LResNet(BlockIR, layers, filter_list, is_gray)
# ---------------------------------- LResNet50E-IR network End ----------------------------------
1 change: 1 addition & 0 deletions simclr/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .lars import LARS
from .resnet import get_resnet
from .gather import GatherLayer
#from simclr.modules.tfiwDataset import TFIWDataset
50 changes: 50 additions & 0 deletions simclr/modules/tfiwDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from types import NoneType
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
from torchvision import transforms
import pandas as pd


class TFIWDataset(Dataset):
def __init__(self, img_dir = os.getcwd(), transform = None):
self.img_dir = img_dir
self.transform = transform

self.img_names = os.listdir(img_dir)

file_names = []
labels = []
for i in self.img_names:
#print(i[:-3])
if(i[-3:]=='jpg'):
file_names.extend([i]) #to remove unwanted files names from the img_names like .DS_Store etc.
labels.extend([int(i[1:5])])
self.labels = labels
self.img_names = file_names

img_names_csv = pd.DataFrame(data= [file_names, self.labels]);
#img_names_csv['Labels'] = self.labels
img_names_csv.T.to_csv("/Users/gaurav/Desktop/data.csv")
#print(self.img_names[0:5])
#print(self.labels[0:5])

def __getitem__(self, idx):
image = Image.open(os.path.join(self.img_dir, self.img_names[idx]))
#image = torch.tensor(image)
if type(image)!=NoneType: #Some images were throwing empty tensors, hence did this.
if self.transform is not None:
image = self.transform(image)
try:
#print(idx, self.labels[idx], self.img_names[idx])
return image, self.labels[idx]
except IndexError:
print(f"Index is not present for index number {idx}")

def __len__(self):
return len(self.img_names)

#tfiw = TFIWDataset(img_dir='/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all')
#example = tfiw[7]
#print(example)