-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathVAEEncodeForInpaint.py
61 lines (48 loc) · 2.37 KB
/
VAEEncodeForInpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Provides demonstration of PR https://github.com/comfyanonymous/ComfyUI/pull/1574/commits/297d1cff422198806cda40e4b6d71a6e6aa05453
import torch
class WAS_VAEEncodeForInpaint:
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "mask_offset": ("INT", {"default": 6, "min": -128, "max": 128, "step": 1}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, mask_offset=6):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
mask_erosion = self.modify_mask(mask, mask_offset)
m = (1.0 - mask_erosion.round()).squeeze(1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
def modify_mask(self, mask, modify_by):
if modify_by == 0:
return mask
if modify_by > 0:
kernel_size = 2 * modify_by + 1
kernel_tensor = torch.ones((1, 1, kernel_size, kernel_size))
padding = modify_by
modified_mask = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
else:
kernel_size = 2 * abs(modify_by) + 1
kernel_tensor = torch.ones((1, 1, kernel_size, kernel_size))
padding = abs(modify_by)
eroded_mask = torch.nn.functional.conv2d(1 - mask.round(), kernel_tensor, padding=padding)
modified_mask = torch.clamp(1 - eroded_mask, 0, 1)
return modified_mask
NODE_CLASS_MAPPINGS = {
"VAEEncodeForInpaint (WAS)": WAS_VAEEncodeForInpaint,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VAEEncodeForInpaint (WAS)": "Inpainting VAE Encode (WAS)",
}