diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl index 2a4096b63945c1..062dad34e3d351 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl @@ -30,6 +30,7 @@ ProtoLangToolchainInfo = provider( protoc_opts = "(list[str]) Options to pass to proto compiler.", progress_message = "(str) Progress message to set on the proto compiler action.", mnemonic = "(str) Mnemonic to set on the proto compiler action.", + toolchain_type = """(Label) Toolchain type that was used to obtain this info""", ), ) @@ -154,6 +155,7 @@ def _compile( use_default_shell_env = True, resource_set = resource_set, exec_group = experimental_exec_group, + toolchain = getattr(proto_lang_toolchain_info, "toolchain_type", None), ) _BAZEL_TOOLS_PREFIX = "external/bazel_tools/" @@ -290,7 +292,6 @@ toolchains = struct( use_toolchain = _use_toolchain, find_toolchain = _find_toolchain, if_legacy_toolchain = _if_legacy_toolchain, - INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), ) proto_common_do_not_use = struct( @@ -299,4 +300,6 @@ proto_common_do_not_use = struct( experimental_should_generate_code = _experimental_should_generate_code, experimental_filter_sources = _experimental_filter_sources, ProtoLangToolchainInfo = ProtoLangToolchainInfo, + INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION = _builtins.toplevel.proto_common.incompatible_enable_proto_toolchain_resolution(), + INCOMPATIBLE_PASS_TOOLCHAIN_TYPE = True, ) diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl index 47b53a5be50098..699636983d60f4 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_lang_toolchain.bzl @@ -14,7 +14,7 @@ """A Starlark implementation of the proto_lang_toolchain rule.""" -load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains") +load(":common/proto/proto_common.bzl", "ProtoLangToolchainInfo", "toolchains", proto_common = "proto_common_do_not_use") load(":common/proto/proto_semantics.bzl", "semantics") ProtoInfo = _builtins.toplevel.ProtoInfo @@ -31,7 +31,7 @@ def _rule_impl(ctx): if ctx.attr.plugin != None: plugin = ctx.attr.plugin[DefaultInfo].files_to_run - if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: proto_compiler = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.proto_compiler protoc_opts = ctx.toolchains[semantics.PROTO_TOOLCHAIN].proto.protoc_opts else: @@ -49,6 +49,7 @@ def _rule_impl(ctx): protoc_opts = protoc_opts, progress_message = ctx.attr.progress_message, mnemonic = ctx.attr.mnemonic, + toolchain_type = ctx.attr.toolchain_type.label if ctx.attr.toolchain_type else None, ) return [ DefaultInfo(files = depset(), runfiles = ctx.runfiles()), @@ -74,7 +75,8 @@ proto_lang_toolchain = rule( "blacklisted_protos": attr.label_list( providers = [ProtoInfo], ), - } | ({} if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { + "toolchain_type": attr.label(), + } | ({} if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION else { "_proto_compiler": attr.label( cfg = "exec", executable = True, diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl index 0d1d9506a863e5..d9943d7a1ea755 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl @@ -251,7 +251,7 @@ def _write_descriptor_set(ctx, direct_sources, deps, exports, proto_info, descri args.add("--allowed_public_imports=") else: args.add_joined("--allowed_public_imports", public_import_protos, map_each = _get_import_path, join_with = ":") - if toolchains.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: + if proto_common.INCOMPATIBLE_ENABLE_PROTO_TOOLCHAIN_RESOLUTION: toolchain = ctx.toolchains[semantics.PROTO_TOOLCHAIN] if not toolchain: fail("Protocol compiler toolchain could not be resolved.") diff --git a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java index f2e3763599dce7..6b6d717807c86d 100644 --- a/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java +++ b/src/test/java/com/google/devtools/build/lib/packages/util/MockProtoSupport.java @@ -268,6 +268,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException { " protoc_opts = ctx.fragments.proto.experimental_protoc_opts,", " progress_message = ctx.attr.progress_message,", " mnemonic = ctx.attr.mnemonic,", + " toolchain_type = '//third_party/bazel_rules/rules_proto/proto:toolchain_type'", " ),", " ),", " ]", @@ -293,7 +294,7 @@ public static void setupWorkspace(MockToolsConfig config) throws IOException { "third_party/bazel_rules/rules_proto/proto/proto_lang_toolchain.bzl", "def proto_lang_toolchain(*, name, toolchain_type = None, exec_compatible_with = [],", " target_compatible_with = [], **attrs):", - " native.proto_lang_toolchain(name = name, **attrs)", + " native.proto_lang_toolchain(name = name, toolchain_type = toolchain_type, **attrs)", " if toolchain_type:", " native.toolchain(", " name = name + '_toolchain',",