Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve spade test #9

Open
wants to merge 2 commits into
base: spadecutoff
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions scripts/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import torch.utils.data as data
import os.path
from pathlib import Path


def default_loader(path):
Expand Down Expand Up @@ -92,22 +93,13 @@ def __len__(self):
import os
import os.path

IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
]
IMG_EXTENSIONS = set(
[".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
)


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
return Path(filename).suffix in IMG_EXTENSIONS


def make_dataset(dir):
Expand Down
77 changes: 49 additions & 28 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,51 @@
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from __future__ import print_function
from utils import get_config, pytorch03_to_pytorch04, sorted_nicely
from utils import get_config, sorted_nicely
from trainer import MUNIT_Trainer
import argparse
from torch.autograd import Variable
import torchvision.utils as vutils
import sys
import torch
import os
from torchvision import transforms
from PIL import Image
import tqdm as tq
import glob
import numpy as np
from pathlib import Path
from data import is_image_file
from datetime import datetime

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="network configuration file")
parser.add_argument("--input", type=str, help="directory of input images")
parser.add_argument("--mask_dir", type=str, help="directory of masks corresponding to input images")
parser.add_argument(
"--mask_dir", type=str, help="directory of masks corresponding to input images"
)
parser.add_argument("--output_folder", type=str, help="output image directory")
parser.add_argument("--checkpoint", type=str, help="checkpoint of generator")

parser.add_argument("--seed", type=int, default=10, help="random seed")

parser.add_argument(
"--synchronized", action="store_true", help="whether use synchronized style code or not",
"--synchronized",
action="store_true",
help="whether use synchronized style code or not",
)
parser.add_argument(
"--save_input", action="store_true", help="whether use synchronized style code or not",
"--save_input",
action="store_true",
help="whether use synchronized style code or not",
)
parser.add_argument(
"--output_path", type=str, default=".", help="path for logs, checkpoints, and VGG model weight",
"--output_path",
type=str,
default=".",
help="path for logs, checkpoints, and VGG model weight",
)
parser.add_argument(
"--save_mask", action="store_true", help="whether to save mask or not",
"--save_mask", action="store_true", help="whether to save mask or not"
)
opts = parser.parse_args()

Expand All @@ -58,12 +69,10 @@

trainer = MUNIT_Trainer(config)

# Load the model (here we currently only load the latest model architecture: one single style)
try:
state_dict = torch.load(opts.checkpoint)
trainer.gen.load_state_dict(state_dict["2"])
except:
sys.exit("Cannot load the checkpoints")
# Load the model
# (here we currently only load the latest model architecture: one single style)
state_dict = torch.load(opts.checkpoint)
trainer.gen.load_state_dict(state_dict["2"])

# Send the trainer to cuda
trainer.cuda()
Expand All @@ -73,23 +82,31 @@
new_size = config["new_size"]

# Define the list of non-flooded images
list_non_flooded = glob.glob(opts.input + "*")
list_non_flooded = [
str(im) for im in Path(opts.input).resolve().glob("*") if is_image_file(im)
]

list_non_flooded = sorted_nicely(list_non_flooded)
# Define list of masks:

list_masks = glob.glob(opts.mask_dir + "*")
list_masks = [
str(im) for im in Path(opts.mask_dir).resolve().glob("*") if is_image_file(im)
]

list_masks = sorted_nicely(list_masks)

if len(list_non_flooded) != len(list_masks):
sys.exit("Image list and mask list differ in length")
assert len(list_non_flooded) == len(
list_masks
), "Image list and mask list differ in length"


# Assert there are some elements inside
if len(list_non_flooded) == 0:
sys.exit("Image list is empty. Please ensure opts.input ends with a /")
assert list_non_flooded, "Image list is empty"

output_folder = Path(opts.output_folder).resolve()
output_folder.mkdir(parents=True, exist_ok=True)

run_id = str(datetime.now())[:19].replace(" ", "_")

# Inference
with torch.no_grad():
Expand All @@ -103,11 +120,13 @@
)

mask_transform = transforms.Compose(
[transforms.Resize((new_size, new_size)), transforms.ToTensor(),]
[transforms.Resize((new_size, new_size)), transforms.ToTensor()]
)

