-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert.py
62 lines (53 loc) · 2.53 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torchvision import transforms
from torchvision.utils import save_image
import argparse
import os
from PIL import Image
from cyclegan import Generator
def main():
parser = argparse.ArgumentParser(description='Converts an old damaged image into a new colored image. Outputs a file(s) with the name input_file.<step>.png')
parser.add_argument('-i', '--intermediate', action='store_true', default=False, help='Output (save) the intermediate states')
parser.add_argument('-m', '--multi-domain', action='store_true', default=False, help='Use the multiple domain/model approach (damaged<>fixed<>color)')
parser.add_argument('-r', '--reverse', action='store_true', default=False, help='Reverses the process (corrected image --> old image)')
parser.add_argument('input_file', nargs='+', help='input file(s)')
args = parser.parse_args()
r = args.reverse
device = "cuda" if torch.cuda.is_available() else "cpu"
models = torch.load("ganban.pth", map_location=device)
flow = []
if args.multi_domain:
flow.append(Generator(3).to(device))
flow[-1].load_state_dict(models["fixed2broken" if r else "broken2fixed"])
flow[-1].eval()
flow.append(Generator(3).to(device))
flow[-1].load_state_dict(models["color2gray" if r else "gray2color"])
flow[-1].eval()
else:
flow.append(Generator(3).to(device))
flow[-1].load_state_dict(models["fixedcolor2broken" if r else "broken2fixedcolor"])
flow[-1].eval()
if r:
flow.reverse()
tf = transforms.Compose([
#transforms.Resize((256, 256), Image.Resampling.BICUBIC),
transforms.Resize(256, Image.Resampling.BICUBIC),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# can make this faster if you use dataloader i think
for f_image in args.input_file:
filename = os.path.splitext(f_image)[0]
input_image = Image.open(f_image).convert("RGB")
img_tensor = tf(input_image).unsqueeze(0).to(device)
if args.intermediate:
save_image(img_tensor, f"{filename}.0.png", normalize=True)
with torch.no_grad():
for i, gen in enumerate(flow):
img_tensor = gen(img_tensor)
if args.intermediate and i != (len(flow) - 1):
save_image(img_tensor, f"{filename}.{i + 1}.png", normalize=True)
save_image(img_tensor, f"{filename}.out.png", normalize=True)
if __name__ == "__main__":
main()