Skip to content

Commit

Permalink
PyTorch now uses GPU!
Browse files Browse the repository at this point in the history
- Fixed bug described here: junyanz#86
  • Loading branch information
bastian43 committed Aug 8, 2023
1 parent c132925 commit eef85d0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 20 deletions.
10 changes: 4 additions & 6 deletions data/colorize_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def __init__(self, Xd: int = 256, maskcent: bool = False):
def prep_net(self, gpu_id: int | None = None, path: str = '', dist: bool = False) -> None:
import torch
import models.pytorch.model as model
print(f"ColorizeImageTorch: path = {path}")
print(f"ColorizeImageTorch: dist mode = {dist}")
self.net = model.SIGGRAPHGenerator(dist=dist)
self.net = model.SIGGRAPHGenerator(use_gpu=gpu_id is not None, dist=dist)
state_dict = torch.load(path)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
Expand All @@ -229,8 +227,8 @@ def prep_net(self, gpu_id: int | None = None, path: str = '', dist: bool = False
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.'))
self.net.load_state_dict(state_dict)
if gpu_id != None:
self.net.cuda()
if gpu_id is not None:
self.net.cuda(gpu_id)
self.net.eval()
self.net_set = True

Expand Down Expand Up @@ -305,7 +303,7 @@ def net_forward(self, input_ab: np.ndarray, input_mask: np.ndarray) -> int | np.
# embed()
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
return -1

# set distribution
(function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)
function_return = function_return[0, :, :, :].cpu().data.numpy()
Expand Down
2 changes: 1 addition & 1 deletion docker/data/colorize_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def prep_net(self, gpu_id=None, path='', dist=False):
import models.pytorch.model as model
print('path = %s' % path)
print('Model set! dist mode? ', dist)
self.net = model.SIGGRAPHGenerator(dist=dist)
self.net = model.SIGGRAPHGenerator(use_gpu=gpu_id is not None, dist=dist)
state_dict = torch.load(path)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
Expand Down
13 changes: 8 additions & 5 deletions ideepcolor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,27 @@ def parse_args() -> argparse.Namespace:
print(f"[{arg}] = {getattr(args, arg)}")
print("\n")

if args.cpu_mode:
args.gpu = -1

args.win_size = int(args.win_size / 4.0) * 4 # make sure the width of the image can be divided by 4

if args.backend == 'caffe':
if args.cpu_mode:
args.gpu = -1

# initialize the colorization model
colorModel = CI.ColorizeImageCaffe(Xd=args.load_size)
colorModel.prep_net(args.gpu, args.color_prototxt, args.color_caffemodel)

distModel = CI.ColorizeImageCaffeDist(Xd=args.load_size)
distModel.prep_net(args.gpu, args.dist_prototxt, args.dist_caffemodel)
elif args.backend == 'pytorch':
if args.cpu_mode:
args.gpu = None

colorModel = CI.ColorizeImageTorch(Xd=args.load_size,maskcent=args.pytorch_maskcent)
colorModel.prep_net(path=args.color_model)
colorModel.prep_net(args.gpu, path=args.color_model)

distModel = CI.ColorizeImageTorchDist(Xd=args.load_size,maskcent=args.pytorch_maskcent)
distModel.prep_net(path=args.color_model, dist=True)
distModel.prep_net(args.gpu, path=args.color_model, dist=True)
else:
print(f"Backend type [{args.backend}] unknown")
sys.exit()
Expand Down
18 changes: 10 additions & 8 deletions models/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class SIGGRAPHGenerator(nn.Module):
def __init__(self, dist=False):
def __init__(self, use_gpu=False, dist=False):
super(SIGGRAPHGenerator, self).__init__()
self.use_gpu = use_gpu
self.dist = dist
use_bias = True
norm_layer = nn.BatchNorm2d
Expand Down Expand Up @@ -136,14 +137,15 @@ def forward(self, input_A, input_B, mask_B, maskcent=0):
# input_B \in [-110, +110]
# mask_B \in [0, +1.0]

input_A = torch.Tensor(input_A)[None, :, :, :]
input_B = torch.Tensor(input_B)[None, :, :, :]
mask_B = torch.Tensor(mask_B)[None, :, :, :]
if self.use_gpu:
input_A = torch.Tensor(input_A).cuda()[None, :, :, :]
input_B = torch.Tensor(input_B).cuda()[None, :, :, :]
mask_B = torch.Tensor(mask_B).cuda()[None, :, :, :]
else:
input_A = torch.Tensor(input_A)[None, :, :, :]
input_B = torch.Tensor(input_B)[None, :, :, :]
mask_B = torch.Tensor(mask_B)[None, :, :, :]
mask_B = mask_B - maskcent

# input_A = torch.Tensor(input_A).cuda()[None, :, :, :]
# input_B = torch.Tensor(input_B).cuda()[None, :, :, :]
# mask_B = torch.Tensor(mask_B).cuda()[None, :, :, :]

conv1_2 = self.model1(torch.cat((input_A / 100., input_B / 110., mask_B), dim=1))
conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
Expand Down

0 comments on commit eef85d0

Please sign in to comment.