for j in tq.tqdm(range(len(list_non_flooded))):

file_id = f"{run_id}-{j}"

# Define image path
path_xa = list_non_flooded[j]

Expand All @@ -123,17 +142,19 @@
mask = mask[0].unsqueeze(0).unsqueeze(0)

# Load and transform the non_flooded image
x_a = Variable(transform(Image.open(path_xa).convert("RGB")).unsqueeze(0).cuda())
x_a = Variable(
transform(Image.open(path_xa).convert("RGB")).unsqueeze(0).cuda()
)
if opts.save_input:
inputs = (x_a + 1) / 2.0
path = os.path.join(opts.output_folder, "{:03d}input.jpg".format(j))
vutils.save_image(inputs.data, path, padding=0, normalize=True)
path = output_folder / "{}-input.jpg".format(file_id)
vutils.save_image(inputs.data, str(path), padding=0, normalize=True)

if opts.save_mask:
path = os.path.join(opts.output_folder, "{:03d}mask.jpg".format(j))
path = output_folder / "{}-mask.jpg".format(file_id)
# overlay mask onto image
save_m_a = x_a - (x_a * mask.repeat(1, 3, 1, 1)) + mask.repeat(1, 3, 1, 1)
vutils.save_image(save_m_a, path, padding=0, normalize=True)
vutils.save_image(save_m_a, str(path), padding=0, normalize=True)

# Extract content and style
x_a_augment = torch.cat([x_a, mask], dim=1)
Expand All @@ -146,7 +167,7 @@
outputs = (x_ab + 1) / 2.0

# Define output path
path = os.path.join(opts.output_folder, "{:03d}output.jpg".format(j))
path = output_folder / "{}-output.jpg".format(file_id)

