Skip to content

Commit

Permalink
More complete YAML provider implementation. (#33716)
Browse files Browse the repository at this point in the history
Transforms defined in a provider in YAML are now expanded in
context, which means in particular that they can reference
other transforms (including themselves) that have been
exported.
  • Loading branch information
robertwb authored Jan 22, 2025
1 parent 6c73723 commit 5a2a250
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 11 deletions.
23 changes: 18 additions & 5 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def create_transform(
yaml_create_transform: Callable[
[Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
) -> beam.PTransform:
from apache_beam.yaml.yaml_transform import SafeLineLoader, YamlTransform
from apache_beam.yaml.yaml_transform import expand_jinja, preprocess
from apache_beam.yaml.yaml_transform import SafeLineLoader
spec = self._transforms[type]
try:
import jsonschema
Expand All @@ -440,10 +441,17 @@ def create_transform(
'Please install jsonschema '
f'for better provider validation of "{type}"')
body = spec['body']
if not isinstance(body, str):
body = yaml.safe_dump(SafeLineLoader.strip_metadata(body))
from apache_beam.yaml.yaml_transform import expand_jinja
return YamlTransform(expand_jinja(body, args))
# Stringify to apply jinja.
if isinstance(body, str):
body_str = body
else:
body_str = yaml.safe_dump(SafeLineLoader.strip_metadata(body))
# Now re-parse resolved templatization.
body = yaml.load(expand_jinja(body_str, args), Loader=SafeLineLoader)
if (body.get('type') == 'chain' and 'input' not in body and
spec.get('requires_inputs', True)):
body['input'] = 'input'
return yaml_create_transform(preprocess(body), [])


# This is needed because type inference can't handle *args, **kwargs forwarding.
Expand Down Expand Up @@ -724,6 +732,11 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
redistribute the work) if there is more than one element in the
collection. Defaults to True.
"""
# Though str and dict are technically iterable, we disallow them
# as using the characters or keys respectively is almost certainly
# not the intent.
if not isinstance(elements, Iterable) or isinstance(elements, (dict, str)):
raise TypeError('elements must be a list of elements')
return beam.Create([element_to_rows(e) for e in elements],
reshuffle=reshuffle is not False)

Expand Down
44 changes: 44 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,50 @@ def test_yaml_define_provider(self):
result | beam.Map(lambda x: (x.element, x.power)),
equal_to([(0, 0), (1, 1), (2, 4), (3, 9)]))

def test_recursive(self):
providers = '''
- type: yaml
transforms:
Factorial:
config_schema:
properties:
n: {type: integer}
requires_inputs: false
body: |
{% if n <= 1 %}
type: Create
config:
elements:
- {value: 1}
{% else %}
type: chain
transforms:
- type: Factorial
config:
n: {{n-1}}
- type: MapToFields
name: Multiply
config:
language: python
fields:
value: value * {{n}}
{% endif %}
'''

pipeline = '''
type: Factorial
config:
n: 5
'''

with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
pipeline,
providers=yaml_provider.parse_providers(
'', yaml.load(providers, Loader=SafeLineLoader)))
assert_that(result | beam.Map(lambda x: x.value), equal_to([120]))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
44 changes: 38 additions & 6 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,34 @@ def create_ptransform(self, spec, input_pcolls):
if 'type' not in spec:
raise ValueError(f'Missing transform type: {identify_object(spec)}')

if spec['type'] == 'composite':

class _CompositeTransformStub(beam.PTransform):
@staticmethod
def expand(pcolls):
if isinstance(pcolls, beam.PCollection):
pcolls = {'input': pcolls}
elif isinstance(pcolls, beam.pvalue.PBegin):
pcolls = {}

inner_scope = Scope(
self.root,
pcolls,
spec['transforms'],
self.providers,
self.input_providers)
inner_scope.compute_all()
if '__implicit_outputs__' in spec['output']:
return inner_scope.get_outputs(
spec['output']['__implicit_outputs__'])
else:
return {
key: inner_scope.get_pcollection(value)
for (key, value) in spec['output'].items()
}

return _CompositeTransformStub()

if spec['type'] not in self.providers:
raise ValueError(
'Unknown transform type %r at %s' %
Expand Down Expand Up @@ -344,11 +372,14 @@ def create_ptransform(self, spec, input_pcolls):
spec['type'], config, self.create_ptransform)
# TODO(robertwb): Should we have a better API for adding annotations
# than this?
annotations = dict(
yaml_type=spec['type'],
yaml_args=json.dumps(config),
yaml_provider=json.dumps(provider.to_json()),
**ptransform.annotations())
annotations = {
**{
'yaml_type': spec['type'],
'yaml_args': json.dumps(config),
'yaml_provider': json.dumps(provider.to_json())
},
**ptransform.annotations()
}
ptransform.annotations = lambda: annotations
original_expand = ptransform.expand

Expand Down Expand Up @@ -387,7 +418,8 @@ def unique_name(self, spec, ptransform, strictness=0):
if 'name' in spec:
name = spec['name']
strictness += 1
elif 'ExternalTransform' not in ptransform.label:
elif ('ExternalTransform' not in ptransform.label and
not ptransform.label.startswith('_')):
# The label may have interesting information.
name = ptransform.label
else:
Expand Down

0 comments on commit 5a2a250

Please sign in to comment.