Skip to content

Commit

Permalink
Fix: model id and model dir check order (kserve#3680)
Browse files Browse the repository at this point in the history
* fix huggingface runtime in chart

Signed-off-by: Dan Sun <[email protected]>

* Allow model_dir to be specified on template

Signed-off-by: Dan Sun <[email protected]>

* Default model_dir to /mnt/models for HF

Signed-off-by: Dan Sun <[email protected]>

* Lint format

Signed-off-by: Dan Sun <[email protected]>

---------

Signed-off-by: Dan Sun <[email protected]>
  • Loading branch information
yuzisun authored May 14, 2024
1 parent 024f69b commit 4c6ce45
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
13 changes: 8 additions & 5 deletions python/huggingfaceserver/huggingfaceserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def list_of_strings(arg):
parser.add_argument(
"--model_dir",
required=False,
default=None,
default="/mnt/models",
help="A URI pointer to the model binary",
)
parser.add_argument("--model_id", required=False, help="Huggingface model id")
parser.add_argument(
"--model_id", required=False, default=None, help="Huggingface model id"
)
parser.add_argument(
"--model_revision", required=False, default=None, help="Huggingface model revision"
)
Expand Down Expand Up @@ -131,10 +133,11 @@ def list_of_strings(arg):

def load_model():
engine_args = None
if args.model_dir:
model_id_or_path = Path(Storage.download(args.model_dir))
else:
# If --model_id is specified then pass model_id to HF API, otherwise load the model from /mnt/models
if args.model_id:
model_id_or_path = cast(str, args.model_id)
else:
model_id_or_path = Path(Storage.download(args.model_dir))

if model_id_or_path is None:
raise ValueError("You must provide a model_id or model_dir")
Expand Down
5 changes: 2 additions & 3 deletions python/lgbserver/lgbserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import kserve
from kserve.errors import ModelMissingError

DEFAULT_LOCAL_MODEL_DIR = "/tmp/model"
DEFAULT_NTHREAD = 1

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -49,5 +48,5 @@
)
model_repository = LightGBMModelRepository(args.model_dir, args.nthread)
# LightGBM doesn't support multi-process, so the number of http server workers should be 1.
kfserver = kserve.ModelServer(workers=1, registered_models=model_repository)
kfserver.start([model] if model.ready else [])
server = kserve.ModelServer(workers=1, registered_models=model_repository)
server.start([model] if model.ready else [])
1 change: 0 additions & 1 deletion python/pmmlserver/pmmlserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import kserve
from kserve.errors import WorkersShouldBeLessThanMaxWorkersError

DEFAULT_LOCAL_MODEL_DIR = "/tmp/model"

parser = argparse.ArgumentParser(parents=[kserve.model_server.parser])
parser.add_argument(
Expand Down
2 changes: 0 additions & 2 deletions python/sklearnserver/sklearnserver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import kserve
from kserve.errors import ModelMissingError

DEFAULT_LOCAL_MODEL_DIR = "/tmp/model"

parser = argparse.ArgumentParser(parents=[kserve.model_server.parser])
parser.add_argument(
"--model_dir", required=True, help="A local path to the model binary"
Expand Down

0 comments on commit 4c6ce45

Please sign in to comment.