Skip to content

Commit

Permalink
Resolve indirect cached_property dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 9, 2024
1 parent 21534a8 commit 79324a7
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions phiml/dataclasses/_dep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import dataclasses
import inspect
from functools import cached_property
from typing import Set
Expand All @@ -21,7 +22,10 @@ def get_dependencies(cls: type, cls_property) -> Set[str]:
return cls.__phiml_dep__[cls_property]
if isinstance(cls_property, cached_property):
method = cls_property.func
elif isinstance(cls_property, property):
method = cls_property.fget
else:
assert callable(cls_property) and hasattr(cls_property, '__qualname__'), f"Dependency resolver failed on {cls_property} of {cls.__name__}"
method = cls_property
ML_LOGGER.debug(f"Analyzing dependencies of {method.__qualname__}")
source_code = inspect.getsource(method)
Expand All @@ -30,12 +34,25 @@ def get_dependencies(cls: type, cls_property) -> Set[str]:
tree = ast.parse(source_code_top)
analyzer = MemberVariableAnalyzer()
analyzer.visit(tree)
result = analyzer.member_vars
direct_deps = analyzer.member_vars
fields = set([f.name for f in dataclasses.fields(cls)])
field_deps = direct_deps & fields
for prop_dep in direct_deps - fields:
if not hasattr(cls, prop_dep): # may be a dynamic dim, such as vector
if hasattr(cls, 'shape'):
prop_dep = 'shape'
else:
field_deps.update(fields) # automatic shape() depends on all attributes
continue
if isinstance(getattr(cls, prop_dep), (property, cached_property)):
field_deps.update(get_dependencies(cls, getattr(cls, prop_dep)))
elif callable(getattr(cls, prop_dep)):
raise NotImplementedError
if not hasattr(cls, '__phiml_dep__'):
cls.__phiml_dep__ = {cls_property: result}
cls.__phiml_dep__ = {cls_property: field_deps}
else:
cls.__phiml_dep__[cls_property] = result
return result
cls.__phiml_dep__[cls_property] = field_deps
return field_deps


class MemberVariableAnalyzer(ast.NodeVisitor):
Expand Down

0 comments on commit 79324a7

Please sign in to comment.