Skip to content

Commit

Permalink
add features to insert_df_to_hive_table
Browse files Browse the repository at this point in the history
  • Loading branch information
dombean committed Nov 15, 2024
1 parent 780233d commit dc7edab
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 42 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [semantic versioning](https://semver.org/spec/v2.0.0
### Added

### Changed
- Modified `insert_df_to_hive_table` function in `cdp/io/output.py`. Added support
for creating non-existent Hive tables, repartitioning by column or partition count,
and handling missing columns with explicit type casting.

### Deprecated

Expand Down
113 changes: 81 additions & 32 deletions rdsa_utils/cdp/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

import boto3
import pyspark.sql.types as T
from pyspark.sql import DataFrame as SparkDF
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
Expand Down Expand Up @@ -36,26 +37,52 @@ def insert_df_to_hive_table(
table_name: str,
overwrite: bool = False,
fill_missing_cols: bool = False,
repartition_column: Union[int, str, None] = None,
) -> None:
"""Write the SparkDF contents to a Hive table.
"""Write SparkDF to Hive table with optional configuration.
This function writes data from a SparkDF into a Hive table, allowing
optional handling of missing columns. The table's column order is ensured to
match that of the DataFrame.
This function writes data from a SparkDF into a Hive table, handling missing
columns and optional repartitioning. It ensures the table's column order matches
the DataFrame and manages different overwrite behaviors for partitioned and
non-partitioned data.
Parameters
----------
spark
spark : SparkSession
Active SparkSession.
df
df : SparkDF
SparkDF containing data to be written.
table_name
table_name : str
Name of the Hive table to write data into.
overwrite
If True, existing data in the table will be overwritten,
by default False.
fill_missing_cols
If True, missing columns will be filled with nulls, by default False.
overwrite : bool, optional
Controls how existing data is handled, default is False:
For non-partitioned data:
- True: Replaces entire table with DataFrame data
- False: Appends DataFrame data to existing table
For partitioned data:
- True: Replaces data only in partitions present in DataFrame
- False: Appends data to existing partitions or creates new ones
fill_missing_cols : bool, optional
If True, adds missing columns as nulls. If False, raises error
on schema mismatch (default is False).
repartition_column : Union[int, str, None], optional
Controls data repartitioning, default is None:
- int: Sets target number of partitions
- str: Specifies column to repartition by
- None: No repartitioning performed
Notes
-----
When using repartition with a number:
- Affects physical file structure but preserves Hive partitioning scheme.
- Controls number of output files per write operation per Hive partition.
- Maintains partition-based query optimization.
When repartitioning by column:
- Helps balance file sizes across Hive partitions.
- Reduces creation of small files.
Raises
------
Expand All @@ -65,36 +92,36 @@ def insert_df_to_hive_table(
ValueError
If the SparkDF schema does not match the Hive table schema and
'fill_missing_cols' is set to False.
DataframeEmptyError
If input DataFrame is empty.
Exception
For other general exceptions when writing data to the table.
"""
logger.info(f"Preparing to write data to {table_name}.")

# Validate SparkDF before writing
if is_df_empty(df):
msg = f"Cannot write an empty SparkDF to {table_name}"
raise DataframeEmptyError(
msg,
)
logger.info(f"Preparing to write data to {table_name} with overwrite={overwrite}.")

# Check if the table exists; if not, set flag for later creation
table_exists = True
try:
table_columns = spark.read.table(table_name).columns
except AnalysisException:
logger.error(
(
f"Error reading table {table_name}. "
f"Make sure the table exists and you have access to it."
),
logger.info(
f"Table {table_name} does not exist and will be "
"created after transformations.",
)
table_exists = False
table_columns = df.columns # Use DataFrame columns as initial schema

raise
# Validate SparkDF before writing
if is_df_empty(df):
msg = f"Cannot write an empty SparkDF to {table_name}"
raise DataframeEmptyError(msg)

if fill_missing_cols:
# Handle missing columns if specified
if fill_missing_cols and table_exists:
missing_columns = list(set(table_columns) - set(df.columns))

for col in missing_columns:
df = df.withColumn(col, F.lit(None))
else:
df = df.withColumn(col, F.lit(None).cast(T.StringType()))
elif not fill_missing_cols and table_exists:
# Validate schema before writing
if set(table_columns) != set(df.columns):
msg = (
Expand All @@ -103,10 +130,32 @@ def insert_df_to_hive_table(
)
raise ValueError(msg)

df = df.select(table_columns)
# Ensure column order
df = df.select(table_columns) if table_exists else df

# Apply repartitioning if specified
if repartition_column is not None:
if isinstance(repartition_column, int):
logger.info(f"Repartitioning data into {repartition_column} partitions.")
df = df.repartition(repartition_column)
elif isinstance(repartition_column, str):
logger.info(f"Repartitioning data by column {repartition_column}.")
df = df.repartition(repartition_column)

# Write DataFrame to Hive table based on existence and overwrite parameter
try:
df.write.insertInto(table_name, overwrite)
if table_exists:
if overwrite:
logger.info(f"Overwriting existing table {table_name}.")
df.write.mode("overwrite").saveAsTable(table_name)
else:
logger.info(
f"Inserting into existing table {table_name} without overwrite.",
)
df.write.insertInto(table_name)
else:
df.write.saveAsTable(table_name)
logger.info(f"Table {table_name} created successfully.")
logger.info(f"Successfully wrote data to {table_name}.")
except Exception:
logger.error(f"Error writing data to {table_name}.")
Expand Down
100 changes: 90 additions & 10 deletions tests/cdp/io/test_cdsw_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,102 @@ def test_insert_df_to_hive_table_without_missing_columns(
fill_missing_cols=False,
)

@patch("pyspark.sql.DataFrameWriter.saveAsTable")
@patch("pyspark.sql.DataFrameReader.table")
def test_insert_df_to_hive_table_with_non_existing_table(
def test_insert_df_to_hive_table_creates_non_existing_table(
self,
mock_table,
mock_save_as_table,
spark_session: SparkSession,
test_df: SparkDF,
) -> None:
"""Test that insert_df_to_hive_table raises an AnalysisException when
the table doesn't exist.
"""
table_name = "non_existing_table"
# Create an AnalysisException with a stack trace
exc = AnalysisException(f"Table {table_name} not found.")
mock_table.side_effect = exc
with pytest.raises(AnalysisException):
insert_df_to_hive_table(spark_session, test_df, table_name)
"""Test that the function creates the table if it does not exist."""
table_name = "new_table"
# Simulate non-existing table by raising AnalysisException
mock_table.side_effect = AnalysisException(f"Table {table_name} not found.")
mock_save_as_table.return_value = None
insert_df_to_hive_table(
spark_session,
test_df,
table_name,
overwrite=True,
fill_missing_cols=True,
)
# Assert that saveAsTable was called
mock_save_as_table.assert_called_with(table_name)

@patch("pyspark.sql.DataFrameWriter.insertInto")
@patch("pyspark.sql.DataFrameReader.table")
def test_insert_df_to_hive_table_with_repartition_column(
self,
mock_table,
mock_insert_into,
spark_session: SparkSession,
test_df: SparkDF,
) -> None:
"""Test that the DataFrame is repartitioned by a specified column."""
table_name = "test_table"
mock_table.return_value.columns = ["id", "name", "age"]
mock_insert_into.return_value = None
with patch.object(test_df, "repartition") as mock_repartition:
mock_repartition.return_value = test_df
insert_df_to_hive_table(
spark_session,
test_df,
table_name,
repartition_column="id",
overwrite=True,
)
# Assert repartition was called with the column name
mock_repartition.assert_called_with("id")
# Assert insertInto was called
mock_insert_into.assert_called_with(table_name, True)

@patch("pyspark.sql.DataFrameWriter.insertInto")
@patch("pyspark.sql.DataFrameReader.table")
def test_insert_df_to_hive_table_with_repartition_num_partitions(
self,
mock_table,
mock_insert_into,
spark_session: SparkSession,
test_df: SparkDF,
) -> None:
"""Test that the DataFrame is repartitioned into a specific number of partitions."""
table_name = "test_table"
mock_table.return_value.columns = ["id", "name", "age"]
mock_insert_into.return_value = None
with patch.object(test_df, "repartition") as mock_repartition:
mock_repartition.return_value = test_df
insert_df_to_hive_table(
spark_session,
test_df,
table_name,
repartition_column=5,
overwrite=True,
)
# Assert repartition was called with the number of partitions
mock_repartition.assert_called_with(5)
# Assert insertInto was called
mock_insert_into.assert_called_with(table_name, True)

def test_insert_df_to_hive_table_with_empty_dataframe(
self,
spark_session: SparkSession,
) -> None:
"""Test that an empty DataFrame raises DataframeEmptyError."""
from rdsa_utils.exceptions import DataframeEmptyError

table_name = "test_table"
empty_df = spark_session.createDataFrame(
[],
schema="id INT, name STRING, age INT",
)
with pytest.raises(DataframeEmptyError):
insert_df_to_hive_table(
spark_session,
empty_df,
table_name,
)


class TestWriteAndReadHiveTable:
Expand Down

0 comments on commit dc7edab

Please sign in to comment.