-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_extractor.py
155 lines (135 loc) · 7.67 KB
/
train_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import json
from dataclasses import dataclass
import os
import torch
from diffusers import UNet2DModel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from feature_extraction.autoencoder import Autoencoder, AETrainer, DBTrainer
from feature_extraction.extractor import ResNetFE
from loader.loader import MVTecDataset
from schedulers.scheduling_ddim import DDIMScheduler
from pipe.inference import generate_samples
@dataclass
class TrainArgs:
checkpoint_dir: str
item: str
flip: bool
resolution: int
epochs: int
dataset_path: str
train_steps: int
beta_schedule: str
device: str
reconstruction_weight: float
eta: float
batch_size: int
noise_kind: str
crop: bool
checkpoint_name: str
save_to: str
start_at_timestep: int
steps_to_regenerate: int
use_diffusion_model: bool
log_dir: str
run_name: str
def parse_args() -> TrainArgs:
parser = argparse.ArgumentParser(description='Add config for the training')
parser.add_argument('--checkpoint_dir', type=str, default="checkpoints",
help='Directory path to store the checkpoints in.')
parser.add_argument('--log_dir', type=str, default="logs",
help='Directory to store the logs in. Will create a sub-directory with the "run_name"')
parser.add_argument('--run_name', type=str, default="extractor",
help='Name of the run. Used by Logging to create a new directory in the logging dir.')
parser.add_argument('--item', type=str, required=True,
help='name of the item within the Dataset to train on')
parser.add_argument('--resolution', type=int, default=256,
help='resolution of the images to generate (dataset will be resized to this resolution during training)')
parser.add_argument('--epochs', type=int, default=30,
help='epochs to train for')
parser.add_argument('--start_at_timestep', type=int, default=250,
help='Timestep from which the diffusion process should be started')
parser.add_argument('--steps_to_regenerate', type=int, default=25,
help='Number of timesteps to generate during the DDIM process')
parser.add_argument('--flip', action='store_true',
help='whether to augment training data with a flip')
parser.add_argument('--train_steps', type=int, default=1000,
help='number of steps for the full diffusion process')
parser.add_argument('--beta_schedule', type=str, default="linear",
help='Type of schedule for the beta/variance values')
parser.add_argument('--dataset_path', type=str, required=True,
help='directory path to the (mvtec) dataset')
parser.add_argument('--device', type=str, default="cuda",
help='device to train on')
parser.add_argument('--checkpoint_name', type=str, default=None,
help='Checkpoint to load diffusion model from. (Only relevant if use_diffusion_model is set)')
parser.add_argument('--save_to', type=str, required=True,
help='Full path and name to where the trained extractor should be saved (including .pt ending)')
parser.add_argument('--eta', type=float, default=0,
help='Stochasticity parameter of DDIM, with eta=1 being DDPM and eta=0 meaning no randomness. Only used during inference, not training.')
parser.add_argument('--reconstruction_weight', type=float, default=.1,
help='Influence of the original sample during diffusion')
parser.add_argument('--batch_size', type=int, default=8,
help='Batch size during training')
parser.add_argument('--noise_kind', type=str, default="gaussian",
choices=["simplex", "gaussian"],
help='Kind of noise to use for the noising steps.')
parser.add_argument('--crop', action='store_true',
help='If set: the image will be cropped to the resolution instead of resized.')
parser.add_argument('--use_diffusion_model', action='store_true',
help='If not set the feature extractor will be trained in an AE-Fashion, s.t. the input image should match the output image. If set and a diffusion model is given the model is trained to reduce the distance between an input image and the diffusion-model output of the image.')
return TrainArgs(**vars(parser.parse_args()))
def main(args: TrainArgs, writer: SummaryWriter):
print(f"**** training feature extractor ****")
def transform_imgs(imgs):
augmentations = transforms.Compose([
transforms.RandomCrop(args.resolution) if args.crop else transforms.Resize(args.resolution,
interpolation=transforms.InterpolationMode.BILINEAR),
# transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(.05, .05, .05),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return [augmentations(image.convert("RGB")) for image in imgs]
if args.use_diffusion_model:
config_file = open(f"{args.checkpoint_dir}/model_config.json", "r")
model_config = json.loads(config_file.read())
train_arg_file = open(f"{args.checkpoint_dir}/train_arg_config.json", "r")
train_arg_config: dict = json.loads(train_arg_file.read())
model = UNet2DModel(
**model_config
)
model.load_state_dict(torch.load(f"{args.checkpoint_dir}/{args.checkpoint_name}"))
model.eval()
model.to(args.device)
noise_kind = train_arg_config.get("noise_kind", "gaussian")
noise_scheduler_inference = DDIMScheduler(args.train_steps, args.start_at_timestep,
beta_schedule=args.beta_schedule, timestep_spacing="leading",
reconstruction_weight=args.reconstruction_weight,
noise_type=noise_kind)
def denoise_imgs(batch: torch.Tensor) -> torch.Tensor:
_, imgs, _, _ = generate_samples(model, noise_scheduler_inference, None, batch, args.eta,
args.steps_to_regenerate, args.start_at_timestep, args.crop, noise_kind)
return imgs
data_train = MVTecDataset(args.dataset_path, True, args.item, ["good"],
transform_imgs)
train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True)
data_test = MVTecDataset(args.dataset_path, False, args.item, ["good"],
transform_imgs)
test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True)
extractor = ResNetFE()
ae = Autoencoder(extractor)
ae.init_decoder((3, args.resolution, args.resolution))
trainer = AETrainer(ae, train_loader, test_loader, writer=writer) if not args.use_diffusion_model else DBTrainer(ae, denoise_imgs,
train_loader,
test_loader, writer=writer)
trainer.train(args.epochs)
torch.save(extractor.state_dict(), f"{args.save_to}")
if __name__ == '__main__':
args: TrainArgs = parse_args()
writer = SummaryWriter(args.log_dir, args.run_name)
main(args, writer)
writer.flush()
writer.close()