diff --git a/phiml/math/_magic_ops.py b/phiml/math/_magic_ops.py index f47d655..9adf94f 100644 --- a/phiml/math/_magic_ops.py +++ b/phiml/math/_magic_ops.py @@ -809,15 +809,6 @@ def value_attributes(obj) -> Tuple[str, ...]: raise ValueError(f"{type(obj).__name__} must implement '__value_attrs__()' or be a dataclass to be used with value functions.") -def variable_values(obj) -> Tuple[str, ...]: - if hasattr(obj, '__variable_attrs__'): - values = obj.__value_attrs__() - variables = obj.__variable_attrs__() - return tuple([a for a in values if a in variables]) - else: - return obj.__value_attrs__() # this takes care of dataclasses as well - - def all_attributes(obj, assert_any=False) -> Tuple[str, ...]: if hasattr(obj, '__all_attrs__'): result = obj.__all_attrs__() diff --git a/phiml/math/_nd.py b/phiml/math/_nd.py index 5a5bc96..8536e2f 100644 --- a/phiml/math/_nd.py +++ b/phiml/math/_nd.py @@ -4,7 +4,7 @@ from . import _ops as math from . import extrapolation as extrapolation -from ._magic_ops import stack, rename_dims, concat, variable_values, tree_map +from ._magic_ops import stack, rename_dims, concat, tree_map, value_attributes from ._ops import choose_backend_t, reshaped_native, reshaped_tensor from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch from ._tensors import Tensor, wrap, tensor, reshaped_numpy @@ -251,7 +251,7 @@ def l1_loss(x, reduce: DimFilter = math.non_batch) -> Tensor: if isinstance(x, Tensor): return math.sum_(abs(x), reduce) elif isinstance(x, PhiTreeNode): - return sum([l1_loss(getattr(x, a), reduce) for a in variable_values(x)]) + return sum([l1_loss(getattr(x, a), reduce) for a in value_attributes(x)]) else: try: backend = math.choose_backend(x) @@ -283,7 +283,7 @@ def l2_loss(x, reduce: DimFilter = math.non_batch) -> Tensor: x = abs(x) return math.sum_(x ** 2, reduce) * 0.5 elif isinstance(x, PhiTreeNode): - return sum([l2_loss(getattr(x, a), reduce) for a in variable_values(x)]) + return sum([l2_loss(getattr(x, a), reduce) for a in value_attributes(x)]) else: try: backend = math.choose_backend(x) @@ -329,7 +329,7 @@ def frequency_loss(x, diff_fft = math.sqrt(math.maximum(diff_fft, threshold)) return l2_loss(diff_fft) if n == 2 else l1_loss(diff_fft) elif isinstance(x, PhiTreeNode): - losses = [frequency_loss(getattr(x, a), frequency_falloff, threshold, ignore_mean, n) for a in variable_values(x)] + losses = [frequency_loss(getattr(x, a), frequency_falloff, threshold, ignore_mean, n) for a in value_attributes(x)] return sum(losses) else: raise ValueError(x)