Skip to content

Commit

Permalink
Refactor logit filter
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Sep 30, 2024
1 parent 5b9330d commit a99e0f0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sillm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import sillm
import sillm.utils as utils
import sillm.experimental.structure as structure

if __name__ == "__main__":
# Parse commandline arguments
Expand Down Expand Up @@ -77,6 +76,11 @@
if args.cache > 0:
prompt_cache = sillm.PromptCache(max_size=args.cache)

# Initialize logit filters
logit_filter = None
if args.ascii:
logit_filter = sillm.experimental.logit_filter.ASCIIFilter(model.tokenizer, model.args.vocab_size)

generate_args = {
"temperature": args.temperature,
"top_k": args.top_k,
Expand All @@ -85,12 +89,10 @@
"repetition_window": args.repetition_window,
"max_tokens": args.max_tokens,
"flush": args.flush,
"prompt_cache": prompt_cache
"prompt_cache": prompt_cache,
"logit_filter": logit_filter
}

if args.ascii:
generate_args["logit_mask"] = structure.ascii_token_logit_mask(model.tokenizer, model.args.vocab_size)

# Init conversation template
template = sillm.init_template(model.tokenizer, model.args, args.template)
conversation = sillm.Conversation(template, system_prompt=args.system_prompt)
Expand Down
43 changes: 43 additions & 0 deletions sillm/experimental/logit_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import string
import re

import mlx.core as mx

class LogitFilter:
def __init__(self,
tokenizer,
output_size: int
):
self.tokenizer = tokenizer
self.output_size = output_size

def reset(self):
pass

def __call__(self,
logits: mx.array
) -> mx.array:
raise NotImplementedError("Class structure.StructureEnforcer is used for inheritance only")

class ASCIIFilter(LogitFilter):
"""
Static logit mask filtering out tokens with non-ASCII printable characters.
"""
def __init__(self,
tokenizer,
output_size: int
):
mask = mx.zeros(output_size)
for i, s in enumerate(tokenizer.vocab_strings):
if all(c in string.printable for c in s.strip()):
mask[i] = 1.0

for i in tokenizer.special_ids:
mask[i] = 1.0

self.mask = mask

def __call__(self,
logits: mx.array
) -> mx.array:
return logits * self.mask

0 comments on commit a99e0f0

Please sign in to comment.