From 643015102ff45106740bbf4bc17154ed6f8cfedd Mon Sep 17 00:00:00 2001 From: jsj Date: Sat, 4 Jan 2025 13:10:24 +0100 Subject: [PATCH] add except cols Signed-off-by: jsj --- python/deltalake/table.py | 82 ++++++++++++++++++++++++++++++++++++-- python/tests/test_merge.py | 77 +++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 4 deletions(-) diff --git a/python/deltalake/table.py b/python/deltalake/table.py index e8fc7d866b..d4e4dd192e 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1522,9 +1522,12 @@ def when_matched_update( self._builder.when_matched_update(updates, predicate) return self - def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": + def when_matched_update_all( + self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None + ) -> "TableMerger": """Updating all source fields to target fields, source and target are required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + If ``except_cols`` is specified, then the columns in the exclude list will not be updated. Note: Column names with special characters, such as numbers or spaces should be encapsulated @@ -1532,11 +1535,13 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg Args: predicate: SQL like predicate on when to update all columns. - + except_cols: List of columns to exclude from update. Returns: TableMerger: TableMerger Object Example: + ** Update all columns ** + ```python from deltalake import DeltaTable, write_deltalake import pyarrow as pa @@ -1563,6 +1568,35 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg 1 2 5 2 3 6 ``` + + ** Update all columns except `bar` ** + + ```python + from deltalake import DeltaTable, write_deltalake + import pyarrow as pa + + data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]}) + write_deltalake("tmp", data) + dt = DeltaTable("tmp") + new_data = pa.table({"foo": [1], "bar": [7]}) + + ( + dt.merge( + source=new_data, + predicate="target.foo = source.foo", + source_alias="source", + target_alias="target") + .when_matched_update_all(except_cols=["bar"]) + .execute() + ) + {'num_source_rows': 1, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 2, 'num_output_rows': 3, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...} + + dt.to_pandas() + foo bar + 0 1 4 + 1 2 5 + 2 3 6 + ``` """ maybe_source_alias = self._builder.source_alias maybe_target_alias = self._builder.target_alias @@ -1572,9 +1606,12 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg (maybe_target_alias + ".") if maybe_target_alias is not None else "" ) + except_columns = except_cols or [] + updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" for col in self._builder.arrow_schema + if col.name not in except_columns } self._builder.when_matched_update(updates, predicate) @@ -1700,11 +1737,11 @@ def when_not_matched_insert( return self def when_not_matched_insert_all( - self, predicate: Optional[str] = None + self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None ) -> "TableMerger": """Insert a new row to the target table, updating all source fields to target fields. Source and target are required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for - the new row to be inserted. + the new row to be inserted. If ``except_cols`` is specified, then the columns in the exclude list will not be inserted. Note: Column names with special characters, such as numbers or spaces should be encapsulated @@ -1712,11 +1749,14 @@ def when_not_matched_insert_all( Args: predicate: SQL like predicate on when to insert. + except_cols: List of columns to exclude from insert. Returns: TableMerger: TableMerger Object Example: + ** Insert all columns ** + ```python from deltalake import DeltaTable, write_deltalake import pyarrow as pa @@ -1744,6 +1784,36 @@ def when_not_matched_insert_all( 2 3 6 3 4 7 ``` + + ** Insert all columns except `bar` ** + + ```python + from deltalake import DeltaTable, write_deltalake + import pyarrow as pa + + data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]}) + write_deltalake("tmp", data) + dt = DeltaTable("tmp") + new_data = pa.table({"foo": [4], "bar": [7]}) + + ( + dt.merge( + source=new_data, + predicate='target.foo = source.foo', + source_alias='source', + target_alias='target') + .when_not_matched_insert_all(except_cols=["bar"]) + .execute() + ) + {'num_source_rows': 1, 'num_target_rows_inserted': 1, 'num_target_rows_updated': 0, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 3, 'num_output_rows': 4, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...} + + dt.to_pandas().sort_values("foo", ignore_index=True) + foo bar + 0 1 4 + 1 2 5 + 2 3 6 + 3 4 NaN + ``` """ maybe_source_alias = self._builder.source_alias maybe_target_alias = self._builder.target_alias @@ -1752,9 +1822,13 @@ def when_not_matched_insert_all( trgt_alias = ( (maybe_target_alias + ".") if maybe_target_alias is not None else "" ) + + except_columns = except_cols or [] + updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" for col in self._builder.arrow_schema + if col.name not in except_columns } self._builder.when_not_matched_insert(updates, predicate) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index b90eecae88..b9ac935753 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -168,6 +168,45 @@ def test_merge_when_matched_update_all_wo_predicate( assert result == expected +def test_merge_when_matched_update_all_with_exclude( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([15, 25], pa.int32()), + "deleted": pa.array([True, True]), + "weight": pa.array([10, 15], pa.int64()), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_update_all(except_cols=["sold"]).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4], pa.int32()), + "deleted": pa.array([False, False, False, True, True]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_matched_update_with_predicate( tmp_path: pathlib.Path, sample_table: pa.Table ): @@ -340,6 +379,44 @@ def test_merge_when_not_matched_insert_all_with_predicate( assert result == expected +def test_merge_when_not_matched_insert_all_with_exclude( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "9"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([None, None], pa.bool_()), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert_all(except_cols=["sold"]).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6", "9"]), + "price": pa.array([0, 1, 2, 3, 4, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, None, None], pa.int32()), + "deleted": pa.array([False, False, False, False, False, None, None]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table ):