Skip to content

Commit

Permalink
Tracing fixes
Browse files Browse the repository at this point in the history
* Clean up tracing state on error
* Fix linear trace with renamed dims
  • Loading branch information
holl- committed Nov 11, 2023
1 parent cd1e213 commit 13fc568
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
24 changes: 14 additions & 10 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,16 @@ def _jit_compile(self, in_key: SignatureKey):
def jit_f_native(*natives):
ML_LOGGER.debug(f"Φ-ML-jit: Tracing '{f_name(self.f)}'")
_TRACING_JIT.append(self)
self._tracing_in_key = in_key
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors)
f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors))
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing)
assert _TRACING_JIT.pop(-1) is self
try:
self._tracing_in_key = in_key
in_tensors = assemble_tensors(natives, in_key.specs)
kwargs = assemble_tree(in_key.tree, in_tensors)
f_output = self.f(**kwargs, **in_key.auxiliary_kwargs) # Tensor or tuple/list of Tensors
tree, out_tensors = disassemble_tree((f_output, self._extract_tensors))
result_natives, result_shapes, specs = disassemble_tensors(out_tensors, expand=True)
self.recorded_mappings[in_key] = SignatureKey(jit_f_native, tree, result_shapes, specs, in_key.backend, in_key.tracing)
finally:
assert _TRACING_JIT.pop(-1) is self
self._tracing_in_key = None
return result_natives

Expand Down Expand Up @@ -309,8 +311,10 @@ def _get_or_trace(self, key: SignatureKey, args: tuple, f_kwargs: dict):
if self.forget_traces:
self.matrices_and_biases.clear()
_TRACING_LINEAR.append(self)
matrix, bias = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True)
assert _TRACING_LINEAR.pop(-1) is self
try:
matrix, bias = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True)
finally:
assert _TRACING_LINEAR.pop(-1) is self
if not key.tracing:
self.matrices_and_biases[key] = matrix, bias
if len(self.matrices_and_biases) >= 4:
Expand Down
3 changes: 2 additions & 1 deletion phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,8 @@ def tracer_to_coo(tracer: Tensor, sparsify_batch: bool, separate_independent: bo
offset = shift_.get_size(dim, default=0)
src_idx_all.append(np.zeros_like(src_idx[0]) + offset)
else:
src_idx_all.append(src_idx[out_shape.index(dim)])
out_dim = {v: k for k, v in tracer._renamed.items()}.get(dim.name, dim.name)
src_idx_all.append(src_idx[out_shape.index(out_dim)])
src_indices.append(src_idx_all)
indices_np = np.concatenate([np.concatenate(src_indices, axis=1), np.concatenate(out_indices, axis=1)]).T
indices = wrap(indices_np, instance('entries'), channel(vector=(sliced_src_shape if separate_independent else src_shape).names + out_shape.names))
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,8 @@ def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, alre
return value[{dim: slice(-width, None)}]

def _pad_linear_tracer(self, value: 'ShiftLinTracer', widths: dict) -> 'ShiftLinTracer':
if value.shape.get_sizes(tuple(widths.keys())) != value.source.shape.get_sizes(tuple(widths.keys())):
raise NotImplementedError("Periodicity does not match input: %s but input has %s. This can happen when padding an already padded or sliced tensor." % (value.shape.only(tuple(widths.keys())), value.source.shape.only(tuple(widths.keys()))))
if value.shape.get_sizes(tuple(widths.keys())) != value._source.shape.get_sizes(tuple(widths.keys())):
raise NotImplementedError("Periodicity does not match input: %s but input has %s. This can happen when padding an already padded or sliced tensor." % (value.shape.only(tuple(widths.keys())), value._source.shape.only(tuple(widths.keys()))))
lower = {dim: -lo for dim, (lo, _) in widths.items()}
return value.shift(lower, new_shape=value.shape.after_pad(widths), val_fun=lambda v: self.pad(v, widths), bias_fun=lambda b: ZERO.pad(b, widths))

Expand Down

0 comments on commit 13fc568

Please sign in to comment.