Skip to content

Commit

Permalink
Custom gradient only for variable attributes
Browse files Browse the repository at this point in the history
Else, constant geometry properties would be traced
  • Loading branch information
holl- committed Dec 30, 2024
1 parent e2c2ace commit 4339df2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,10 +913,10 @@ def __init__(self, f: Callable, gradient: Callable, auxiliary_args: Set[str]):
def _trace(self, in_key: SignatureKey):
def forward_native(*natives):
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=all_attributes)
kwargs = assemble_tree(in_key.tree, in_tensors, attr_type=variable_attributes)
ML_LOGGER.debug(f"Running forward pass of custom op {forward_native.__name__} given args {tuple(kwargs.keys())} containing {len(natives)} native tensors")
result = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
nest, out_tensors = disassemble_tree(result, cache=True, attr_type=all_attributes)
nest, out_tensors = disassemble_tree(result, cache=True, attr_type=variable_attributes)
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(forward_native, nest, result_shapes, specs, in_key.backend, in_key.tracing)
return result_natives
Expand All @@ -928,11 +928,11 @@ def backward_native(x_natives, y_natives, dy_natives):
x_tensors = assemble_tensors(x_natives, in_key.specs)
y_tensors = assemble_tensors(y_natives, out_key.specs)
dy_tensors = assemble_tensors(dy_natives, out_key.specs)
kwargs = assemble_tree(in_key.tree, x_tensors, attr_type=all_attributes)
kwargs = assemble_tree(in_key.tree, x_tensors, attr_type=variable_attributes)
if in_key.auxiliary_kwargs:
kwargs = {**kwargs, **in_key.auxiliary_kwargs}
y = assemble_tree(out_key.tree, y_tensors, attr_type=all_attributes)
dy = assemble_tree(out_key.tree, dy_tensors, attr_type=all_attributes)
y = assemble_tree(out_key.tree, y_tensors, attr_type=variable_attributes)
dy = assemble_tree(out_key.tree, dy_tensors, attr_type=variable_attributes)
result = self.gradient(kwargs, y, dy)
assert isinstance(result, dict) and all(key in kwargs for key in result.keys()), f"gradient function must return a dict containing only parameter names of the forward function. Forward '{f_name(self.f)}' has arguments {kwargs}."
full_result = tuple(result.get(name, None) for name in in_key.tree.keys())
Expand Down

0 comments on commit 4339df2

Please sign in to comment.