From 4c6ce450b6d937862aba6597e673ca6622b45194 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Tue, 14 May 2024 08:22:43 -0400 Subject: [PATCH] Fix: model id and model dir check order (#3680) * fix huggingface runtime in chart Signed-off-by: Dan Sun * Allow model_dir to be specified on template Signed-off-by: Dan Sun * Default model_dir to /mnt/models for HF Signed-off-by: Dan Sun * Lint format Signed-off-by: Dan Sun --------- Signed-off-by: Dan Sun --- .../huggingfaceserver/huggingfaceserver/__main__.py | 13 ++++++++----- python/lgbserver/lgbserver/__main__.py | 5 ++--- python/pmmlserver/pmmlserver/__main__.py | 1 - python/sklearnserver/sklearnserver/__main__.py | 2 -- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/huggingfaceserver/huggingfaceserver/__main__.py b/python/huggingfaceserver/huggingfaceserver/__main__.py index 5a94e098f17..5213fd55690 100644 --- a/python/huggingfaceserver/huggingfaceserver/__main__.py +++ b/python/huggingfaceserver/huggingfaceserver/__main__.py @@ -53,10 +53,12 @@ def list_of_strings(arg): parser.add_argument( "--model_dir", required=False, - default=None, + default="/mnt/models", help="A URI pointer to the model binary", ) -parser.add_argument("--model_id", required=False, help="Huggingface model id") +parser.add_argument( + "--model_id", required=False, default=None, help="Huggingface model id" +) parser.add_argument( "--model_revision", required=False, default=None, help="Huggingface model revision" ) @@ -131,10 +133,11 @@ def list_of_strings(arg): def load_model(): engine_args = None - if args.model_dir: - model_id_or_path = Path(Storage.download(args.model_dir)) - else: + # If --model_id is specified then pass model_id to HF API, otherwise load the model from /mnt/models + if args.model_id: model_id_or_path = cast(str, args.model_id) + else: + model_id_or_path = Path(Storage.download(args.model_dir)) if model_id_or_path is None: raise ValueError("You must provide a model_id or model_dir") diff --git a/python/lgbserver/lgbserver/__main__.py b/python/lgbserver/lgbserver/__main__.py index 59db7a464a9..2a286e3a7f6 100644 --- a/python/lgbserver/lgbserver/__main__.py +++ b/python/lgbserver/lgbserver/__main__.py @@ -21,7 +21,6 @@ import kserve from kserve.errors import ModelMissingError -DEFAULT_LOCAL_MODEL_DIR = "/tmp/model" DEFAULT_NTHREAD = 1 parser = argparse.ArgumentParser( @@ -49,5 +48,5 @@ ) model_repository = LightGBMModelRepository(args.model_dir, args.nthread) # LightGBM doesn't support multi-process, so the number of http server workers should be 1. - kfserver = kserve.ModelServer(workers=1, registered_models=model_repository) - kfserver.start([model] if model.ready else []) + server = kserve.ModelServer(workers=1, registered_models=model_repository) + server.start([model] if model.ready else []) diff --git a/python/pmmlserver/pmmlserver/__main__.py b/python/pmmlserver/pmmlserver/__main__.py index b080f31e392..88e1481c3d5 100644 --- a/python/pmmlserver/pmmlserver/__main__.py +++ b/python/pmmlserver/pmmlserver/__main__.py @@ -19,7 +19,6 @@ import kserve from kserve.errors import WorkersShouldBeLessThanMaxWorkersError -DEFAULT_LOCAL_MODEL_DIR = "/tmp/model" parser = argparse.ArgumentParser(parents=[kserve.model_server.parser]) parser.add_argument( diff --git a/python/sklearnserver/sklearnserver/__main__.py b/python/sklearnserver/sklearnserver/__main__.py index 798213c1d2a..6e4b01291f0 100644 --- a/python/sklearnserver/sklearnserver/__main__.py +++ b/python/sklearnserver/sklearnserver/__main__.py @@ -20,8 +20,6 @@ import kserve from kserve.errors import ModelMissingError -DEFAULT_LOCAL_MODEL_DIR = "/tmp/model" - parser = argparse.ArgumentParser(parents=[kserve.model_server.parser]) parser.add_argument( "--model_dir", required=True, help="A local path to the model binary"