diff --git a/annolid/detector/countgd/models/__init__.py b/annolid/detector/countgd/models/__init__.py index 8b123bb..0e21950 100644 --- a/annolid/detector/countgd/models/__init__.py +++ b/annolid/detector/countgd/models/__init__.py @@ -7,4 +7,4 @@ from .GroundingDINO import build_groundingdino def build_model(args): - return build(args) + return build_groundingdino(args) diff --git a/annolid/detector/countgd/predict.py b/annolid/detector/countgd/predict.py index b81fd32..b534d97 100644 --- a/annolid/detector/countgd/predict.py +++ b/annolid/detector/countgd/predict.py @@ -1,6 +1,7 @@ import random import torch import os +import gdown from PIL import Image import numpy as np import argparse @@ -37,6 +38,24 @@ def __init__(self, model_path: str = "checkpoint_best_regular.pth", self.here = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(self.here, config_path) model_path = os.path.join(self.here, model_path) + self._REMOTE_MODEL_URL = "https://github.com/healthonrails/annolid/releases/download/v1.2.0/checkpoint_best_regular.pth" + self._MD5 = "1492bfdd161ac1de471d0aafb32b174d" + if not os.path.exists(model_path): + gdown.cached_download(self._REMOTE_MODEL_URL, + model_path, + md5=self._MD5 + ) + + self._REMOTE_BERT_MODEL_URL = "https://github.com/healthonrails/annolid/releases/download/v1.2.0/model.safetensors" + self._BERT_MD5 = "cd18ceb6b110c04a8033ce01de41b0b7" + self._BERT_MODEL_PATH = os.path.join( + self.here, "checkpoints/bert-base-uncased/model.safetensors") + if not os.path.exists(self._BERT_MODEL_PATH): + gdown.cached_download(self._REMOTE_BERT_MODEL_URL, + self._BERT_MODEL_PATH, + md5=self._BERT_MD5 + ) + self.model = self._load_model(model_path, config_path, self.device) self.transform = self._build_transforms()