Skip to content

Commit

Permalink
Use partial matmul when only some dims match
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 23, 2023
1 parent d9149be commit 1582dea
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,19 +732,19 @@ def __iter__(self):

def __matmul__(self, other):
assert isinstance(other, Tensor), f"Matmul '@' requires two Tensor arguments but got {type(other)}"
dims = self.shape.dual.as_batch().names
if not dims: # this is not a matrix
match_names = self.shape.dual.as_batch().names
if not match_names: # this is not a matrix
assert self.shape.primal.only(other.shape).is_empty, f"Cannot compute matmul {self.shape} @ {other.shape}. First argument is not a matrix; it has no dual dimensions."
return self * other
match = other.shape.only(dims, reorder=True)
if not match:
match_primal = other.shape.only(match_names, reorder=True)
if not match_primal:
assert non_batch(other).non_dual.rank == 1, f"Cannot multiply {self.shape} @ {other.shape} because arg2 does not have appropriate non-dual dimensions"
match = non_batch(other).non_dual
assert len(dims) == match.rank, f"Dual dimensions {dual} do not match shape of second argument {other.shape}"
left_arg = pack_dims(self, dual, dual('_reduce')) if len(dims) > 1 else self
right_arg = pack_dims(other, match, channel('_reduce'))
match_primal = non_batch(other).non_dual
match_dual = self.shape.dual.only(match_primal.as_dual(), reorder=True)
left_arg = pack_dims(self, match_dual, dual('_reduce'))
right_arg = pack_dims(other, match_primal, channel('_reduce'))
from ._ops import dot
return dot(left_arg, dual, right_arg, '_reduce')
return dot(left_arg, '~_reduce', right_arg, '_reduce')

# def __rmatmul__(self, other):

Expand Down

0 comments on commit 1582dea

Please sign in to comment.