Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

444 #4

Open
dreamlychina opened this issue Sep 24, 2024 · 0 comments
Open

444 #4

dreamlychina opened this issue Sep 24, 2024 · 0 comments

Comments

@dreamlychina
Copy link
Owner

import torch
import clip
from PIL import Image
from torchvision import transforms

class ClipEmbeding:
device = "cuda" if torch.cuda.is_available() else "cpu"

def __init__(self):
    self.model, self.processor = clip.load("ViT-B/32", device=self.device)
    self.tokenizer = clip.tokenize

def probs(self, image: Image):
    process_image = self.processor(image).unsqueeze(0).to(self.device)
    text = self.tokenizer(["This is a medical laboratory report.", "cat", "dog"]).to(self.device)

    with torch.no_grad():
        logits_per_image, logits_per_text = self.model(process_image, text)
        print("logits_per_image: ", logits_per_image.shape)
        print("logits_per_text; ", logits_per_text.shape)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    print("Label probs:", probs)


def embeding(self, image: Image, text: str):
    process_image = self.processor(image).unsqueeze(0).to(self.device)
    text = self.tokenizer(text).to(self.device)

    image_features = self.model.encode_image(process_image)
    text_features = self.model.encode_text(text)
    print(image_features.shape)
    print(text_features.shape)

    #img_text_features = torch.cat([image_features, text_features], dim=1)
    img_text_features = image_features+text_features
    print(img_text_features.shape)


    return img_text_features

import time

if name == "main":
device = "cuda" if torch.cuda.is_available() else "cpu"
model, processor = clip.load("ViT-B/32", device=device)
tokenizer = clip.tokenize

images_list =[]
text_list = []
for i in range(500):
    images_list.append("med_1.png")
    text_list.append("This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report This is a medical laboratory report")

start_time = time.time() 

image_inputs=[]
for img_name in images_list:
    img = Image.open(img_name)
    image_input = processor(img).unsqueeze(0).to(device)
    image_inputs.append(image_input)

temp = torch.stack(image_inputs)
temp = temp.squeeze().to(device) 

print("temp", temp.shape)
print("temp", temp.device)
with torch.no_grad():

    all_image_features = model.encode_image(temp)
    text = tokenizer(text_list).to(device)
    text_features = model.encode_text(text)
end_time = time.time()  
print(f"耗时: {end_time - start_time}秒")

print(all_image_features.shape)
print(text_features.shape)
print((all_image_features+text_features).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant