Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 29, 2024
1 parent 3ce24c4 commit 1a1e224
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def vec(name: Union[str, Shape] = 'vector', *sequence, tuple_dim=spatial('sequen
assert not components, "vec() must be given either positional or keyword arguments but not both"
if len(sequence) == 1 and isinstance(sequence[0], (tuple, list)):
sequence = sequence[0]
dim = dim.with_size([str(v) for v in sequence])
item_names = [str(v) for v in sequence]
if len(set(item_names)) == len(item_names):
dim = dim.with_size(item_names)
return wrap(sequence, dim)
else:
def wrap_sequence(value):
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def to_sparse_tracer(tracer: Tensor, ref: Optional[Tensor]) -> SparseLinTracer:
if isinstance(tracer, ShiftLinTracer):
matrix, bias = tracer_to_coo(tracer, sparsify_batch=False, separate_independent=False)
src_dims = dual(matrix) - set(tracer._renamed)
matrix = rename_dims(matrix, src_dims, [n + '_src' for n in src_dims.as_batch().names])
matrix = rename_dims(matrix, src_dims, [f'~{n}_src' for n in src_dims.as_batch().names])
return SparseLinTracer(tracer._source, matrix, bias, tracer.shape)
assert isinstance(tracer, GatherLinTracer)
if tracer._selection is None:
Expand Down

0 comments on commit 1a1e224

Please sign in to comment.