# Save image
vutils.save_image(outputs.data, path, padding=0, normalize=True)
vutils.save_image(outputs.data, str(path), padding=0, normalize=True)
80 changes: 59 additions & 21 deletions scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,19 @@ def get_data_loader_list(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
transform_list = (
[transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
[transforms.RandomCrop((height, width))] + transform_list
if crop
else transform_list
)
transform_list = (
[transforms.Resize((new_size, new_size))] + transform_list
if new_size is not None
else transform_list
)
transform_list = (
[transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
[transforms.RandomHorizontalFlip()] + transform_list
if train
else transform_list
)
transform = transforms.Compose(transform_list)
dataset = ImageFilelist(root, file_list, transform=transform)
Expand Down Expand Up @@ -314,7 +318,9 @@ def transform(self, image, mask):
image = resize(image)
to_tensor = transforms.ToTensor()
# Random crop
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(self.height, self.width))
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(self.height, self.width)
)
image = F.crop(image, i, j, h, w)

if type(mask) is not torch.Tensor:
Expand Down Expand Up @@ -436,7 +442,9 @@ def __len__(self):
return len(self.image_paths)


def get_fid_data_loader(file_list_a, file_list_b, batch_size, train, new_size=256, num_workers=4):
def get_fid_data_loader(
file_list_a, file_list_b, batch_size, train, new_size=256, num_workers=4
):
"""
Masks and images lists-based data loader with transformations
(horizontal flip, resizing, random crop, normalization are handled)
Expand Down Expand Up @@ -518,7 +526,7 @@ def transform(self, image_a, image_b, mask, semantic_a, semantic_b):
# print('dim image after resize',image.size)

# Resize mask
#mask = mask.resize((image_b.width, image_b.height), Image.NEAREST)
# mask = mask.resize((image_b.width, image_b.height), Image.NEAREST)
mask = resize(mask)
semantic_a = semantic_a.resize((image_b.width, image_b.height), Image.NEAREST)
semantic_b = semantic_b.resize((image_b.width, image_b.height), Image.NEAREST)
Expand Down Expand Up @@ -627,7 +635,14 @@ def get_synthetic_data_loader(
loader -- data loader with transformed dataset
"""
dataset = MyDatasetSynthetic(
file_list_a, file_list_b, mask_list, sem_list_a, sem_list_b, new_size, height, width,
file_list_a,
file_list_b,
mask_list,
sem_list_a,
sem_list_b,
new_size,
height,
width,
)
loader = DataLoader(
dataset=dataset,
Expand Down Expand Up @@ -682,7 +697,14 @@ def get_data_loader_mask_and_im(


def get_data_loader_folder(
input_folder, batch_size, train, new_size=None, height=256, width=256, num_workers=4, crop=True,
input_folder,
batch_size,
train,
new_size=None,
height=256,
width=256,
num_workers=4,
crop=True,
):
"""
Folder-based data loader with transformations
Expand Down Expand Up @@ -711,15 +733,19 @@ def get_data_loader_folder(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
transform_list = (
[transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
[transforms.RandomCrop((height, width))] + transform_list
if crop
else transform_list
)
transform_list = (
[transforms.Resize((new_size, new_size))] + transform_list
if new_size is not None
else transform_list
)
transform_list = (
[transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
[transforms.RandomHorizontalFlip()] + transform_list
if train
else transform_list
)
transform = transforms.Compose(transform_list)
dataset = ImageFolder(input_folder, transform=transform)
Expand Down Expand Up @@ -763,14 +789,18 @@ def __write_images(image_outputs, display_image_num, file_name):
image_outputs = [
images.expand(-1, 3, -1, -1) for images in image_outputs
] # expand gray-scale images to 3 channels
image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
image_tensor = torch.cat(
[images[:display_image_num] for images in image_outputs], 0
)
image_grid = vutils.make_grid(
image_tensor.data, nrow=display_image_num, padding=0, normalize=True
)
vutils.save_image(image_grid, file_name, nrow=1)


def write_2images(image_outputs, display_image_num, image_directory, postfix, comet_exp=None):
def write_2images(
image_outputs, display_image_num, image_directory, postfix, comet_exp=None
):
"""Write images from both worlds a and b of the cycle A-B-A as jpg
Arguments:
image_outputs {Tensor list} -- list of images, the first half being outputs in B,
Expand Down Expand Up @@ -859,7 +889,9 @@ def get_slerp_interp(nb_latents, nb_interp, z_dim):
low = np.random.randn(z_dim)
high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7
interp_vals = np.linspace(0, 1, num=nb_interp)
latent_interp = np.array([slerp(v, low, high) for v in interp_vals], dtype=np.float32)
latent_interp = np.array(
[slerp(v, low, high) for v in interp_vals], dtype=np.float32
)
latent_interps = np.vstack((latent_interps, latent_interp))

return latent_interps[:, :, np.newaxis, np.newaxis]
Expand Down Expand Up @@ -919,7 +951,10 @@ def __init__(self, num_classes=1000):
# Load the pretrained weights, remove avg pool
# layer and get the output stride of 8
resnet34_8s = resnet34(
fully_conv=True, pretrained=True, output_stride=8, remove_avg_pool_layer=True,
fully_conv=True,
pretrained=True,
output_stride=8,
remove_avg_pool_layer=True,
)

# Randomly initialize the 1x1 Conv scoring layer
Expand Down Expand Up @@ -1070,7 +1105,9 @@ def get_scheduler(optimizer, hyperparameters, iterations=-1):
def weights_init(init_type="gaussian"):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find("Conv") == 0 or classname.find("Linear") == 0) and hasattr(m, "weight"):
if (classname.find("Conv") == 0 or classname.find("Linear") == 0) and hasattr(
m, "weight"
):
# print m.__class__.__name__
if init_type == "gaussian":
init.normal_(m.weight.data, 0.0, 0.02)
Expand Down Expand Up @@ -1279,7 +1316,9 @@ def __init__(
self.stride = stride
self.downsample = downsample
if stride != 1 or inplanes != planes:
self.downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
self.downsample = nn.Sequential(
conv1x1(inplanes, planes, stride), norm_layer(planes)
)

def forward(self, x):
identity = x
Expand Down Expand Up @@ -1408,12 +1447,11 @@ def p(d, prefix="", vals=[]):
return dict(values_list)


def sorted_nicely( l ):
""" Sort the given iterable in the way that humans expect."""
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
return sorted(l, key = alphanum_key)

def sorted_nicely(l):
""" Sort the given iterable in the way that humans expect."""
convert = lambda text: int(text) if text.isdigit() else text
alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]
return sorted(l, key=alphanum_key)



Expand Down