Skip to content

Commit

Permalink
Improve RDC (#167)
Browse files Browse the repository at this point in the history
- Remove rdc output from cuda_library

  It is incorrect in the case of whole archive linking and prevent us from creating a shared library.

- Better handling of rdc objects archiving logic

- Make transitive rdc cuda_library correct
  • Loading branch information
cloudhan authored Oct 6, 2023
1 parent d3b9ab7 commit 894603f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 42 deletions.
19 changes: 18 additions & 1 deletion cuda/private/cuda_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,31 @@ def _create_common(ctx):
transitive_linking_contexts = transitive_linking_contexts,
)

def _create_cuda_info(defines = None, objects = None, rdc_objects = None, pic_objects = None, rdc_pic_objects = None):
def _create_cuda_info(
defines = None,
objects = None,
rdc_objects = None,
pic_objects = None,
rdc_pic_objects = None,
archive_objects = None,
archive_rdc_objects = None,
archive_pic_objects = None,
archive_rdc_pic_objects = None,
dlink_rdc_objects = None,
dlink_rdc_pic_objects = None):
"""Constructor for `CudaInfo`. See the providers documentation for detail."""
ret = CudaInfo(
defines = defines if defines != None else depset([]),
objects = objects if objects != None else depset([]),
rdc_objects = rdc_objects if rdc_objects != None else depset([]),
pic_objects = pic_objects if pic_objects != None else depset([]),
rdc_pic_objects = rdc_pic_objects if rdc_pic_objects != None else depset([]),
archive_objects = archive_objects if archive_objects != None else depset([]),
archive_rdc_objects = archive_rdc_objects if archive_rdc_objects != None else depset([]),
archive_pic_objects = archive_pic_objects if archive_pic_objects != None else depset([]),
archive_rdc_pic_objects = archive_rdc_pic_objects if archive_rdc_pic_objects != None else depset([]),
dlink_rdc_objects = dlink_rdc_objects if dlink_rdc_objects != None else depset([]),
dlink_rdc_pic_objects = dlink_rdc_pic_objects if dlink_rdc_pic_objects != None else depset([]),
)
return ret

