From 38241071818c1b25ccb1cad186ac61acbbc625e4 Mon Sep 17 00:00:00 2001 From: chenxwh Date: Tue, 24 Sep 2024 22:12:05 +0000 Subject: [PATCH 1/5] replicate --- cog.yaml | 40 +++++++++++++ predict.py | 108 +++++++++++++++++++++++++++++++++++ predict_video.py | 142 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 cog.yaml create mode 100644 predict.py create mode 100644 predict_video.py diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..9b85302 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,40 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # a list of packages in the format == + python_packages: + - decord>=0.6.0 + - pytorchvideo==0.1.5 + - xformers + - torch==2.1.0 + - torchvision==0.16.0 + - transformers==4.42.4 + - huggingface-hub>=0.23.0 + - pillow + - chainlit>=1.0 + - timm>=0.9.16 + - openai>=1.30.1 + - loguru>=0.7.2 + - accelerate + - einops + - sse-starlette>=2.1.0 + - bitsandbytes>=0.43.1 # for int4 quantization + run: + - pip install ipython + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +# predict: "predict.py:Predictor" +predict: "predict_video.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..eea017b --- /dev/null +++ b/predict.py @@ -0,0 +1,108 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + + +import os +import time +import subprocess +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from cog import BasePredictor, Input, Path + +MODEL_CACHE = "model_cache_image" +MODEL_URL = ( + f"https://weights.replicate.delivery/default/THUDM/CogVLM2/{MODEL_CACHE}.tar" +) +os.environ["HF_DATASETS_OFFLINE"] = "1" +os.environ["TRANSFORMERS_OFFLINE"] = "1" +os.environ["HF_HOME"] = MODEL_CACHE +os.environ["TORCH_HOME"] = MODEL_CACHE +os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE +os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE +os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE + +TORCH_TYPE = torch.bfloat16 +DEVICE = "cuda:0" + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + # model_id: THUDM/cogvlm2-llama3-chat-19B, use 8 bit quantization + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_CACHE, + torch_dtype=TORCH_TYPE, + trust_remote_code=True, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + low_cpu_mem_usage=True, + ).eval() + + self.tokenizer = AutoTokenizer.from_pretrained( + MODEL_CACHE, trust_remote_code=True + ) + + def predict( + self, + input_image: Path = Input(description="Input image"), + prompt: str = Input(description="Input prompt", default="Describe this image."), + top_p: float = Input( + description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", + ge=0.0, + le=1.0, + default=0.9, + ), + temperature: float = Input( + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", + default=0.7, + ge=0.0, + ), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate. A word is generally 2-3 tokens", + default=2048, + ge=0, + ), + ) -> str: + """Run a single prediction on the model""" + image = Image.open(str(input_image)).convert("RGB") + + input_by_model = self.model.build_conversation_input_ids( + self.tokenizer, query=prompt, images=[image], template_version="chat" + ) + + inputs = { + "input_ids": input_by_model["input_ids"].unsqueeze(0).to(DEVICE), + "token_type_ids": input_by_model["token_type_ids"].unsqueeze(0).to(DEVICE), + "attention_mask": input_by_model["attention_mask"].unsqueeze(0).to(DEVICE), + "images": ( + [[input_by_model["images"][0].to(DEVICE).to(TORCH_TYPE)]] + if image is not None + else None + ), + } + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "pad_token_id": 128002, + # "top_k": 1, + "do_sample": True, + "top_p": top_p, + "temperature": temperature, + } + with torch.no_grad(): + outputs = self.model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs["input_ids"].shape[1] :] + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + return response diff --git a/predict_video.py b/predict_video.py new file mode 100644 index 0000000..25db133 --- /dev/null +++ b/predict_video.py @@ -0,0 +1,142 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + + +import os +import io +import time +import subprocess +import numpy as np +import torch +from PIL import Image +from decord import cpu, VideoReader, bridge +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from cog import BasePredictor, Input, Path + +MODEL_CACHE = "model_cache_video" +MODEL_URL = ( + f"https://weights.replicate.delivery/default/THUDM/CogVLM2/{MODEL_CACHE}.tar" +) +os.environ["HF_DATASETS_OFFLINE"] = "1" +os.environ["TRANSFORMERS_OFFLINE"] = "1" +os.environ["HF_HOME"] = MODEL_CACHE +os.environ["TORCH_HOME"] = MODEL_CACHE +os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE +os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE +os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE + +TORCH_TYPE = torch.bfloat16 +DEVICE = "cuda:0" + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + # model_id: THUDM/cogvlm2-video-llama3-chat, use 8 bit quantization + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_CACHE, + torch_dtype=TORCH_TYPE, + trust_remote_code=True, + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + bnb_4bit_compute_dtype=TORCH_TYPE, + ), + low_cpu_mem_usage=True, + ).eval() + + self.tokenizer = AutoTokenizer.from_pretrained( + MODEL_CACHE, trust_remote_code=True + ) + + def predict( + self, + input_video: Path = Input(description="Input video"), + prompt: str = Input(description="Input prompt", default="Describe this video."), + top_p: float = Input( + description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", + ge=0.0, + le=1.0, + default=0.1, + ), + temperature: float = Input( + description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", + default=0.1, + ge=0.0, + ), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate. A word is generally 2-3 tokens", + default=2048, + ge=0, + ), + ) -> str: + """Run a single prediction on the model""" + video = load_video(str(input_video)) + + inputs = self.model.build_conversation_input_ids( + tokenizer=self.tokenizer, + query=prompt, + images=[video], + template_version="chat", + ) + + inputs = { + "input_ids": inputs["input_ids"].unsqueeze(0).to(DEVICE), + "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(DEVICE), + "attention_mask": inputs["attention_mask"].unsqueeze(0).to(DEVICE), + "images": [[inputs["images"][0].to("cuda").to(TORCH_TYPE)]], + } + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "pad_token_id": 128002, + # "top_k": 1, + "do_sample": True, + "top_p": top_p, + "temperature": temperature, + } + with torch.no_grad(): + outputs = self.model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs["input_ids"].shape[1] :] + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + return response + + +def load_video(video_path): + bridge.set_bridge("torch") + with open(video_path, "rb") as f: + mp4_stream = f.read() + num_frames = 24 + + if mp4_stream is not None: + decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0)) + else: + decord_vr = VideoReader(video_path, ctx=cpu(0)) + frame_id_list = None + total_frames = len(decord_vr) + + # strategy == 'chat': + timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames)) + timestamps = [i[0] for i in timestamps] + max_second = round(max(timestamps)) + 1 + frame_id_list = [] + for second in range(max_second): + closest_num = min(timestamps, key=lambda x: abs(x - second)) + index = timestamps.index(closest_num) + frame_id_list.append(index) + if len(frame_id_list) >= num_frames: + break + video_data = decord_vr.get_batch(frame_id_list) + video_data = video_data.permute(3, 0, 1, 2) + return video_data From ff64af6292314bd3721976f77d2daea97e3b0086 Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 25 Sep 2024 15:59:48 +0100 Subject: [PATCH 2/5] Rename cog.yaml to web_demo/replicate/cog.yaml --- cog.yaml => web_demo/replicate/cog.yaml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename cog.yaml => web_demo/replicate/cog.yaml (100%) diff --git a/cog.yaml b/web_demo/replicate/cog.yaml similarity index 100% rename from cog.yaml rename to web_demo/replicate/cog.yaml From 7ecbb692107bb364e6fee8822e00d36cfb86cb85 Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 25 Sep 2024 16:00:29 +0100 Subject: [PATCH 3/5] Rename predict.py to web_demo/replicate/predict.py --- predict.py => web_demo/replicate/predict.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename predict.py => web_demo/replicate/predict.py (100%) diff --git a/predict.py b/web_demo/replicate/predict.py similarity index 100% rename from predict.py rename to web_demo/replicate/predict.py From 35f0f9e08b1bfe283619d6fa64659a594a576126 Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 25 Sep 2024 16:00:45 +0100 Subject: [PATCH 4/5] Rename predict_video.py to web_demo/replicate/predict_video.py --- predict_video.py => web_demo/replicate/predict_video.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename predict_video.py => web_demo/replicate/predict_video.py (100%) diff --git a/predict_video.py b/web_demo/replicate/predict_video.py similarity index 100% rename from predict_video.py rename to web_demo/replicate/predict_video.py From bf90631633df9fbb6e24cff7b40e4addcc920e4f Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 25 Sep 2024 16:03:10 +0100 Subject: [PATCH 5/5] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 655cbf5..8d8efdb 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@

## Recent updates +- 🔥 **News**: ``2024/9/25``: Web demos are availble on Replicate! Try CogVLM2 here [![Replicate](https://replicate.com/chenxwh/cogvlm2/badge)](https://replicate.com/chenxwh/cogvlm2) and CogVLM2-Video here [![Replicate](https://replicate.com/chenxwh/cogvlm2-video/badge)](https://replicate.com/chenxwh/cogvlm2-video). - 🔥 **News**: ``2024/8/30``: The [CogVLM2 paper](https://arxiv.org/abs/2408.16500) has been published on arXiv. - 🔥 **News**: ``2024/7/12``: We have released CogVLM2-Video [online web demo](http://cogvlm2-online.cogviewai.cn:7868/), welcome to experience it. - 🔥 **News**: ``2024/7/8``: We released the video understanding version of the CogVLM2 model, the CogVLM2-Video model.