Skip to content

Commit

Permalink
Remove variable_values()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 24, 2024
1 parent a5008a5 commit 3fbcead
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 13 deletions.
9 changes: 0 additions & 9 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,15 +809,6 @@ def value_attributes(obj) -> Tuple[str, ...]:
raise ValueError(f"{type(obj).__name__} must implement '__value_attrs__()' or be a dataclass to be used with value functions.")


def variable_values(obj) -> Tuple[str, ...]:
if hasattr(obj, '__variable_attrs__'):
values = obj.__value_attrs__()
variables = obj.__variable_attrs__()
return tuple([a for a in values if a in variables])
else:
return obj.__value_attrs__() # this takes care of dataclasses as well


def all_attributes(obj, assert_any=False) -> Tuple[str, ...]:
if hasattr(obj, '__all_attrs__'):
result = obj.__all_attrs__()
Expand Down
8 changes: 4 additions & 4 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import _ops as math
from . import extrapolation as extrapolation
from ._magic_ops import stack, rename_dims, concat, variable_values, tree_map
from ._magic_ops import stack, rename_dims, concat, tree_map, value_attributes
from ._ops import choose_backend_t, reshaped_native, reshaped_tensor
from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch
from ._tensors import Tensor, wrap, tensor, reshaped_numpy
Expand Down Expand Up @@ -251,7 +251,7 @@ def l1_loss(x, reduce: DimFilter = math.non_batch) -> Tensor:
if isinstance(x, Tensor):
return math.sum_(abs(x), reduce)
elif isinstance(x, PhiTreeNode):
return sum([l1_loss(getattr(x, a), reduce) for a in variable_values(x)])
return sum([l1_loss(getattr(x, a), reduce) for a in value_attributes(x)])
else:
try:
backend = math.choose_backend(x)
Expand Down Expand Up @@ -283,7 +283,7 @@ def l2_loss(x, reduce: DimFilter = math.non_batch) -> Tensor:
x = abs(x)
return math.sum_(x ** 2, reduce) * 0.5
elif isinstance(x, PhiTreeNode):
return sum([l2_loss(getattr(x, a), reduce) for a in variable_values(x)])
return sum([l2_loss(getattr(x, a), reduce) for a in value_attributes(x)])
else:
try:
backend = math.choose_backend(x)
Expand Down Expand Up @@ -329,7 +329,7 @@ def frequency_loss(x,
diff_fft = math.sqrt(math.maximum(diff_fft, threshold))
return l2_loss(diff_fft) if n == 2 else l1_loss(diff_fft)
elif isinstance(x, PhiTreeNode):
losses = [frequency_loss(getattr(x, a), frequency_falloff, threshold, ignore_mean, n) for a in variable_values(x)]
losses = [frequency_loss(getattr(x, a), frequency_falloff, threshold, ignore_mean, n) for a in value_attributes(x)]
return sum(losses)
else:
raise ValueError(x)
Expand Down

0 comments on commit 3fbcead

Please sign in to comment.