Skip to content

Commit

Permalink
fix lower-case bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jan 29, 2025
1 parent ff1e08f commit 906c3c9
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions python/cudf_polars/cudf_polars/experimental/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def _make_hash_join(

# Record new partitioning info
partitioned_on: tuple[NamedExpr, ...] = ()
if ir.left_on == ir.right_on or (ir.options[0] in ("left", "semi", "anti")):
how = ir.options[0].lower()
if ir.left_on == ir.right_on or (how in ("left", "semi", "anti")):
partitioned_on = ir.left_on
elif ir.options[0] == "right": # pragma: no cover
elif how == "right": # pragma: no cover
partitioned_on = ir.right_on
partition_info[new_node] = PartitionInfo(
count=output_count,
Expand All @@ -94,7 +95,6 @@ def _should_bcast_join(
output_count: int,
) -> bool:
# Decide if a broadcast join is appropriate.

if partition_info[left].count >= partition_info[right].count:
bcast_count = partition_info[right].count
other = left
Expand All @@ -116,13 +116,14 @@ def _should_bcast_join(
# TODO: Make this value/heuristic configurable).
# We may want to account for the number of workers.
# 3. The "kind" of join is compatible with a broadcast join
how = ir.options[0].lower()
return (
not other_shuffled
and bcast_count <= 8 # TODO: Make this configurable
and (
ir.options[0] == "inner"
or (ir.options[0] in ("left", "semi", "anti") and other == left)
or (ir.options[0] == "right" and other == right)
how == "inner"
or (how in ("left", "semi", "anti") and other == left)
or (how == "right" and other == right)
)
)

Expand All @@ -134,7 +135,7 @@ def _make_bcast_join(
left: IR,
right: IR,
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
how = ir.options[0]
how = ir.options[0].lower()
if how != "inner":
shuffle_options: dict[str, Any] = {}
left_count = partition_info[left].count
Expand Down Expand Up @@ -175,7 +176,7 @@ def _(
new_node = ir.reconstruct(children)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info
elif ir.options[0] == "cross":
elif ir.options[0].lower() == "cross":
raise NotImplementedError(
"cross join not support for multiple partitions."
) # pragma: no cover
Expand Down Expand Up @@ -231,7 +232,7 @@ def _(
}
else:
# Broadcast join
how = ir.options[0]
how = ir.options[0].lower()
left_parts = partition_info[left]
right_parts = partition_info[right]
if left_parts.count >= right_parts.count:
Expand Down

0 comments on commit 906c3c9

Please sign in to comment.