diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 2ca145f..f7e63fd 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -913,10 +913,10 @@ def __init__(self, f: Callable, gradient: Callable, auxiliary_args: Set[str]): def _trace(self, in_key: SignatureKey): def forward_native(*natives): in_tensors = assemble_tensors(natives, in_key.specs) - kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=all_attributes) + kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=variable_attributes) ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors") result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors - nest, out_tensors = disassemble_tree(result, cache=True, attr_type=all_attributes) + nest, out_tensors = disassemble_tree(result, cache=True, attr_type=variable_attributes) result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True) self.recorded_mappings[in_key] = SignatureKey(forward_native, nest, result_shapes, specs, in_key.backend, in_key.tracing) return result_natives @@ -928,11 +928,11 @@ def backward_native(x_natives, y_natives, dy_natives): x_tensors = assemble_tensors(x_natives, in_key.specs) y_tensors = assemble_tensors(y_natives, out_key.specs) dy_tensors = assemble_tensors(dy_natives, out_key.specs) - kwargs = assemble_tree(in_key.tree, x_tensors, attr_type=all_attributes) + kwargs = assemble_tree(in_key.tree, x_tensors, attr_type=variable_attributes) if in_key.auxiliary_kwargs: kwargs = {**kwargs, **in_key.auxiliary_kwargs} - y = assemble_tree(out_key.tree, y_tensors, attr_type=all_attributes) - dy = assemble_tree(out_key.tree, dy_tensors, attr_type=all_attributes) + y = assemble_tree(out_key.tree, y_tensors, attr_type=variable_attributes) + dy = assemble_tree(out_key.tree, dy_tensors, attr_type=variable_attributes) result = self.gradient(kwargs, y, dy) assert isinstance(result, dict) and all(key in kwargs for key in result.keys()), f"gradient function must return a dict containing only parameter names of the forward function. Forward '{f_name(self.f)}' has arguments {kwargs}." full_result = tuple(result.get(name, None) for name in in_key.tree.keys())