Skip to content

Commit

Permalink
Improved shape spec parsing, allow single-letter dims for spatial
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Sep 7, 2024
1 parent 9e5475a commit 539b7a7
Showing 1 changed file with 32 additions and 39 deletions.
71 changes: 32 additions & 39 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def dual(*args, **dims: Union[int, str, tuple, list, Shape, 'Tensor']) -> Shape:
raise AssertionError(f"dual() must be called either as a selector dual(Shape) or dual(Tensor) or as a constructor dual(*names, **dims). Got *args={args}, **dims={dims}")


def auto(spec: Union[str, Shape], default_type=None) -> Shape:
def auto(spec: Union[str, Shape], default_type: Callable = None) -> Shape:
"""
Create a single-dimension `Shape` from a specification string.
Expand All @@ -1803,55 +1803,38 @@ def auto(spec: Union[str, Shape], default_type=None) -> Shape:
return spec # allow multi-dim Shapes as well, as the main application is stacking
assert isinstance(spec, str), f"spec must be a Shape or str but got {type(spec)}"
assert ',' not in spec, f"auto dim only supported for single dimensions"
dim_type = None
dim_name = spec
for dt, char in SUPERSCRIPT.items():
if spec.endswith(char):
dim_type = dt
dim_name = spec[:-1]
break
else:
for dt, char in CHAR.items():
if spec.endswith(':' + char):
dim_type = dt
dim_name = spec[:-2]
break
if spec.startswith('~'):
assert dim_type == DUAL_DIM or dim_type is None, f"Inconsistent dim types for '{spec}'. '~' indicates dual dimension but declared type is {dim_type}"
dim_type = DUAL_DIM
dim_name = dim_name[1:]
if dim_type is None:
assert default_type is not None, f"No dim type specified: '{spec}'"
if callable(default_type):
return default_type(spec)
else:
assert isinstance(default_type, str), f"default_type must be a dimension generator or str but got {type(default_type)}"
dim_type = default_type
return _construct_shape(dim_type, dim_name)
return parse_shape_spec(spec, default_type=default_type)


class InvalidShapeSpec(ValueError):
pass


def parse_shape_spec(input_string) -> Shape:
name_type_items = re.compile(r'(~?)(\w+):(\w+)=\(([^)]*)\)')
name_type = re.compile(r'(~?)(\w+):(\w+)')
name_items = re.compile(r'(~?)(\w+)=\(([^)]*)\)')
items = re.compile(r'(~?)\(([^)]*)\)')
dual_name = re.compile(r'(~\w+)')
SPEC_PATTERNS = {
'name_type_items': re.compile(r'(~?)(\w+):(\w+)=\(([^)]*)\)'),
'name_type': re.compile(r'(~?)(\w+):(\w+)'),
'name_items': re.compile(r'(~?)(\w+)=\(([^)]*)\)'),
'items': re.compile(r'(~?)\(([^)]*)\)'),
'dual_name': re.compile(r'(~\w+)'),
'single_letter': re.compile(r'(\w)(?=,|$)'),
'name_only': re.compile(r'(\w+)(?=,|$)')
}


def parse_shape_spec(input_string, default_type: Callable = None) -> Shape:
results = []
pos = 0
while pos < len(input_string):
if match := name_type_items.match(input_string, pos):
if match := SPEC_PATTERNS['name_type_items'].match(input_string, pos):
tilde, name, type_, values = match.groups()
if tilde and type_ not in ('d', 'dual'):
raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}")
elif not tilde and type_ in ('d', 'dual'):
raise InvalidShapeSpec(input_string, f"Dual dims must start with ~. Failed at index {pos}: {input_string[pos:]}")
results.append({'name': '~' + name if tilde else name, 'type': type_, 'values': values.split(',')})
items = [n.strip() for n in values.split(',') if n.strip()]
results.append({'name': '~' + name if tilde else name, 'type': type_, 'values': items})
pos = match.end() + 1
elif match := name_type.match(input_string, pos):
elif match := SPEC_PATTERNS['name_type'].match(input_string, pos):
tilde, name, type_ = match.groups()
if tilde and type_ not in ('d', 'dual'):
raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}")
Expand All @@ -1863,18 +1846,28 @@ def parse_shape_spec(input_string) -> Shape:
raise ValueError(f"Invalid format at position {pos}: values must be inside parentheses")
results.append({'name': '~' + name if tilde else name, 'type': type_})
pos = match.end() + 1
elif match := name_items.match(input_string, pos):
elif match := SPEC_PATTERNS['name_items'].match(input_string, pos):
tilde, name, values = match.groups()
results.append({'name': '~' + name if tilde else name, 'type': 'd' if tilde else 'c', 'values': values.split(',')})
items = [n.strip() for n in values.split(',') if n.strip()]
results.append({'name': '~' + name if tilde else name, 'type': 'd' if tilde else 'c', 'values': items})
pos = match.end() + 1
elif match := items.match(input_string, pos):
elif match := SPEC_PATTERNS['items'].match(input_string, pos):
tilde, values = match.groups()
results.append({'name': '~vector' if tilde else 'vector', 'type': 'd' if tilde else 'c', 'values': values.split(',')})
pos = match.end() + 1
elif match := dual_name.match(input_string, pos):
elif match := SPEC_PATTERNS['dual_name'].match(input_string, pos):
name, = match.groups()
results.append({'name': name, 'type': 'd'})
pos = match.end() + 1
elif match := SPEC_PATTERNS['single_letter'].match(input_string, pos):
name, = match.groups()
results.append({'name': name, 'type': 's'})
pos = match.end() + 1
elif default_type is not None and (match := SPEC_PATTERNS['name_only'].match(input_string, pos)):
name, = match.groups()
default_type_str = TYPE_BY_FUNCTION[default_type]
results.append({'name': name, 'type': default_type_str})
pos = match.end() + 1
else:
raise InvalidShapeSpec(input_string, f"Failed to parse from index {pos}: '{input_string[pos:]}'. Dims must be specified as name:type or name:type=(item_names...). Names and types may only be omitted if component names are given.")
names = [r['name'] for r in results]
Expand Down

0 comments on commit 539b7a7

Please sign in to comment.