diff --git a/python/cudf_polars/cudf_polars/experimental/join.py b/python/cudf_polars/cudf_polars/experimental/join.py index 265cdf1b098..3754c295b01 100644 --- a/python/cudf_polars/cudf_polars/experimental/join.py +++ b/python/cudf_polars/cudf_polars/experimental/join.py @@ -74,9 +74,10 @@ def _make_hash_join( # Record new partitioning info partitioned_on: tuple[NamedExpr, ...] = () - if ir.left_on == ir.right_on or (ir.options[0] in ("left", "semi", "anti")): + how = ir.options[0].lower() + if ir.left_on == ir.right_on or (how in ("left", "semi", "anti")): partitioned_on = ir.left_on - elif ir.options[0] == "right": # pragma: no cover + elif how == "right": # pragma: no cover partitioned_on = ir.right_on partition_info[new_node] = PartitionInfo( count=output_count, @@ -94,7 +95,6 @@ def _should_bcast_join( output_count: int, ) -> bool: # Decide if a broadcast join is appropriate. - if partition_info[left].count >= partition_info[right].count: bcast_count = partition_info[right].count other = left @@ -116,13 +116,14 @@ def _should_bcast_join( # TODO: Make this value/heuristic configurable). # We may want to account for the number of workers. # 3. The "kind" of join is compatible with a broadcast join + how = ir.options[0].lower() return ( not other_shuffled and bcast_count <= 8 # TODO: Make this configurable and ( - ir.options[0] == "inner" - or (ir.options[0] in ("left", "semi", "anti") and other == left) - or (ir.options[0] == "right" and other == right) + how == "inner" + or (how in ("left", "semi", "anti") and other == left) + or (how == "right" and other == right) ) ) @@ -134,7 +135,7 @@ def _make_bcast_join( left: IR, right: IR, ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - how = ir.options[0] + how = ir.options[0].lower() if how != "inner": shuffle_options: dict[str, Any] = {} left_count = partition_info[left].count @@ -175,7 +176,7 @@ def _( new_node = ir.reconstruct(children) partition_info[new_node] = PartitionInfo(count=1) return new_node, partition_info - elif ir.options[0] == "cross": + elif ir.options[0].lower() == "cross": raise NotImplementedError( "cross join not support for multiple partitions." ) # pragma: no cover @@ -231,7 +232,7 @@ def _( } else: # Broadcast join - how = ir.options[0] + how = ir.options[0].lower() left_parts = partition_info[left] right_parts = partition_info[right] if left_parts.count >= right_parts.count: