Skip to content

Commit

Permalink
Merge pull request #53 from unum-cloud/gen
Browse files Browse the repository at this point in the history
Generative models
  • Loading branch information
ashvardanian authored Dec 28, 2023
2 parents 8e73b01 + 908f8c6 commit 0a894ff
Show file tree
Hide file tree
Showing 8 changed files with 742 additions and 184 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"tag": "Add",
"release": "minor"
},
{
"tag": "Break",
"release": "major"
},
{
"tag": "Improve",
"release": "patch"
Expand All @@ -46,6 +50,10 @@
"tag": "Add",
"release": "minor"
},
{
"tag": "Break",
"release": "major"
},
{
"tag": "Improve",
"release": "patch"
Expand Down
321 changes: 156 additions & 165 deletions README.md

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
requires = ["setuptools>=42"]
build-backend = "setuptools.build_meta"

[project.scripts]
uform-chat = "uform.chat:main"

[project]
name = "uform"
version = "0.4.8"
Expand All @@ -17,6 +20,7 @@ dependencies = [
"torch>=1.13.1",
"tokenizers>=0.13.3",
"huggingface_hub>=0.16.4",
"transformers>=4.36.2",
"torchvision"
]
description = "Multi-Modal Transformers library for Semantic Search and other Vision-Language tasks"
Expand Down Expand Up @@ -46,12 +50,4 @@ classifiers = [
]

[project.urls]
"Homepage" = "https://github.com/unum-cloud/uform"

[project.optional-dependencies]
remote = [
"tritonclient[all]"
]
ipu = [
"poptorch"
]
"Homepage" = "https://github.com/unum-cloud/uform"
Empty file removed src/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions src/uform.py → src/uform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from json import load
from typing import Optional, Tuple, Mapping
from typing import Mapping, Optional, Tuple

import torch
from huggingface_hub import snapshot_download

from models import *
from uform.models import *


def get_checkpoint(model_name, token) -> Tuple[str, Mapping, str]:
Expand Down
114 changes: 114 additions & 0 deletions src/uform/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from argparse import ArgumentParser

import torch
import requests
from PIL import Image
from transformers import TextStreamer

from uform.gen_model import VLMForCausalLM, VLMProcessor

EOS_TOKEN = 32001


def parse_args():
parser = ArgumentParser(description="Chat with UForm generative model")

parser.add_argument("--model", type=str, default="unum-cloud/uform-gen-chat")
parser.add_argument("--image", type=str, help="", required=True)
parser.add_argument("--device", type=str, required=True)
parser.add_argument("--fp16", action="store_true")

return parser.parse_args()


def run_chat(opts, model, processor):
streamer = TextStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)

messages = [{"role": "system", "content": "You are a helpful assistant."}]
is_first_message = True
if opts.image.startswith("http"):
image = (
processor.image_processor(
Image.open(requests.get(opts.image, stream=True).raw)
)
.unsqueeze(0)
.to(torch.bfloat16 if opts.fp16 else torch.float32)
.to(opts.device)
)
else:
image = (
processor.image_processor(Image.open(opts.image))
.unsqueeze(0)
.to(torch.bfloat16 if opts.fp16 else torch.float32)
.to(opts.device)
)

while True:
if messages[-1]["role"] in ("system", "assistant"):
message = input("User: ")
if is_first_message:
message = f" <image> {message}"
is_first_message = False
messages.append({"role": "user", "content": message})

print()

else:
input_ids = processor.tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True
).to(opts.device)

attention_mask = torch.ones(
1, input_ids.shape[1] + processor.num_image_latents - 1
).to(opts.device)
x = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"images": image,
}

print("Assistant: ", end="")
with torch.inference_mode():
y = model.generate(
**x,
do_sample=False,
use_cache=True,
max_new_tokens=1024,
eos_token_id=EOS_TOKEN,
pad_token_id=processor.tokenizer.pad_token_id,
streamer=streamer,
)
print()

message = processor.batch_decode(y[:, x["input_ids"].shape[1] : -1])[0]

messages.append({"role": "assistant", "content": message})


def main():
try:
opts = parse_args()

model = (
VLMForCausalLM.from_pretrained(
opts.model,
torch_dtype=torch.bfloat16 if opts.fp16 else torch.float32,
)
.eval()
.to(opts.device)
)
processor = VLMProcessor.from_pretrained(opts.model)

run_chat(opts, model, processor)

except KeyboardInterrupt:
print("Bye!")
pass


if __name__ == "__main__":
main()
Loading

0 comments on commit 0a894ff

Please sign in to comment.