-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathimage_inference.py
93 lines (77 loc) · 2.38 KB
/
image_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
import argparse
from PIL import Image
from models.pipeline import EmuGenerationPipeline
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--instruct",
action='store_true',
default=False,
help="Load Emu-I",
)
parser.add_argument(
"--ckpt-path",
type=str,
default='',
help="Emu Decoder ckpt path",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
# NOTE
# Emu Decoder Pipeline only supports pretrain model
# Using instruct tuning model as image encoder may cause unpredicted results
assert args.instruct is False, "Image Generation currently do not support instruct tuning model"
pipeline = EmuGenerationPipeline.from_pretrained(
path=args.ckpt_path,
args=args,
)
pipeline = pipeline.bfloat16().cuda()
# image blend case
# image_1 = Image.open("examples/sunflower.png")
# image_2 = Image.open("examples/oil_sunflower.jpg")
image_1 = Image.open("examples/cat.jpg")
image_2 = Image.open("examples/tiger.jpg")
image, safety = pipeline(
[image_1, image_2],
height=512,
width=512,
guidance_scale=7.5,
)
if safety is None or not safety:
image.save("image_blend_result.jpg")
else:
print("ImageBlend Generated Image Has Safety Concern!!!")
# text-to-image case
text = "An image of a dog wearing a pair of glasses."
image, safety = pipeline(
[text],
height=512,
width=512,
guidance_scale=7.5,
)
if safety is None or not safety:
image.save("text2image_result.jpg")
else:
print("T2I Generated Image Has Safety Concern!!!")
# in-context generation
image_1 = Image.open("examples/dog.png")
image_2 = Image.open("examples/sunflower.png")
image, safety = pipeline(
[
"This is the first image: ",
image_1,
"This is the second image: ",
image_2,
"The animal in the first image surrounded with the plant in the second image: ",
],
height=512,
width=512,
guidance_scale=10.,
)
if safety is None or not safety:
image.save("incontext_result.jpg")
else:
print("In-context Generated Image Has Safety Concern!!!")