Skip to content

Commit

Permalink
Merge pull request spark-redshift-community#58 from eeshugerman/colum…
Browse files Browse the repository at this point in the history
…n-list

Add 'include_column_list' parameter
  • Loading branch information
lucagiovagnoli authored Dec 16, 2019
2 parents 8d08b6f + 42bf9be commit 2d5c49e
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# spark-redshift Changelog

## 4.1.0

- Add `include_column_list` parameter

## 4.0.2

- Trim SQL text for preactions and postactions, to fix empty SQL queries bug.
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,16 @@ must also set a distribution key with the <tt>distkey</tt> option.
<p>Since setting <tt>usestagingtable=false</tt> operation risks data loss / unavailability, we have chosen to deprecate it in favor of requiring users to manually drop the destination table themselves.</p>
</td>
</tr>
<tr>
<td><tt>include_column_list</tt></td>
<td>No</td>
<td>false</td>
<td>
If <tt>true</tt> then this library will automatically extract the columns from the schema
and add them to the COPY command according to the <a href="http://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-column-mapping.html">Column List docs</a>.
(e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`).
</td>
</tr>
<tr>
<td><tt>description</tt></td>
<td>No</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ private[redshift] object Parameters {
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
"preactions" -> ";",
"postactions" -> ";"
"postactions" -> ";",
"include_column_list" -> "false"
)

val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")
Expand Down Expand Up @@ -285,5 +286,11 @@ private[redshift] object Parameters {
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}

/**
* If true then this library will extract the column list from the schema to
* include in the COPY command (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`)
*/
def includeColumnList: Boolean = parameters("include_column_list").toBoolean
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private[redshift] class RedshiftWriter(
*/
private def copySql(
sqlContext: SQLContext,
schema: StructType,
params: MergedParameters,
creds: AWSCredentialsProvider,
manifestUrl: String): String = {
Expand All @@ -96,7 +97,13 @@ private[redshift] class RedshiftWriter(
case "AVRO" => "AVRO 'auto'"
case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
}
s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
val columns = if (params.includeColumnList) {
"(" + schema.fieldNames.map(name => s""""$name"""").mkString(",") + ") "
} else {
""
}

s"COPY ${params.table.get} ${columns}FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"${format} manifest ${params.extraCopyOptions}"
}

Expand Down Expand Up @@ -138,7 +145,7 @@ private[redshift] class RedshiftWriter(

manifestUrl.foreach { manifestUrl =>
// Load the temporary data into the new file
val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl)
val copyStatement = copySql(data.sqlContext, data.schema, params, creds, manifestUrl)
log.info(copyStatement)
try {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class ParametersSuite extends FunSuite with Matchers {
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password",
"forward_spark_s3_credentials" -> "true")
"forward_spark_s3_credentials" -> "true",
"include_column_list" -> "true")

val mergedParams = Parameters.mergeParameters(params)

Expand All @@ -37,9 +38,14 @@ class ParametersSuite extends FunSuite with Matchers {
mergedParams.jdbcUrl shouldBe params("url")
mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))
assert(mergedParams.forwardSparkS3Credentials)
assert(mergedParams.includeColumnList)

// Check that the defaults have been added
(Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach {
(
Parameters.DEFAULT_PARAMETERS
- "forward_spark_s3_credentials"
- "include_column_list"
).foreach {
case (key, value) => mergedParams.parameters(key) shouldBe value
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,46 @@ class RedshiftSourceSuite
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}

test("include_column_list=true adds the schema columns to the COPY query") {
val expectedCommands = Seq(
"CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,

("COPY \"PUBLIC\".\"test_table\" \\(\"testbyte\",\"testbool\",\"testdate\"," +
"\"testdouble\",\"testfloat\",\"testint\",\"testlong\",\"testshort\",\"teststring\"," +
"\"testtimestamp\"\\) FROM .*").r
)

val params = defaultParams ++ Map("include_column_list" -> "true")

val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema))

val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
source.createRelation(testSqlContext, SaveMode.Append, params, expectedDataDF)

mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}

test("include_column_list=false (default) does not add the schema columns to the COPY query") {
val expectedCommands = Seq(
"CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,

"COPY \"PUBLIC\".\"test_table\" FROM .*".r
)

val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> TestUtils.testSchema))

val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
source.createRelation(testSqlContext, SaveMode.Append, defaultParams, expectedDataDF)

mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}

test("configuring maxlength on string columns") {
val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build()
val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build()
Expand Down Expand Up @@ -594,4 +634,4 @@ class RedshiftSourceSuite
}
assert(e.getMessage.contains("Block FileSystem"))
}
}
}
2 changes: 1 addition & 1 deletion version.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version in ThisBuild := "4.0.2"
version in ThisBuild := "4.1.0"

0 comments on commit 2d5c49e

Please sign in to comment.