Skip to content

Commit

Permalink
Support automatic exec groups in proto_common.compile
Browse files Browse the repository at this point in the history
Pass toolchain_type through ProtoLangToolchainInfo into proto_common.compile and use it on ctx.actions.run. Automatic exec groups require that toolchain type is set on the `ctx.actions.run`. This information is used to select correct execution platform.

Expose INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION in proto_common. This will be needed to support lang_proto_libraries that are not part of Bazel. For example py_proto_library. Other methods in `toolchains` struct in proto_common.bzl, are both temporary and can be written in Starlark, so don't expose them. It's possible to access the value in backward compatible manner (that is with `getattr(proto_common, ...)`).

Expose INCOMPATIBLE_PASS_TOOLCHAIN_TYPE in proto_common. Second "flag" is here to mark, that builtin `proto_lang_toolchain` rule has a `toolchain_type` attribute. This way `proto_lang_toolchain` macro can pass the value in a compatible fashion with older Bazel. This should make toolchainisation work with older versions of Bazel that don't know about automatic exec groups and don't need to pass in the value.

Issue: bazelbuild/rules_proto#179
PiperOrigin-RevId: 571876657
Change-Id: I543ab862b318c9062d40430160e33ad197973094
  • Loading branch information
comius committed Jan 17, 2024
1 parent f30e59a commit 29b1f95
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ ProtoLangToolchainInfo = provider(
protoc_opts = "(list[str]) Options to pass to proto compiler.",
progress_message = "(str) Progress message to set on the proto compiler action.",
mnemonic = "(str) Mnemonic to set on the proto compiler action.",
toolchain_type = """(Label) Toolchain type that was used to obtain this info""",
),
)

Expand Down Expand Up @@ -154,6 +155,7 @@ def _compile(
use_default_shell_env = True,
resource_set = resource_set,
exec_group = experimental_exec_group,
toolchain = getattr(proto_lang_toolchain_info, "toolchain_type", None),
)

_BAZEL_TOOLS_PREFIX = "external/bazel_tools/"
Expand Down Expand Up @@ -290,7 +292,6 @@ toolchains = struct(
use_toolchain = _use_toolchain,
find_toolchain = _find_toolchain,
if_legacy_toolchain = _if_legacy_toolchain,
INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(),
)

proto_common_do_not_use = struct(
Expand All @@ -299,4 +300,6 @@ proto_common_do_not_use = struct(
experimental_should_generate_code = _experimental_should_generate_code,
experimental_filter_sources = _experimental_filter_sources,
ProtoLangToolchainInfo = ProtoLangToolchainInfo,
INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(),
INCOMPATIBLE_PASS_TOOLCHAIN_TYPE = True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""A Starlark implementation of the proto_lang_toolchain rule."""

load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains")
load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains", proto_common = "proto_common_do_not_use")
load(":common/proto/proto_semantics.bzl", "semantics")

ProtoInfo = _builtins.toplevel.ProtoInfo
Expand All @@ -31,7 +31,7 @@ def _rule_impl(ctx):
if ctx.attr.plugin != None:
plugin = ctx.attr.plugin[DefaultInfo].files_to_run

if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.proto_compiler
protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.protoc_opts
else:
Expand All @@ -49,6 +49,7 @@ def _rule_impl(ctx):
protoc_opts = protoc_opts,
progress_message = ctx.attr.progress_message,
mnemonic = ctx.attr.mnemonic,
toolchain_type = ctx.attr.toolchain_type.label if ctx.attr.toolchain_type else None,
)
return [
DefaultInfo(files = depset(), runfiles = ctx.runfiles()),
Expand All @@ -74,7 +75,8 @@ proto_lang_toolchain = rule(
"blacklisted_protos": attr.label_list(
providers = [ProtoInfo],
),
} | ({} if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else {
"toolchain_type": attr.label(),
} | ({} if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else {
"_proto_compiler": attr.label(
cfg = "exec",
executable = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _write_descriptor_set(ctx, direct_sources, deps, exports, proto_info, descri
args.add("--allowed_public_imports=")
else:
args.add_joined("--allowed_public_imports", public_import_protos, map_each = _get_import_path, join_with = ":")
if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION:
toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN]
if not toolchain:
fail("Protocol compiler toolchain could not be resolved.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException {
" protoc_opts = ctx.fragments.proto.experimental_protoc_opts,",
" progress_message = ctx.attr.progress_message,",
" mnemonic = ctx.attr.mnemonic,",
" toolchain_type = '//third_party/bazel_rules/rules_proto/proto:toolchain_type'",
" ),",
" ),",
" ]",
Expand All @@ -293,7 +294,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException {
"third_party/bazel_rules/rules_proto/proto/proto_lang_toolchain.bzl",
"def proto_lang_toolchain(*, name, toolchain_type = None, exec_compatible_with = [],",
" target_compatible_with = [], **attrs):",
" native.proto_lang_toolchain(name = name, **attrs)",
" native.proto_lang_toolchain(name = name, toolchain_type = toolchain_type, **attrs)",
" if toolchain_type:",
" native.toolchain(",
" name = name + '_toolchain',",
Expand Down

0 comments on commit 29b1f95

Please sign in to comment.