Skip to content

Commit

Permalink
[Bug] [Feature] Fix issue #1039 by exposing And, Not regex operat…
Browse files Browse the repository at this point in the history
…ors from `llguidance` (#1043)

Fixes issue #1039 by ensuring that the keys for JSON properties
validated against `additionalProperties` aren't allowed to match any of
the keys for the properties themselves.

Does so by exposing + using `And` and `Not` regex primitives
  • Loading branch information
hudson-ai authored Oct 7, 2024
1 parent af27044 commit 8af45c1
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 36 deletions.
103 changes: 73 additions & 30 deletions guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,11 @@ class Terminal(GrammarFunction):
def __init__(self, *, temperature: float, capture_name: Union[str, None]):
super().__init__(capture_name=capture_name)
self.temperature = temperature
self.max_tokens = 1000000000000

def match_byte(self, byte):
pass # abstract

@property
def max_tokens(self):
return 1000000000000

class DeferredReference(Terminal):
"""Container to hold a value that is resolved at a later time. This is useful for recursive definitions."""
__slots__ = "_value"
Expand Down Expand Up @@ -496,7 +493,6 @@ class Gen(Terminal):
"stop_regex",
"save_stop_text",
"name",
"_max_tokens",
)

def __init__(
Expand All @@ -512,11 +508,7 @@ def __init__(
self.stop_regex = stop_regex
self.name = name if name is not None else GrammarFunction._new_name()
self.save_stop_text = save_stop_text
self._max_tokens = max_tokens

@property
def max_tokens(self) -> int:
return self._max_tokens
self.max_tokens = max_tokens

def __repr__(self, indent="", done=None, lbl="Gen"):
if done is None:
Expand Down Expand Up @@ -560,22 +552,50 @@ def __repr__(self, indent="", done=None):
return super().__repr__(indent, done, "Lex")


class RegularGrammar(Gen):
__slots__ = ("grammar",)
class RegularGrammar(Terminal):
__slots__ = ("grammar", "name")

def __init__(
self,
grammar: GrammarFunction,
lexeme: bool = False,
name: Union[str, None] = None,
max_tokens=100000000,
) -> None:
super().__init__("", "", name=name, max_tokens=max_tokens)
super().__init__(capture_name=None, temperature=-1)
self.grammar = grammar
self.lexeme = lexeme
self.name = name if name is not None else GrammarFunction._new_name()
self.max_tokens = max_tokens

def __repr__(self, indent="", done=None):
# TODO add grammar repr
return super().__repr__(indent, done, "RegularGrammar")

class And(Terminal):
__slots__ = ("values", "name")

def __init__(
self,
values: Sequence[GrammarFunction],
name: Union[str, None] = None,
):
super().__init__(temperature=-1, capture_name=None)
self.values = list(values)
self.name = name if name is not None else GrammarFunction._new_name()

class Not(Terminal):
__slots__ = ("value", "name")

def __init__(
self,
value: GrammarFunction,
name: Union[str, None] = None,
):
super().__init__(temperature=-1, capture_name=None)
self.value = value
self.name = name if name is not None else GrammarFunction._new_name()


class Subgrammar(Gen):
__slots__ = (
Expand Down Expand Up @@ -874,15 +894,6 @@ def _is_string_literal(node: ComposableGrammar) -> bool:
return False


def as_regular_grammar(value) -> RegularGrammar:
# TODO: assert that value is not empty since we don't yet support that
if isinstance(value, str):
value = string(value)
# check if it serializes
_ignore = LLSerializer().regex(value)
return RegularGrammar(value)


class LLSerializer:
def __init__(self) -> None:
self.nodes: list[dict] = []
Expand All @@ -906,12 +917,21 @@ def _add_regex_json(self, json):
def _add_regex(self, key: str, val):
return self._add_regex_json({key: val})

def _regex_or(self, nodes: list[GrammarFunction]):
def _regex_or(self, nodes: Sequence[GrammarFunction]):
if len(nodes) == 1:
return self.regex_id_cache[nodes[0]]
else:
return self._add_regex("Or", [self.regex_id_cache[v] for v in nodes])

def _regex_and(self, nodes: Sequence[GrammarFunction]):
if len(nodes) == 1:
return self.regex_id_cache[nodes[0]]
else:
return self._add_regex("And", [self.regex_id_cache[v] for v in nodes])

def _regex_not(self, node: GrammarFunction):
return self._add_regex("Not", self.regex_id_cache[node])

def regex(self, node: GrammarFunction):
"""
Serialize node as regex. Throws if impossible.
Expand Down Expand Up @@ -1019,6 +1039,20 @@ def check_unserializable_attrs(node: GrammarFunction):
if node.json_string:
raise ValueError("Cannot serialize lexeme with `json_string=True` as regex: " + node.__repr__())
res = self._add_regex("Regex", node.body_regex)
elif isinstance(node, And):
if not all_finished(node.values):
add_todo(node)
pending.add(node)
add_todos(node.values)
continue
res = self._regex_and(node.values)
elif isinstance(node, Not):
if not node_finished(node.value):
add_todo(node)
pending.add(node)
add_todo(node.value)
continue
res = self._regex_not(node.value)
else:
raise ValueError("Cannot serialize as regex: " + node.__repr__())
if node in pending:
Expand Down Expand Up @@ -1092,14 +1126,23 @@ def process(self, node: GrammarFunction):
}
}
elif isinstance(node, RegularGrammar):
obj = {
"Gen": {
"body_rx": self.regex(node.grammar),
"stop_rx": "",
"lazy": False, # TODO this should be True
"temperature": node.temperature if node.temperature >= 0 else None,
if node.lexeme:
obj = {
"Lexeme" : {
"rx": self.regex(node.grammar),
"contextual": False,
"json_string": False,
}
}
else:
obj = {
"Gen": {
"body_rx": self.regex(node.grammar),
"stop_rx": "",
"lazy": False, # TODO this should be True
"temperature": node.temperature if node.temperature >= 0 else None,
}
}
}
elif isinstance(node, Gen):
obj = {
"Gen": {
Expand Down
20 changes: 16 additions & 4 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from ..library import char_range, gen, one_or_more, optional, sequence
from ..library._regex_utils import rx_int_range, rx_float_range

from .._grammar import GrammarFunction, select, capture, with_temperature
from .._grammar import GrammarFunction, select, capture, with_temperature, Not, And, quote_regex
from ._pydantic import pydantic_to_json_schema
from ._subgrammar import lexeme, subgrammar
from ._subgrammar import as_regular_grammar, lexeme, subgrammar

JSONSchema = Union[bool, Mapping[str, Any]]

Expand Down Expand Up @@ -494,9 +494,21 @@ def _gen_json_object(
]
grammars = tuple(f'"{name}":' + _gen_json(json_schema=schema, definitions=definitions) for name, schema in items)
required_items = tuple(name in required for name, _ in items)

names = set(properties.keys()) | set(required)
key_grammar: GrammarFunction
if len(names) > 0:
# If there are any properties, we need to disallow them as additionalProperties
key_grammar = as_regular_grammar(
And([
lexeme(r'"([^"\\]|\\["\\/bfnrt]|\\u[0-9a-fA-F]{4})*"'),
Not(lexeme('"(' + '|'.join(quote_regex(name) for name in names) + ')"')),
]),
lexeme = True,
)
else:
key_grammar = _gen_json_string()
if additional_properties is not False:
additional_item_grammar = _gen_json_string() + ':' + _gen_json(json_schema=additional_properties, definitions=definitions)
additional_item_grammar = key_grammar + ':' + _gen_json(json_schema=additional_properties, definitions=definitions)
additional_items_grammar = sequence(additional_item_grammar + ',') + additional_item_grammar
grammars += (additional_items_grammar,)
required_items += (False,)
Expand Down
10 changes: 10 additions & 0 deletions guidance/library/_subgrammar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from guidance._grammar import LLSerializer, RegularGrammar, string
from .._grammar import Subgrammar, Lexeme, GrammarFunction, capture
from typing import Optional

Expand Down Expand Up @@ -40,3 +41,12 @@ def subgrammar(
if name:
r = capture(r, name)
return r


def as_regular_grammar(value, lexeme=False) -> RegularGrammar:
# TODO: assert that value is not empty since we don't yet support that
if isinstance(value, str):
value = string(value)
# check if it serializes
_ignore = LLSerializer().regex(value)
return RegularGrammar(value, lexeme=lexeme)
4 changes: 3 additions & 1 deletion guidance/library/_substring.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, Dict, Union

from ._subgrammar import as_regular_grammar

from .._guidance import guidance

# from ._prefix_tree import prefix_tree
from .._grammar import string, select, capture, as_regular_grammar, Terminal, GrammarFunction
from .._grammar import string, select, capture, Terminal, GrammarFunction
from ._optional import optional


Expand Down
10 changes: 10 additions & 0 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,16 @@ def test_combined_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_byte
compact=compact,
)

def test_out_of_order_non_required_properties_not_validated_as_additionalProperties(self):
schema = {
"type": "object",
"properties": {"a": {"const": "foo"}, "b": {"const": "bar"}},
"required": ["b"],
}
test_string = '{"b": "bar", "a": "BAD"}'
grammar = gen_json(schema=schema)
assert grammar.match(test_string) is None


class TestRecursiveStructures:
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_ll.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
string,
capture,
)
from guidance._grammar import as_regular_grammar
from guidance.library._subgrammar import as_regular_grammar
from guidance.library._subgrammar import subgrammar, lexeme

log_level = 10
Expand Down

0 comments on commit 8af45c1

Please sign in to comment.