-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathstable_diffusion.py
81 lines (65 loc) · 2.31 KB
/
stable_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import time
import os
import requests
import json
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
# default to fit into small-ish memory (less than 12G)
# TODO make configurable via args
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 768
def load_model():
# make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True)
pipe = pipe.to("cuda")
return pipe
def run_model(pipe, prompt, save_image=False):
with autocast("cuda"):
image = pipe(prompt, height=IMAGE_HEIGHT, width=IMAGE_WIDTH)["sample"][0]
if save_image:
ts = str(int(time.time()))
image_name = f"sample-{prompt[:100].replace(' ', '_')}-{ts}.png"
print(f"Image name is {image_name}")
image.save(image_name)
return image
def add_prompt_modifiers(plain_prompt):
OPENAI_TOKEN = os.environ['OPENAI_TOKEN']
with open('effective_prompts_fs.txt', 'r') as f:
prefix = f.read()
prompt = prefix + '\n' + plain_prompt
response = requests.post(
"https://api.openai.com/v1/completions",
headers={
'authorization': "Bearer " + OPENAI_TOKEN,
"content-type": "application/json",
},
json={
"model": "davinci",
"prompt": prompt,
"max_tokens": 50,
"temperature": 0.7,
"stop": "\n",
})
text = response.text
try:
result = json.loads(text)
except:
raise Exception(f'Cannot load: {text}, {response}')
prompt_modifiers = result['choices'][0]['text']
engineered_prompt = plain_prompt + prompt_modifiers
print(f'New engineered prompt: {engineered_prompt}')
return engineered_prompt
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", "-p", type=str, help="Prompt into model")
parser.add_argument("--promptgen", "-g", action='store_true', help="Use GPT-3 to prompt engineer plain English to prompt English")
args = parser.parse_args()
prompt = args.prompt
if args.promptgen:
prompt = add_prompt_modifiers(prompt)
pipe = load_model()
run_model(pipe, prompt, save_image=True)
if __name__ == "__main__":
main()