Skip to content

Commit

Permalink
Fix derivative checkers
Browse files Browse the repository at this point in the history
  • Loading branch information
cgokmen committed Dec 12, 2024
1 parent 281f89b commit b37c0f3
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions bddl/knowledge_base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,38 +415,28 @@ def is_derivative(self):
return any(self.name.startswith(dw) for dw in derivative_words)

@cached_property
def derivative_parent(self):
def derivative_parents(self):
# Check if the synset is a derivative
if not self.is_derivative:
return None

# Find the synset that has this as a child
parent_candidates = [s for s in Synset.all_objects() if self in s.derivative_children]
assert len(parent_candidates) == 1, f"Expected 1 parent, got {parent_candidates}"
return parent_candidates[0]

@cached_property
def derivative_root(self):
if not self.derivative_parent:
return None

parent_root = self.derivative_parent.derivative_root
return parent_root if parent_root else self.derivative_parent
return {s for s in Synset.all_objects() if self in s.derivative_children}

@cached_property
def derivative_children_names(self):
sliceable_children = [
Synset.get(json.loads(p.parameters)["sliceable_derivative_synset"])
json.loads(p.parameters)["sliceable_derivative_synset"]
for p in self.properties
if p.name == "sliceable"
]
diceable_children = [
Synset.get(json.loads(p.parameters)["uncooked_diceable_derivative_synset"])
json.loads(p.parameters)["uncooked_diceable_derivative_synset"]
for p in self.properties
if p.name == "diceable"
]
cookable_children = [
Synset.get(json.loads(p.parameters)["substance_cooking_derivative_synset"])
json.loads(p.parameters)["substance_cooking_derivative_synset"]
for p in self.properties
if p.name == "cookable" and self.state == STATE_SUBSTANCE
]
Expand All @@ -458,11 +448,13 @@ def derivative_children(self):

@cached_property
def derivative_ancestors(self):
if not self.derivative_parent:
if not self.derivative_parents:
return {self}
return {self, self.derivative_parent} | set(
self.derivative_parent.derivative_ancestors
)
return {self} | self.derivative_parents | {
ancestor
for parent in self.derivative_parents
for ancestor in parent.derivative_ancestors
}

@cached_property
def derivative_descendants(self):
Expand Down Expand Up @@ -522,7 +514,7 @@ def view_unnecessary(cls):
def view_bad_derivative(cls):
"""Derivative synsets that exist even though the original synset is missing the expected property"""
return [
s for s in cls.all_objects() if s.is_derivative and not s.derivative_parent
s for s in cls.all_objects() if s.is_derivative and not s.derivative_parents
]

@classmethod
Expand Down

0 comments on commit b37c0f3

Please sign in to comment.