Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.3.1 #274

Merged
merged 8 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ ENV RDMAV_HUGEPAGES_SAFE 0

EXPOSE 8008

COPY database-start.sh /
RUN chmod +x database-start.sh
COPY docker-entrypoint.sh /
RUN chmod +x docker-entrypoint.sh

Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ On Windows you need to install WSL 2 first, [one guide to do this](https://docs.

Run docker container with following command:
```commandline
docker run -d --rm --gpus all -p 8008:8008 -v refact-perm-storage:/perm_storage -v refact-database:/var/lib/cassandra smallcloud/refact_self_hosting:latest
docker run -d --rm --gpus all -p 8008:8008 -v refact-perm-storage:/perm_storage smallcloud/refact_self_hosting:latest
```

`perm-storage` is a volume that is mounted inside the container. All the configuration files, downloaded weights and logs are stored here.
`refact-database` is a volume for database where server stores statistics from your users.

To upgrade the docker, delete it using `docker kill XXX` (the volume `perm-storage` will retain your
data), run `docker pull smallcloud/refact_self_hosting` and run it again.
Expand Down
27 changes: 27 additions & 0 deletions database-start.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/sh
REFACT_CASSANDRA_DIR="$REFACT_PERM_DIR/cassandra"
if [ ! -d "$REFACT_CASSANDRA_DIR" ]; then
mkdir -p "$REFACT_CASSANDRA_DIR"
chown cassandra:cassandra "$REFACT_CASSANDRA_DIR"
if [ ! -z "$(ls /var/lib/cassandra)" ]; then
cp -rp /var/lib/cassandra/* "$REFACT_CASSANDRA_DIR"
fi
cp -rp /var/log/cassandra "$REFACT_CASSANDRA_DIR/log"
fi
# patch cassandra config to work with REFACT_CASSANDRA_DIR
sed -i "s|/var/lib/cassandra|$REFACT_CASSANDRA_DIR|g" /etc/cassandra/cassandra.yaml
# patch cassandra.in.sh for less memory consumption and logging to REFACT_CASSANDRA_DIR/log
REFACT_CASSANDRA_INCLUDE=/usr/sbin/cassandra.in.sh
cp /usr/share/cassandra/cassandra.in.sh "$REFACT_CASSANDRA_INCLUDE"
echo "MAX_HEAP_SIZE=2G" >> "$REFACT_CASSANDRA_INCLUDE"
echo "HEAP_NEWSIZE=400M" >> "$REFACT_CASSANDRA_INCLUDE"
echo "CASSANDRA_LOG_DIR=$REFACT_CASSANDRA_DIR/log" >> "$REFACT_CASSANDRA_INCLUDE"

if [ ! -z "$(service cassandra status | grep 'could not access pidfile')" ]; then
rm /var/run/cassandra/cassandra.pid
fi

if [ ! -z "$(service cassandra status | grep 'not running')" ]; then
service cassandra start
echo "cassandra database started on localhost"
fi
3 changes: 1 addition & 2 deletions docker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/bin/sh
if [ -z "$REFACT_DATABASE_HOST" ]; then
sudo service cassandra start
echo "cassandra database started on localhost"
sh database-start.sh
fi
python -m self_hosting_machinery.watchdog.docker_watchdog
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def robot_human_ratio(robot: int, human: int) -> float:
return 1
if robot == 0:
return 0
# in older versions of refact LSP negative values of human metric existed
if robot + human == 0:
return 0
return round(robot / (robot + human), 2)


Expand Down Expand Up @@ -181,10 +184,10 @@ def extract_stats(df: pd.DataFrame, date_kind: str) -> Dict:
if lang not in languages:
continue
res_loc[lang] = {
"Assistant": (robot := int(group["robot_characters"].sum())),
"Refact": (robot := int(group["robot_characters"].sum())),
"Human": (human := int(group["human_characters"].sum())),
"Total (characters)": robot + human,
"A/(A+H)": robot_human_ratio(robot, human),
"Refact Impact": robot_human_ratio(robot, human),
"Completions": int(group["completions_cnt"].sum()),
"Users": int(group["tenant_name"].nunique()),
}
Expand All @@ -194,7 +197,7 @@ def extract_stats(df: pd.DataFrame, date_kind: str) -> Dict:
res_loc = {
'data': fmt_vals,
'columns': ['Language', *res_loc[list(res_loc.keys())[0]].keys()],
'title': f"Assistant's impact by language: {date_kind}"
'title': f"Refact's impact by language: {date_kind}"
}
return res_loc

Expand Down
3 changes: 3 additions & 0 deletions self_hosting_machinery/finetune/scripts/finetune_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def _get_file_loss(file) -> float:
except InvalidLossValueException as e:
files_status_context.reject_file(file, reason=str(e))
continue
except Exception as e:
files_status_context.reject_file(file, reason=str(e))
continue

if file_loss > filter_loss_threshold:
files_status_context.reject_file(file, reason=f"loss {file_loss:.3f}")
Expand Down
2 changes: 0 additions & 2 deletions self_hosting_machinery/scripts/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
FLAG_LAUNCH_FINETUNE = os.path.join(DIR_WATCHDOG_D, "_launch_finetune.flag")
FLAG_STOP_FINETUNE = os.path.join(DIR_WATCHDOG_D, "_stop_finetune.flag")

FLAG_RESTART_LSP = os.path.join(DIR_WATCHDOG_D, "_restart_lsp.flag")

def create_dirs():
os.makedirs(DIR_WATCHDOG_D, exist_ok=True)
os.makedirs(DIR_WEIGHTS, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion self_hosting_machinery/watchdog/docker_watchdog.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def log(*args):
list_of_files.sort()
while len(list_of_files) > 20:
try:
os.remove(list_of_files.pop())
os.remove(list_of_files.pop(0))
except OSError:
pass
with open(os.path.join(env.DIR_LOGS, "watchdog_%s.log" % date), "a") as f:
Expand Down
5 changes: 1 addition & 4 deletions self_hosting_machinery/watchdog/watchdog.d/lsp.cfg
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
{
"policy": ["always_on"],
"interrupt_when_file_appears": "%FLAG_RESTART_LSP%",
"command_line": [
"refact-lsp",
"--address-url", "http://127.0.0.1:8008",
"--http-port", "8001",
"--lsp-port", "8002",
"--logs-stderr"
"--http-port", "8001"
],
"gpus": []
}
104 changes: 14 additions & 90 deletions self_hosting_machinery/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import litellm

from fastapi import APIRouter, Request, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse

from self_hosting_machinery import env
from self_hosting_machinery.webgui.selfhost_model_resolve import completion_resolve_model
from self_hosting_machinery.webgui.selfhost_model_resolve import static_resolve_model
from self_hosting_machinery.webgui.selfhost_req_queue import Ticket
from self_hosting_machinery.webgui.selfhost_queue import Ticket
from self_hosting_machinery.webgui.selfhost_webutils import log
from self_hosting_machinery.webgui.selfhost_queue import InferenceQueue
from self_hosting_machinery.webgui.selfhost_model_assigner import ModelAssigner
Expand Down Expand Up @@ -107,7 +107,7 @@ class ChatContext(NlpSamplingParams):
function: str = Query(default="chat", regex="^[a-zA-Z0-9_\.\-]+$")


async def completion_streamer(ticket: Ticket, post: NlpCompletion, timeout, seen, created_ts):
async def _completion_streamer(ticket: Ticket, post: NlpCompletion, timeout, seen, created_ts, caps_version: int):
try:
packets_cnt = 0
while 1:
Expand All @@ -117,6 +117,7 @@ async def completion_streamer(ticket: Ticket, post: NlpCompletion, timeout, seen
log("TIMEOUT %s" % ticket.id())
msg = {"status": "error", "human_readable_message": "timeout"}
not_seen_resp = copy.deepcopy(msg)
not_seen_resp["caps_version"] = caps_version
if "choices" in not_seen_resp:
for i in range(post.n):
newtext = not_seen_resp["choices"][i]["text"]
Expand Down Expand Up @@ -152,40 +153,6 @@ async def completion_streamer(ticket: Ticket, post: NlpCompletion, timeout, seen
ticket.cancelled = True


async def diff_streamer(ticket: Ticket, post: DiffCompletion, timeout, created_ts):
try:
while 1:
try:
msg = await asyncio.wait_for(ticket.streaming_queue.get(), timeout)
except asyncio.TimeoutError:
log("TIMEOUT %s" % ticket.id())
msg = {"status": "error", "human_readable_message": "timeout"}
if not post.stream:
if msg.get("status", "") == "in_progress":
continue
yield json.dumps(msg)
break
tmp = json.dumps(msg)
yield "data: " + tmp + "\n\n"
log(" " + red_time(created_ts) + " stream %s <- %i bytes" % (ticket.id(), len(tmp)))
if msg.get("status", "") != "in_progress":
break
if post.stream:
yield "data: [DONE]" + "\n\n"
log(red_time(created_ts) + " /finished call %s" % ticket.id())
ticket.done()
# fastapi_stats.stats_accum[kt] += msg.get("generated_tokens_n", 0)
# fastapi_stats.stats_accum[kcomp] += 1
# fastapi_stats.stats_lists_accum["stat_latency_" + post.model].append(time.time() - created_ts)
finally:
if ticket.id() is not None:
log(" *** CANCEL *** cancelling %s " % ticket.id() + red_time(created_ts))
# fastapi_stats.stats_accum["stat_api_cancelled"] += 1
# fastapi_stats.stats_accum["stat_m_" + post.model + "_cancelled"] += 1
ticket.cancelled = True
ticket.done()


async def chat_streamer(ticket: Ticket, timeout, created_ts):
seen: Dict[int, str] = dict()
try:
Expand Down Expand Up @@ -240,7 +207,6 @@ def __init__(self,
# API for direct FIM and Chat usage
self.add_api_route("/v1/login", self._login, methods=["GET"])
self.add_api_route("/v1/secret-key-activate", self._secret_key_activate, methods=["GET"])
self.add_api_route("/v1/contrast", self._contrast, methods=["POST"])
self.add_api_route("/v1/chat", self._chat, methods=["POST"])

# API for LSP server
Expand Down Expand Up @@ -275,10 +241,11 @@ async def _coding_assistant_caps(self):
or model_name in litellm.model_list:
code_chat_default_model = model_name
break
return {
config_mtime = self._model_assigner.config_inference_mtime()
data = {
"cloud_name": "Refact Self-Hosted",
"endpoint_template": "v1/completions",
"endpoint_chat_passthrough": "v1/chat/completions",
"endpoint_template": "/v1/completions",
"endpoint_chat_passthrough": "/v1/chat/completions",
"endpoint_style": "openai",
"telemetry_basic_dest": "/stats/telemetry-basic",
"telemetry_corrected_snippets_dest": "/stats/telemetry-snippets",
Expand All @@ -291,7 +258,9 @@ async def _coding_assistant_caps(self):
for model in models_available
if model in self._model_assigner.models_db
},
"caps_version": config_mtime,
}
return Response(content=json.dumps(data, indent=4), media_type="application/json")

async def _login(self):
longthink_functions = dict()
Expand Down Expand Up @@ -348,10 +317,12 @@ async def _secret_key_activate(self):
async def _completions(self, post: NlpCompletion, account: str = "user"):
ticket = Ticket("comp-")
req = post.clamp()
caps_version = self._model_assigner.config_inference_mtime() # use mtime as a version, if that changes the client will know to refresh caps
model_name, err_msg = static_resolve_model(post.model, self._inference_queue)
if err_msg:
log("%s model resolve \"%s\" -> error \"%s\" from %s" % (ticket.id(), post.model, err_msg, account))
raise HTTPException(status_code=400, detail=err_msg)
return Response(status_code=400, content=json.dumps({"detail": err_msg, "caps_version": caps_version}, indent=4), media_type="application/json")

log("%s model resolve \"%s\" -> \"%s\" from %s" % (ticket.id(), post.model, model_name, account))
req.update({
"object": "text_completion_req",
Expand All @@ -367,57 +338,10 @@ async def _completions(self, post: NlpCompletion, account: str = "user"):
await q.put(ticket)
seen = [""] * post.n
return StreamingResponse(
completion_streamer(ticket, post, self._timeout, seen, req["created"]),
_completion_streamer(ticket, post, self._timeout, seen, req["created"], caps_version=caps_version),
media_type=("text/event-stream" if post.stream else "application/json"),
)

async def _contrast(self, post: DiffCompletion, request: Request, account: str = "user"):
if post.function != "diff-anywhere":
if post.cursor_file not in post.sources:
raise HTTPException(status_code=400, detail="cursor_file='%s' is not in sources=%s" % (post.cursor_file, list(post.sources.keys())))
if post.cursor0 < 0 or post.cursor1 < 0:
raise HTTPException(status_code=400, detail="cursor0=%d or cursor1=%d is negative" % (post.cursor0, post.cursor1))
filetext = post.sources[post.cursor_file]
if post.cursor0 > len(filetext) or post.cursor1 > len(filetext):
raise HTTPException(status_code=400, detail="cursor0=%d or cursor1=%d is beyond file length=%d" % (post.cursor0, post.cursor1, len(filetext)))
for fn, text in post.sources.items():
if len(text) > 180*1024:
raise HTTPException(status_code=400, detail="file '%s' is too long (%d bytes)" % (fn, len(text)))
ticket = Ticket("comp-")
if post.function == "infill":
model_name, err_msg = completion_resolve_model(self._inference_queue)
else:
model_name, err_msg = static_resolve_model(post.model, self._inference_queue)
if err_msg:
log("%s model resolve \"%s\" func \"%s\" -> error \"%s\" from %s" % (ticket.id(), post.model, post.function, err_msg, account))
raise HTTPException(status_code=400, detail=err_msg)
log("%s model resolve \"%s\" func \"%s\" -> \"%s\" from %s" % (ticket.id(), post.model, post.function, model_name, account))
if post.function == "highlight":
post.max_tokens = 0
req = post.clamp()
req.update({
"object": "diff_completion_req",
"account": account,
"model": model_name,
"intent": post.intent,
"sources": post.sources,
"cursor_file": post.cursor_file,
"cursor0": post.cursor0,
"cursor1": post.cursor1,
"function": post.function,
"max_edits": post.max_edits,
"stream": post.stream,
})
post_raw = await request.json()
if "poi" in post_raw:
req["poi"] = post_raw["poi"]
ticket.call.update(req)
q = self._inference_queue.model_name_to_queue(ticket, model_name)
# kt, kcomp = await _model_hit(red, ticket, req, model_name, account)
self._id2ticket[ticket.id()] = ticket
await q.put(ticket)
return StreamingResponse(diff_streamer(ticket, post, self._timeout, req["created"]))

async def _chat(self, post: ChatContext, request: Request, account: str = "user"):
ticket = Ticket("comp-")

Expand Down
2 changes: 1 addition & 1 deletion self_hosting_machinery/webgui/selfhost_fastapi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import APIRouter, Query, Request, Header, HTTPException

from self_hosting_machinery.webgui.selfhost_req_queue import Ticket
from self_hosting_machinery.webgui.selfhost_queue import Ticket
from self_hosting_machinery.webgui.selfhost_webutils import log
from self_hosting_machinery.webgui.selfhost_queue import InferenceQueue

Expand Down
11 changes: 7 additions & 4 deletions self_hosting_machinery/webgui/selfhost_model_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def model_assignment(self):
}
return j

@staticmethod
def restart_lsp():
with open(env.FLAG_RESTART_LSP, "w") as f:
f.write("")
def config_inference_mtime(self) -> int:
if os.path.exists(env.CONFIG_INFERENCE):
try:
return int(os.path.getmtime(env.CONFIG_INFERENCE))
except OSError:
return 0
return 0
2 changes: 1 addition & 1 deletion self_hosting_machinery/webgui/selfhost_model_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def _family(model: str) -> str:
if not model_name or _family(model_name) == _family(have_model):
return have_model, ""
else:
return "", f"model is not loaded (3)"
return "", f"model \"{model_name}\" is not loaded (3)"
20 changes: 19 additions & 1 deletion self_hosting_machinery/webgui/selfhost_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,25 @@
from self_hosting_machinery import env
from self_hosting_machinery.webgui.selfhost_webutils import log
from fastapi import HTTPException
from typing import Dict, List
from typing import Dict, List, Any
import uuid


class Ticket:
def __init__(self, id_prefix):
self.call: Dict[str, Any] = dict()
random_guid = str(uuid.uuid4()).replace("-", "")[0:12]
self.call["id"] = id_prefix + random_guid
self.cancelled: bool = False
self.processed_by_infmod_guid: str = ""
self.streaming_queue = asyncio.queues.Queue()

def id(self):
return self.call.get("id", None)

def done(self):
if "id" in self.call:
del self.call["id"]


class InferenceQueue:
Expand Down
Loading
Loading