Skip to content

Commit

Permalink
feat: Add model and BERT model download functionality
Browse files Browse the repository at this point in the history
- Implemented logic to download the main model and BERT model if not present locally
- Added remote URLs and MD5 checksums for model verification
  • Loading branch information
healthonrails committed Dec 23, 2024
1 parent 3debde1 commit eb9cf55
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion annolid/detector/countgd/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .GroundingDINO import build_groundingdino

def build_model(args):
return build(args)
return build_groundingdino(args)
19 changes: 19 additions & 0 deletions annolid/detector/countgd/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import torch
import os
import gdown
from PIL import Image
import numpy as np
import argparse
Expand Down Expand Up @@ -37,6 +38,24 @@ def __init__(self, model_path: str = "checkpoint_best_regular.pth",
self.here = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(self.here, config_path)
model_path = os.path.join(self.here, model_path)
self._REMOTE_MODEL_URL = "https://github.com/healthonrails/annolid/releases/download/v1.2.0/checkpoint_best_regular.pth"
self._MD5 = "1492bfdd161ac1de471d0aafb32b174d"
if not os.path.exists(model_path):
gdown.cached_download(self._REMOTE_MODEL_URL,
model_path,
md5=self._MD5
)

self._REMOTE_BERT_MODEL_URL = "https://github.com/healthonrails/annolid/releases/download/v1.2.0/model.safetensors"
self._BERT_MD5 = "cd18ceb6b110c04a8033ce01de41b0b7"
self._BERT_MODEL_PATH = os.path.join(
self.here, "checkpoints/bert-base-uncased/model.safetensors")
if not os.path.exists(self._BERT_MODEL_PATH):
gdown.cached_download(self._REMOTE_BERT_MODEL_URL,
self._BERT_MODEL_PATH,
md5=self._BERT_MD5
)

self.model = self._load_model(model_path, config_path, self.device)
self.transform = self._build_transforms()

Expand Down

0 comments on commit eb9cf55

Please sign in to comment.