Skip to content

Commit

Permalink
better fill_missing_cols logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dombean committed Nov 26, 2024
1 parent c832e42 commit 6084901
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions rdsa_utils/cdp/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Union

import boto3
import pyspark.sql.types as T
from pyspark.sql import DataFrame as SparkDF
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
Expand Down Expand Up @@ -65,8 +64,13 @@ def insert_df_to_hive_table(
- True: Replaces data only in partitions present in DataFrame.
- False: Appends data to existing partitions or creates new ones.
fill_missing_cols
If True, adds missing columns as nulls. If False, raises error
on schema mismatch (default is False).
If True, adds missing columns as NULL values. If False, raises an error
on schema mismatch, default is False.
- Explicitly casts DataFrame columns to match the Hive table schema to
avoid type mismatch errors.
- Adds missing columns as NULL values when `fill_missing_cols` is True,
regardless of their data type (e.g., String, Integer, Double, Boolean, etc.).
repartition_data_by
Controls data repartitioning, default is None:
- int: Sets target number of partitions.
Expand Down Expand Up @@ -143,6 +147,7 @@ def insert_df_to_hive_table(
# Check if the table exists; if not, set flag for later creation
table_exists = True
try:
table_schema = spark.read.table(table_name).schema
table_columns = spark.read.table(table_name).columns
except AnalysisException:
logger.info(
Expand All @@ -161,7 +166,10 @@ def insert_df_to_hive_table(
if fill_missing_cols and table_exists:
missing_columns = list(set(table_columns) - set(df.columns))
for col in missing_columns:
df = df.withColumn(col, F.lit(None).cast(T.StringType()))
column_type = [
field.dataType for field in table_schema if field.name == col
][0]
df = df.withColumn(col, F.lit(None).cast(column_type))
elif not fill_missing_cols and table_exists:
# Validate schema before writing
if set(table_columns) != set(df.columns):
Expand Down

0 comments on commit 6084901

Please sign in to comment.