Skip to content

Commit

Permalink
move _resolve_model_lora info base model
Browse files Browse the repository at this point in the history
  • Loading branch information
mitya52 committed Mar 9, 2024
1 parent 2d117b7 commit d07d816
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,40 +364,8 @@ async def _secret_key_activate(self, authorization: str = Header(None)):
"human_readable_message": "API key verified",
}

def _resolve_model_lora(self, model_name: str) -> Tuple[str, Optional[Dict[str, str]]]:
model_name, run_id, checkpoint_id = (*model_name.split(":"), None, None)[:3]

if model_name not in self._model_assigner.models_db:
return model_name, None

active_loras: List[Dict[str, str]] = get_active_loras({
model_name: self._model_assigner.models_db[model_name]
})[model_name].get("loras", [])

if not active_loras:
return model_name, None

if run_id is None:
run_id = active_loras[0]["run_id"]
checkpoint_id = active_loras[0]["checkpoint"]
else:
run_checkpoints = [
lora_info["checkpoint"]
for lora_info in active_loras
if lora_info["run_id"] == run_id
]
if not run_checkpoints:
return model_name, None

if checkpoint_id is None:
checkpoint_id = run_checkpoints[0]
elif checkpoint_id not in run_checkpoints:
return model_name, None

return model_name, {
"run_id": run_id,
"checkpoint_id": checkpoint_id,
}
async def _resolve_model_lora(self, model_name: str, account: str) -> Tuple[str, Optional[Dict[str, str]]]:
raise NotImplementedError()

async def _completions(self, post: NlpCompletion, authorization: str = Header(None)):
account = await self._account_from_bearer(authorization)
Expand All @@ -406,7 +374,7 @@ async def _completions(self, post: NlpCompletion, authorization: str = Header(No
req = post.clamp()
caps_version = self._model_assigner.config_inference_mtime() # use mtime as a version, if that changes the client will know to refresh caps

model_name, lora_config = self._resolve_model_lora(post.model)
model_name, lora_config = await self._resolve_model_lora(post.model, account)
model_name, err_msg = static_resolve_model(post.model, self._inference_queue)
if err_msg:
log("%s model resolve \"%s\" -> error \"%s\" from %s" % (ticket.id(), post.model, err_msg, account))
Expand Down Expand Up @@ -630,3 +598,38 @@ async def _account_from_bearer(self, authorization: str) -> str:
return self._session.header_authenticate(authorization)
except BaseException as e:
raise HTTPException(status_code=401, detail=str(e))

async def _resolve_model_lora(self, model_name: str, account: str) -> Tuple[str, Optional[Dict[str, str]]]:
model_name, run_id, checkpoint_id = (*model_name.split(":"), None, None)[:3]

if model_name not in self._model_assigner.models_db:
return model_name, None

active_loras: List[Dict[str, str]] = get_active_loras({
model_name: self._model_assigner.models_db[model_name]
})[model_name].get("loras", [])

if not active_loras:
return model_name, None

if run_id is None:
run_id = active_loras[0]["run_id"]
checkpoint_id = active_loras[0]["checkpoint"]
else:
run_checkpoints = [
lora_info["checkpoint"]
for lora_info in active_loras
if lora_info["run_id"] == run_id
]
if not run_checkpoints:
return model_name, None

if checkpoint_id is None:
checkpoint_id = run_checkpoints[0]
elif checkpoint_id not in run_checkpoints:
return model_name, None

return model_name, {
"run_id": run_id,
"checkpoint_id": checkpoint_id,
}

0 comments on commit d07d816

Please sign in to comment.