Expand Down
23 changes: 18 additions & 5 deletions cuda/private/providers.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,26 @@ CudaInfo = provider(
"""Provides cuda build artifacts that can be consumed by device linking or linking process.
This provider is analog to [CcInfo](https://bazel.build/rules/lib/CcInfo) but only contains necessary information for
linking in a flat structure.""",
linking in a flat structure. Objects are grouped by direct and transitive, because we have no way to split them again
if merged a single depset.
""",
fields = {
"defines": "A depset of strings. It is used for the compilation during device linking.",
"objects": "A depset of objects.", # but not rdc and pic
"rdc_objects": "A depset of relocatable device code objects.", # but not pic
"pic_objects": "A depset of position indepentent code objects.", # but not rdc
"rdc_pic_objects": "A depset of relocatable device code and position indepentent code objects.",
# direct only:
"objects": "A depset of objects. Direct artifacts of the rule.", # but not rdc and pic
"pic_objects": "A depset of position indepentent code objects. Direct artifacts of the rule.", # but not rdc
"rdc_objects": "A depset of relocatable device code objects. Direct artifacts of the rule.", # but not pic
"rdc_pic_objects": "A depset of relocatable device code and position indepentent code objects. Direct artifacts of the rule.",
# transitive archive only (cuda_objects):
"archive_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving.",
"archive_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving.",
"archive_rdc_objects": "A depset of rdc objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.",
"archive_rdc_pic_objects": "A depset of rdc pic objects. cuda_objects only. Gathered from the transitive dependencies for archiving or device linking.",

# transitive dlink only (cuda_library):
# NOTE: ideally, we can use the archived library to do the device linking, but the nvlink is not happy with library with *_dlink.o included
"dlink_rdc_objects": "A depset of rdc objects. cuda_library only. Gathered from the transitive dependencies for device linking.",
"dlink_rdc_pic_objects": "A depset of rdc pic objects. cuda_library only. Gathered from the transitive dependencies for device linking.",
},
)

Expand Down
87 changes: 61 additions & 26 deletions cuda/private/rules/cuda_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,59 @@ def _cuda_library_impl(ctx):
for src in ctx.attr.srcs:
src_files.extend(src[DefaultInfo].files.to_list())

# outputs
objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = use_rdc))
pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = use_rdc))
rdc_objects = depset([])
rdc_pic_objects = depset([])
# merge deps' direct objects and archive objects as our archive objects
archive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_objects for dep in attr.deps if CudaInfo in dep])
archive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_pic_objects for dep in attr.deps if CudaInfo in dep])
archive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_rdc_objects for dep in attr.deps if CudaInfo in dep])
archive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_rdc_pic_objects for dep in attr.deps if CudaInfo in dep])

# if rdc is enabled for this cuda_library, then we need futher do a pass of device link
# Gather transitive dlink objects that may come from other `cuda_library`s
dlink_rdc_objects = depset(transitive = [dep[CudaInfo].dlink_rdc_objects for dep in attr.deps if CudaInfo in dep])
dlink_rdc_pic_objects = depset(transitive = [dep[CudaInfo].dlink_rdc_pic_objects for dep in attr.deps if CudaInfo in dep])

# direct outputs
objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False)) if not use_rdc else depset([])
pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False)) if not use_rdc else depset([])
rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True)) if use_rdc else depset([])
rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True)) if use_rdc else depset([])

# if rdc is enabled for this `cuda_library`, then we need to do a pass of device link further.
rdc_dlink_inputs = None
rdc_pic_dlink_inputs = None
if use_rdc:
transitive_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep])
transitive_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep])
objects = depset(transitive = [objects, transitive_objects])
rdc_objects = objects
pic_objects = depset(transitive = [pic_objects, transitive_pic_objects])
rdc_pic_objects = pic_objects
dlink_object = depset([device_link(ctx, cuda_toolchain, cc_toolchain, objects, common, pic = False, rdc = use_rdc)])
dlink_pic_object = depset([device_link(ctx, cuda_toolchain, cc_toolchain, pic_objects, common, pic = True, rdc = use_rdc)])
objects = depset(transitive = [objects, dlink_object])
pic_objects = depset(transitive = [pic_objects, dlink_pic_object])
# TODO: Switch to explicit dlink with attr `dlink=True`, then add support dlink with libraries. At the moment,
# all libraries produced by this rule with `rdc=True` will have an <name>_dlink.<infix>.o archived, and nvlink
# refuses to consume such libraries and ignores them silently.

# prepare inputs for device_link, take use_rdc=True and non-pic as an example:
# rdc_objects: produce with this rule
# archive_rdc_objects: propagate from other `cuda_objects`
# dlink_rdc_objects: propagate from other `cuda_library`s
rdc_dlink_inputs = depset(transitive = [rdc_objects, archive_rdc_objects, dlink_rdc_objects])
rdc_pic_dlink_inputs = depset(transitive = [rdc_pic_objects, archive_rdc_pic_objects, dlink_rdc_pic_objects])

rdc_dlink_output = depset([device_link(ctx, cuda_toolchain, cc_toolchain, rdc_dlink_inputs, common, pic = False, rdc = True)])
rdc_pic_dlink_output = depset([device_link(ctx, cuda_toolchain, cc_toolchain, rdc_pic_dlink_inputs, common, pic = True, rdc = True)])

# update the **direct** outputs
rdc_objects = depset(transitive = [rdc_objects, rdc_dlink_output])
rdc_pic_objects = depset(transitive = [rdc_pic_objects, rdc_pic_dlink_output])

# objects to archive: objects directly outputed by this rule and all objects transitively from deps,
# take use_rdc=True and non-pic as an example:
# rdc_objects: produce with this rule, thus it must be archived in the library produced by this rule
# archive_rdc_objects: propagate from other `cuda_objects`, so this rule is in charge of archiving them
# dlink_rdc_objects is NOT included!
if not use_rdc:
archive_content = depset(transitive = [objects, archive_objects])
pic_archive_content = depset(transitive = [pic_objects, archive_pic_objects])
else:
archive_content = depset(transitive = [rdc_objects, archive_rdc_objects])
pic_archive_content = depset(transitive = [rdc_pic_objects, archive_rdc_pic_objects])

compilation_ctx = cc_common.create_compilation_context(
headers = common.headers,
Expand All @@ -67,7 +102,7 @@ def _cuda_library_impl(ctx):
actions = ctx.actions,
feature_configuration = cc_feature_config,
cc_toolchain = cc_toolchain,
compilation_outputs = cc_common.create_compilation_outputs(objects = objects, pic_objects = pic_objects),
compilation_outputs = cc_common.create_compilation_outputs(objects = archive_content, pic_objects = pic_archive_content),
user_link_flags = common.host_link_flags,
alwayslink = attr.alwayslink,
linking_contexts = common.transitive_linking_contexts,
Expand All @@ -82,7 +117,10 @@ def _cuda_library_impl(ctx):
libs = [] if lib == None else [lib]
pic_libs = [] if pic_lib == None else [pic_lib]

cc_info = cc_common.merge_cc_infos(direct_cc_infos = [CcInfo(compilation_context = compilation_ctx, linking_context = linking_ctx)], cc_infos = [common.transitive_cc_info])
cc_info = cc_common.merge_cc_infos(
direct_cc_infos = [CcInfo(compilation_context = compilation_ctx, linking_context = linking_ctx)],
cc_infos = [common.transitive_cc_info],
)

return [
DefaultInfo(files = depset(libs + pic_libs)),
Expand All @@ -100,10 +138,9 @@ def _cuda_library_impl(ctx):
),
cuda_helper.create_cuda_info(
defines = depset(common.defines),
objects = objects,
pic_objects = pic_objects,
rdc_objects = rdc_objects,
rdc_pic_objects = rdc_pic_objects,
# all objects from cuda_objects should be properly archived, thus, the transitivity of archive is cut off here.
dlink_rdc_objects = rdc_dlink_inputs,
dlink_rdc_pic_objects = rdc_pic_dlink_inputs,
),
]

Expand All @@ -118,10 +155,8 @@ cuda_library = rule(
"alwayslink": attr.bool(default = False),
"rdc": attr.bool(
default = False,
doc = ("Whether to produce and consume relocateable device code. " +
"Transitive deps that contain device code must all either be cuda_objects or cuda_library(rdc = True). " +
"If False, all device code must be in the same translation unit. May have performance implications. " +
"See https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#using-separate-compilation-in-cuda."),
doc = ("Whether to perform device linking for relocateable device code. " +
"Transitive deps that contain device code must all either be cuda_objects or cuda_library(rdc = True)."),
),
"includes": attr.string_list(doc = "List of include dirs to be added to the compile line."),
"host_copts": attr.string_list(doc = "Add these options to the CUDA host compilation command."),
Expand Down
35 changes: 25 additions & 10 deletions cuda/private/rules/cuda_objects.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ def _cuda_objects_impl(ctx):
for src in ctx.attr.srcs:
src_files.extend(src[DefaultInfo].files.to_list())

transitive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep])
transitive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep])
transitive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep])
transitive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep])
# merge deps' direct objects and archive objects as our archive objects
archive_objects = depset(transitive = [dep[CudaInfo].objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_objects for dep in attr.deps if CudaInfo in dep])
archive_pic_objects = depset(transitive = [dep[CudaInfo].pic_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_pic_objects for dep in attr.deps if CudaInfo in dep])
archive_rdc_objects = depset(transitive = [dep[CudaInfo].rdc_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_rdc_objects for dep in attr.deps if CudaInfo in dep])
archive_rdc_pic_objects = depset(transitive = [dep[CudaInfo].rdc_pic_objects for dep in attr.deps if CudaInfo in dep] +
[dep[CudaInfo].archive_rdc_pic_objects for dep in attr.deps if CudaInfo in dep])

# outputs
objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False), transitive = [transitive_objects])
rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True), transitive = [transitive_rdc_objects])
pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False), transitive = [transitive_pic_objects])
rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True), transitive = [transitive_rdc_pic_objects])
# direct outputs
objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = False))
pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = False))
rdc_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = False, rdc = True))
rdc_pic_objects = depset(compile(ctx, cuda_toolchain, cc_toolchain, src_files, common, pic = True, rdc = True))

compilation_ctx = cc_common.create_compilation_context(
headers = common.headers,
Expand All @@ -39,6 +44,11 @@ def _cuda_objects_impl(ctx):
local_defines = depset(common.host_local_defines),
)

cc_info = cc_common.merge_cc_infos(
direct_cc_infos = [CcInfo(compilation_context = compilation_ctx)],
cc_infos = [common.transitive_cc_info],
)

return [
# default output is only enabled for rdc_objects, otherwise, when you build with
#
Expand All @@ -60,14 +70,19 @@ def _cuda_objects_impl(ctx):
rdc_pic_objects = rdc_pic_objects,
),
CcInfo(
compilation_context = compilation_ctx,
compilation_context = cc_info.compilation_context,
linking_context = cc_info.linking_context,
),
cuda_helper.create_cuda_info(
defines = depset(common.defines),
objects = objects,
pic_objects = pic_objects,
rdc_objects = rdc_objects,
rdc_pic_objects = rdc_pic_objects,
archive_objects = archive_objects,
archive_pic_objects = archive_pic_objects,
archive_rdc_objects = archive_rdc_objects,
archive_rdc_pic_objects = archive_rdc_pic_objects,
),
]

Expand Down

0 comments on commit 894603f

Please sign in to comment.