diff --git a/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact_webgui/webgui/selfhost_fastapi_completions.py index 19cff522..86e5f955 100644 --- a/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -545,7 +545,7 @@ async def litellm_streamer(): # NOTE: DONE needed by refact-lsp server yield prefix + "[DONE]" + postfix except BaseException as e: - err_msg = f"litellm error: {e}" + err_msg = f"litellm error (1): {e}" log(err_msg) yield prefix + json.dumps({"error": err_msg}) + postfix @@ -575,7 +575,7 @@ async def litellm_non_streamer(): data = {"choices": [{"finish_reason": finish_reason}]} yield json.dumps(data) except BaseException as e: - err_msg = f"litellm error: {e}" + err_msg = f"litellm error (2): {e}" log(err_msg) yield json.dumps({"error": err_msg}) diff --git a/self_hosting_machinery/inference/inference_worker.py b/self_hosting_machinery/inference/inference_worker.py index ece4add7..18e77761 100644 --- a/self_hosting_machinery/inference/inference_worker.py +++ b/self_hosting_machinery/inference/inference_worker.py @@ -1,3 +1,4 @@ +import os import sys import logging import time @@ -80,8 +81,9 @@ def check_cancelled(*args, **kwargs): log("STATUS serving %s" % model_name) req_session = infserver_session() + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").replace(",", "") description_dict = validate_description_dict( - model_name + "_" + socket.getfqdn(), + f'{model_name}_{socket.getfqdn()}_{cuda_visible_devices}', "account_name", model=model_name, B=1, max_thinking_time=10, ) @@ -121,7 +123,7 @@ def check_cancelled(*args, **kwargs): time.sleep(10) upload_proxy.stop() - log("clean shutdown") + log("inference_worker.py clean shutdown") def catch_sigkill(signum, frame): diff --git a/self_hosting_machinery/inference/stream_results.py b/self_hosting_machinery/inference/stream_results.py index 7aa0df9f..e2f59370 100644 --- a/self_hosting_machinery/inference/stream_results.py +++ b/self_hosting_machinery/inference/stream_results.py @@ -1,6 +1,7 @@ import os import json import re +import psutil import time import datetime import termcolor @@ -171,9 +172,12 @@ def start_upload_result_daemon(self): self.proc.start() return self.proc - def stop(self): + def stop(self, timeout=10): if self.proc: self.upload_q.put(dict(exit=1)) + self.proc.join(timeout) + if self.proc.is_alive(): + self.proc.terminate() self.proc = None def __del__(self): @@ -260,7 +264,12 @@ def _upload_results_loop(upload_q: multiprocessing.Queue, cancelled_q: multiproc setproctitle.setproctitle("upload_results_loop") req_session = infserver_session() exit_flag = False + parent_pid = os.getppid() while not exit_flag: + if not psutil.pid_exists(parent_pid): + logger.warning("Parent process no longer exists, exiting.") + exit_flag = True + break try: upload_dict = upload_q.get(timeout=600) except queue.Empty as e: