Skip to content

Commit

Permalink
Fix eval_time calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Sep 30, 2024
1 parent 9c750c6 commit 5b9330d
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions sillm/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import pathlib
import json
import typing

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -14,6 +15,7 @@
from sillm.training.dataset import Dataset
from sillm.core.cache import KVCache, PromptCache
from sillm.modules.switch import SwitchLinear
from sillm.experimental.logit_filter import LogitFilter

logger = logging.getLogger("sillm")

Expand Down Expand Up @@ -416,7 +418,7 @@ def generate(model,
flush: int = 5,
extra_stop_tokens: list = None,
prompt_cache: PromptCache = None,
logit_mask: mx.array = None
logit_filter: LogitFilter = None
):
start = time.perf_counter()

Expand Down Expand Up @@ -467,12 +469,12 @@ def sample(logits):
# Apply temperature
logits = logits * (1 / temperature)

# Apply logit mask
if logit_mask is not None:
logits = logits * logit_mask
# Apply structure enforcer
if logit_filter is not None:
logits = logit_filter(logits)
# Apply repetition penalty
if len(tokens) > 0 and repetition_penalty is not None:
logits = sampling.apply_repetition_penalty(logits, tokens)
logits = sampling.apply_repetition_penalty(logits, tokens, repetition_penalty=repetition_penalty, repetition_window=repetition_window)
# Apply top-k sampling
if top_k > 0:
logits = sampling.top_k(logits, k=top_k)
Expand Down Expand Up @@ -520,7 +522,7 @@ def generate_step(model, inputs):
# Main generation loop
for (token,p), i in zip(generate_step(model, inputs), range(max_tokens)):
if i == 0:
mx.async_eval(token)
mx.eval(token)
timing["eval_time"] = time.perf_counter() - start

if token.item() in stop_tokens:
Expand All @@ -533,7 +535,7 @@ def generate_step(model, inputs):
metadata["logprobs"].append(p)

if (len(tokens) % flush) == 0:
mx.async_eval(tokens)
mx.eval(tokens)

text_offset = len(text)
text = tokenizer.decode(tokens)
Expand Down

0 comments on commit 5b9330d

Please sign in to comment.