diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index 7429859..31d6a1b 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -1784,7 +1784,7 @@ def dual(*args, **dims: Union[int, str, tuple, list, Shape, 'Tensor']) -> Shape: raise AssertionError(f"dual() must be called either as a selector dual(Shape) or dual(Tensor) or as a constructor dual(*names, **dims). Got *args={args}, **dims={dims}") -def auto(spec: Union[str, Shape], default_type=None) -> Shape: +def auto(spec: Union[str, Shape], default_type: Callable = None) -> Shape: """ Create a single-dimension `Shape` from a specification string. @@ -1803,55 +1803,38 @@ def auto(spec: Union[str, Shape], default_type=None) -> Shape: return spec # allow multi-dim Shapes as well, as the main application is stacking assert isinstance(spec, str), f"spec must be a Shape or str but got {type(spec)}" assert ',' not in spec, f"auto dim only supported for single dimensions" - dim_type = None - dim_name = spec - for dt, char in SUPERSCRIPT.items(): - if spec.endswith(char): - dim_type = dt - dim_name = spec[:-1] - break - else: - for dt, char in CHAR.items(): - if spec.endswith(':' + char): - dim_type = dt - dim_name = spec[:-2] - break - if spec.startswith('~'): - assert dim_type == DUAL_DIM or dim_type is None, f"Inconsistent dim types for '{spec}'. '~' indicates dual dimension but declared type is {dim_type}" - dim_type = DUAL_DIM - dim_name = dim_name[1:] - if dim_type is None: - assert default_type is not None, f"No dim type specified: '{spec}'" - if callable(default_type): - return default_type(spec) - else: - assert isinstance(default_type, str), f"default_type must be a dimension generator or str but got {type(default_type)}" - dim_type = default_type - return _construct_shape(dim_type, dim_name) + return parse_shape_spec(spec, default_type=default_type) class InvalidShapeSpec(ValueError): pass -def parse_shape_spec(input_string) -> Shape: - name_type_items = re.compile(r'(~?)(\w+):(\w+)=\(([^)]*)\)') - name_type = re.compile(r'(~?)(\w+):(\w+)') - name_items = re.compile(r'(~?)(\w+)=\(([^)]*)\)') - items = re.compile(r'(~?)\(([^)]*)\)') - dual_name = re.compile(r'(~\w+)') +SPEC_PATTERNS = { + 'name_type_items': re.compile(r'(~?)(\w+):(\w+)=\(([^)]*)\)'), + 'name_type': re.compile(r'(~?)(\w+):(\w+)'), + 'name_items': re.compile(r'(~?)(\w+)=\(([^)]*)\)'), + 'items': re.compile(r'(~?)\(([^)]*)\)'), + 'dual_name': re.compile(r'(~\w+)'), + 'single_letter': re.compile(r'(\w)(?=,|$)'), + 'name_only': re.compile(r'(\w+)(?=,|$)') +} + + +def parse_shape_spec(input_string, default_type: Callable = None) -> Shape: results = [] pos = 0 while pos < len(input_string): - if match := name_type_items.match(input_string, pos): + if match := SPEC_PATTERNS['name_type_items'].match(input_string, pos): tilde, name, type_, values = match.groups() if tilde and type_ not in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}") elif not tilde and type_ in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dual dims must start with ~. Failed at index {pos}: {input_string[pos:]}") - results.append({'name': '~' + name if tilde else name, 'type': type_, 'values': values.split(',')}) + items = [n.strip() for n in values.split(',') if n.strip()] + results.append({'name': '~' + name if tilde else name, 'type': type_, 'values': items}) pos = match.end() + 1 - elif match := name_type.match(input_string, pos): + elif match := SPEC_PATTERNS['name_type'].match(input_string, pos): tilde, name, type_ = match.groups() if tilde and type_ not in ('d', 'dual'): raise InvalidShapeSpec(input_string, f"Dimension names starting with ~ must be of type dual. Failed at index {pos}: {input_string[pos:]}") @@ -1863,18 +1846,28 @@ def parse_shape_spec(input_string) -> Shape: raise ValueError(f"Invalid format at position {pos}: values must be inside parentheses") results.append({'name': '~' + name if tilde else name, 'type': type_}) pos = match.end() + 1 - elif match := name_items.match(input_string, pos): + elif match := SPEC_PATTERNS['name_items'].match(input_string, pos): tilde, name, values = match.groups() - results.append({'name': '~' + name if tilde else name, 'type': 'd' if tilde else 'c', 'values': values.split(',')}) + items = [n.strip() for n in values.split(',') if n.strip()] + results.append({'name': '~' + name if tilde else name, 'type': 'd' if tilde else 'c', 'values': items}) pos = match.end() + 1 - elif match := items.match(input_string, pos): + elif match := SPEC_PATTERNS['items'].match(input_string, pos): tilde, values = match.groups() results.append({'name': '~vector' if tilde else 'vector', 'type': 'd' if tilde else 'c', 'values': values.split(',')}) pos = match.end() + 1 - elif match := dual_name.match(input_string, pos): + elif match := SPEC_PATTERNS['dual_name'].match(input_string, pos): name, = match.groups() results.append({'name': name, 'type': 'd'}) pos = match.end() + 1 + elif match := SPEC_PATTERNS['single_letter'].match(input_string, pos): + name, = match.groups() + results.append({'name': name, 'type': 's'}) + pos = match.end() + 1 + elif default_type is not None and (match := SPEC_PATTERNS['name_only'].match(input_string, pos)): + name, = match.groups() + default_type_str = TYPE_BY_FUNCTION[default_type] + results.append({'name': name, 'type': default_type_str}) + pos = match.end() + 1 else: raise InvalidShapeSpec(input_string, f"Failed to parse from index {pos}: '{input_string[pos:]}'. Dims must be specified as name:type or name:type=(item_names...). Names and types may only be omitted if component names are given.") names = [r['name'] for r in results]