diff --git a/CHANGELOG.md b/CHANGELOG.md index bade680..bc1911c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0 ### Added ### Changed +- Modified `insert_df_to_hive_table` function in `cdp/io/output.py`. Added support + for creating non-existent Hive tables, repartitioning by column or partition count, + and handling missing columns with explicit type casting. ### Deprecated diff --git a/rdsa_utils/cdp/io/output.py b/rdsa_utils/cdp/io/output.py index 3f7231a..98e5d35 100644 --- a/rdsa_utils/cdp/io/output.py +++ b/rdsa_utils/cdp/io/output.py @@ -5,6 +5,7 @@ from typing import Union import boto3 +import pyspark.sql.types as T from pyspark.sql import DataFrame as SparkDF from pyspark.sql import SparkSession from pyspark.sql import functions as F @@ -36,26 +37,52 @@ def insert_df_to_hive_table( table_name: str, overwrite: bool = False, fill_missing_cols: bool = False, + repartition_column: Union[int, str, None] = None, ) -> None: - """Write the SparkDF contents to a Hive table. + """Write SparkDF to Hive table with optional configuration. - This function writes data from a SparkDF into a Hive table, allowing - optional handling of missing columns. The table's column order is ensured to - match that of the DataFrame. + This function writes data from a SparkDF into a Hive table, handling missing + columns and optional repartitioning. It ensures the table's column order matches + the DataFrame and manages different overwrite behaviors for partitioned and + non-partitioned data. Parameters ---------- - spark + spark : SparkSession Active SparkSession. - df + df : SparkDF SparkDF containing data to be written. - table_name + table_name : str Name of the Hive table to write data into. - overwrite - If True, existing data in the table will be overwritten, - by default False. - fill_missing_cols - If True, missing columns will be filled with nulls, by default False. + overwrite : bool, optional + Controls how existing data is handled, default is False: + + For non-partitioned data: + - True: Replaces entire table with DataFrame data + - False: Appends DataFrame data to existing table + + For partitioned data: + - True: Replaces data only in partitions present in DataFrame + - False: Appends data to existing partitions or creates new ones + fill_missing_cols : bool, optional + If True, adds missing columns as nulls. If False, raises error + on schema mismatch (default is False). + repartition_column : Union[int, str, None], optional + Controls data repartitioning, default is None: + - int: Sets target number of partitions + - str: Specifies column to repartition by + - None: No repartitioning performed + + Notes + ----- + When using repartition with a number: + - Affects physical file structure but preserves Hive partitioning scheme. + - Controls number of output files per write operation per Hive partition. + - Maintains partition-based query optimization. + + When repartitioning by column: + - Helps balance file sizes across Hive partitions. + - Reduces creation of small files. Raises ------ @@ -65,36 +92,36 @@ def insert_df_to_hive_table( ValueError If the SparkDF schema does not match the Hive table schema and 'fill_missing_cols' is set to False. + DataframeEmptyError + If input DataFrame is empty. Exception For other general exceptions when writing data to the table. """ - logger.info(f"Preparing to write data to {table_name}.") - - # Validate SparkDF before writing - if is_df_empty(df): - msg = f"Cannot write an empty SparkDF to {table_name}" - raise DataframeEmptyError( - msg, - ) + logger.info(f"Preparing to write data to {table_name} with overwrite={overwrite}.") + # Check if the table exists; if not, set flag for later creation + table_exists = True try: table_columns = spark.read.table(table_name).columns except AnalysisException: - logger.error( - ( - f"Error reading table {table_name}. " - f"Make sure the table exists and you have access to it." - ), + logger.info( + f"Table {table_name} does not exist and will be " + "created after transformations.", ) + table_exists = False + table_columns = df.columns # Use DataFrame columns as initial schema - raise + # Validate SparkDF before writing + if is_df_empty(df): + msg = f"Cannot write an empty SparkDF to {table_name}" + raise DataframeEmptyError(msg) - if fill_missing_cols: + # Handle missing columns if specified + if fill_missing_cols and table_exists: missing_columns = list(set(table_columns) - set(df.columns)) - for col in missing_columns: - df = df.withColumn(col, F.lit(None)) - else: + df = df.withColumn(col, F.lit(None).cast(T.StringType())) + elif not fill_missing_cols and table_exists: # Validate schema before writing if set(table_columns) != set(df.columns): msg = ( @@ -103,10 +130,32 @@ def insert_df_to_hive_table( ) raise ValueError(msg) - df = df.select(table_columns) + # Ensure column order + df = df.select(table_columns) if table_exists else df + + # Apply repartitioning if specified + if repartition_column is not None: + if isinstance(repartition_column, int): + logger.info(f"Repartitioning data into {repartition_column} partitions.") + df = df.repartition(repartition_column) + elif isinstance(repartition_column, str): + logger.info(f"Repartitioning data by column {repartition_column}.") + df = df.repartition(repartition_column) + # Write DataFrame to Hive table based on existence and overwrite parameter try: - df.write.insertInto(table_name, overwrite) + if table_exists: + if overwrite: + logger.info(f"Overwriting existing table {table_name}.") + df.write.mode("overwrite").saveAsTable(table_name) + else: + logger.info( + f"Inserting into existing table {table_name} without overwrite.", + ) + df.write.insertInto(table_name) + else: + df.write.saveAsTable(table_name) + logger.info(f"Table {table_name} created successfully.") logger.info(f"Successfully wrote data to {table_name}.") except Exception: logger.error(f"Error writing data to {table_name}.") diff --git a/tests/cdp/io/test_cdsw_output.py b/tests/cdp/io/test_cdsw_output.py index 0175d31..9c9725c 100644 --- a/tests/cdp/io/test_cdsw_output.py +++ b/tests/cdp/io/test_cdsw_output.py @@ -106,22 +106,102 @@ def test_insert_df_to_hive_table_without_missing_columns( fill_missing_cols=False, ) + @patch("pyspark.sql.DataFrameWriter.saveAsTable") @patch("pyspark.sql.DataFrameReader.table") - def test_insert_df_to_hive_table_with_non_existing_table( + def test_insert_df_to_hive_table_creates_non_existing_table( self, mock_table, + mock_save_as_table, spark_session: SparkSession, test_df: SparkDF, ) -> None: - """Test that insert_df_to_hive_table raises an AnalysisException when - the table doesn't exist. - """ - table_name = "non_existing_table" - # Create an AnalysisException with a stack trace - exc = AnalysisException(f"Table {table_name} not found.") - mock_table.side_effect = exc - with pytest.raises(AnalysisException): - insert_df_to_hive_table(spark_session, test_df, table_name) + """Test that the function creates the table if it does not exist.""" + table_name = "new_table" + # Simulate non-existing table by raising AnalysisException + mock_table.side_effect = AnalysisException(f"Table {table_name} not found.") + mock_save_as_table.return_value = None + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + overwrite=True, + fill_missing_cols=True, + ) + # Assert that saveAsTable was called + mock_save_as_table.assert_called_with(table_name) + + @patch("pyspark.sql.DataFrameWriter.insertInto") + @patch("pyspark.sql.DataFrameReader.table") + def test_insert_df_to_hive_table_with_repartition_column( + self, + mock_table, + mock_insert_into, + spark_session: SparkSession, + test_df: SparkDF, + ) -> None: + """Test that the DataFrame is repartitioned by a specified column.""" + table_name = "test_table" + mock_table.return_value.columns = ["id", "name", "age"] + mock_insert_into.return_value = None + with patch.object(test_df, "repartition") as mock_repartition: + mock_repartition.return_value = test_df + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + repartition_column="id", + overwrite=True, + ) + # Assert repartition was called with the column name + mock_repartition.assert_called_with("id") + # Assert insertInto was called + mock_insert_into.assert_called_with(table_name, True) + + @patch("pyspark.sql.DataFrameWriter.insertInto") + @patch("pyspark.sql.DataFrameReader.table") + def test_insert_df_to_hive_table_with_repartition_num_partitions( + self, + mock_table, + mock_insert_into, + spark_session: SparkSession, + test_df: SparkDF, + ) -> None: + """Test that the DataFrame is repartitioned into a specific number of partitions.""" + table_name = "test_table" + mock_table.return_value.columns = ["id", "name", "age"] + mock_insert_into.return_value = None + with patch.object(test_df, "repartition") as mock_repartition: + mock_repartition.return_value = test_df + insert_df_to_hive_table( + spark_session, + test_df, + table_name, + repartition_column=5, + overwrite=True, + ) + # Assert repartition was called with the number of partitions + mock_repartition.assert_called_with(5) + # Assert insertInto was called + mock_insert_into.assert_called_with(table_name, True) + + def test_insert_df_to_hive_table_with_empty_dataframe( + self, + spark_session: SparkSession, + ) -> None: + """Test that an empty DataFrame raises DataframeEmptyError.""" + from rdsa_utils.exceptions import DataframeEmptyError + + table_name = "test_table" + empty_df = spark_session.createDataFrame( + [], + schema="id INT, name STRING, age INT", + ) + with pytest.raises(DataframeEmptyError): + insert_df_to_hive_table( + spark_session, + empty_df, + table_name, + ) class TestWriteAndReadHiveTable: