Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JdbcSource to partition queries for potential performance improvements #236

Open
hannesmiller opened this issue Jan 30, 2017 · 14 comments
Open
Assignees
Milestone

Comments

@hannesmiller
Copy link
Contributor

hannesmiller commented Jan 30, 2017

  • Spark can partition data on a JDBC data frame by by specifying the following binding parameters which are all longs: lowerBound, upperBound, numPartitions and and partition key column
  • To take advantage of this in JdbcSource, the data has to be divided into multiple partitions (in multiple Threads). In turn these binding parameters are used to alter the original query for each partition, e.g. add a predicate to restrict query on each partition - for Oracle the partition key could be rownum which is available on every table, e.g. for a population consisting of 100 rows with a specification of: lowerBound=1, upperBound=100 and numPartitions=5 and query=select col1, col2 from table where blah would result in the following partitions queries on each partition:

Part 1:

select * (select col1, col2 from table where blah)
where rownum between 1 and 20

Part 2:

select * (select col1, col2 from table where blah)
where rownum between 21 and 40

Part 3:

select * (select col1, col2 from table where blah)
where rownum between 41 and 60

Part 4:

select * (select col1, col2 from table where blah)
where rownum between 61 and 80

Part 5:

select * (select col1, col2 from table where blah)
where rownum >= 81
  • For partition 5 just return the remainder of rows.

  • Note it may not be necessary to create N connections for each partition - simply return N JDBC result sets - one for each partition - investigating this...

Proposal

  • withPartition(lowerBound, upperBound, partitionColumn)
.withPartition(1, 100, rownum)
  • rownum is the partition key which is internal to Oracle however you can't use this across the board with a function like withPartition(1,100), e.g. SQLServer doesn't have rownum, however it can be achieved using a the windowing function ROW_NUMBER():
where RowNumber = ROW_NUMBER() OVER (ORDER BY CustomerID ASC)
  • where CustomerID is the primary key and ROW_NUMBER() is the SQLServer windowing function.
  • One idea is to provide a upper bound function like so:
withPartition(lowerBound:Long, upperBoundFn(query:String) => Long, partitionColumn:String)
  • the upperBoundFn() could do anything you like such as execute a separate query or just supply the count.

Example of returning N JDBC result sets

	
public static void executeProcedure(Connection con) {
   try {
      CallableStatement stmt = con.prepareCall(...);
      .....  //Set call parameters, if you have IN,OUT, or IN/OUT parameters

      boolean results = stmt.execute();
      int rsCount = 0;

      //Loop through the available result sets.
     while (results) {
           ResultSet rs = stmt.getResultSet();
           //Retrieve data from the result set.
           while (rs.next()) {
        ....// using rs.getxxx() method to retieve data
           }
           rs.close();

        //Check for next result set
        results = stmt.getMoreResults();
      } 
      stmt.close();
   }
   catch (Exception e) {
      e.printStackTrace();
   }
}
@hannesmiller
Copy link
Contributor Author

There's a floor in my proposal using rownum - each rownum predicate performs a table scan - Spark does this differently in a more efficient manner.

You can for example supply a hash function on a column (primary key) to return a long - most databases come with some kind of hash function

@sksamuel
Copy link
Contributor

sksamuel commented Feb 6, 2017

I think spark just says x < col < y doesn't it? So it might not be uniformly distributed.

@sksamuel
Copy link
Contributor

sksamuel commented Feb 8, 2017

You mentioned switching this to use a mod ?

@hannesmiller
Copy link
Contributor Author

Yeah will explain more tomorrow - it's not that straight forward but doable.

Basically you need to use a hash and mod functions together on the select column - Oracle supports both, e.g.:

Select *
From (
Select blah,
mod(hash(primary_col),8) + 1 as part
From table)
Where part = 1

  • The above query can be support up to 8 parallel queries for the partitions
  • I have been testing by hashing the primary key which is an Oracle number and hashes really well using the above
  • Converting the primary key to a varchar did not hash nicely and misses out on some rows

Therefore I think the onus should be on the user to tell EEL what the partition number column name is so that you can literally wrap the sql and add the where predicate , e.g: Where part_num = 1

@hannesmiller
Copy link
Contributor Author

The hardest bit I suppose is kicking off the parallel threads and joining the results as and when each thread finishes.

@hannesmiller
Copy link
Contributor Author

hannesmiller commented Feb 15, 2017

More analysis

  • I have been doing some JDBC performance tests on an Oracle table of 50 million rows.
  • If I run with a single thread it take and average of 175 seconds over 5 runs.
  • If I break up the query into multiple queries using a hash to partition each query then it dramatically drops down to 37 seconds.

Multi-query tests performed on an 8 core PC

  • The idea was to assign each core to a query running in a separate thread
  • I use the same fetchsize as the original single query
  • Each query thread ascertains a connection from a JDBC connection pool

Code

package hannesmiller

import java.io.{File, PrintWriter}
import java.util.concurrent.{Callable, Executors}

import com.sksamuel.exts.Logging
import org.apache.commons.dbcp2.BasicDataSource

import scala.collection.mutable.ListBuffer

object MultiHashDcfQuery extends App with Logging {

  private def generateStatsFile(fileName: String, stats: ListBuffer[String]): Unit = {
    val statsFile = new File(fileName)
    println(s"Generating ${statsFile.getAbsolutePath} ...")
    val statsFileWriter = new PrintWriter(statsFile)
    stats.foreach { s => statsFileWriter.write(s + "\n"); statsFileWriter.flush() }
    statsFileWriter.close()
    println(s"${statsFile.getAbsolutePath} done!")
  }

  val recordCount = 49510353L
  val partitionsStartNumber = 2
  val numberOfPartitions = 8
  val numberOfRuns = 1

  val sql =
    s"""SELECT MY_PRIMARY_KEY, COL2, COL3, COL4, COL5
       FROM MY_TABLE
       WHERE COL2 in (8682)"""

  def buildPartitionSql(bindExpression: String, bindExpressionAlias: String): String = {
    s"""
       |SELECT *
       |FROM (
       |  SELECT eel_tmp.*, $bindExpression AS $bindExpressionAlias
       |  FROM ( $sql ) eel_tmp
       |)
       |WHERE $bindExpressionAlias = ?
       |""".stripMargin
  }

  // Setup the database connection pool equal to the number of partitions - could be less depending on your connection
  // resource limit on the Database server.
  val dataSource = new BasicDataSource()
  dataSource.setDriverClassName("oracle.jdbc.OracleDriver")
  dataSource.setUrl("jdbc:oracle:thin:@//myhost:1901/myservice")
  dataSource.setUsername("username")
  dataSource.setPassword("username1234")
  dataSource.setPoolPreparedStatements(false)
  dataSource.setInitialSize(numberOfPartitions)
  dataSource.setDefaultAutoCommit(false)
  dataSource.setMaxOpenPreparedStatements(numberOfPartitions)

  val stats = ListBuffer[String]()
  for (numPartitions <- partitionsStartNumber to numberOfPartitions) {
    for (runNumber <- 1 to numberOfRuns) {

      // Kick off a number of threads equal to the number of partitions so each partitioned query is executed on parallel.
      val threadPool = Executors.newFixedThreadPool(numberOfPartitions)
      val startTime = System.currentTimeMillis()
      val fetchSize = 100600
      val futures = for (i <- 1 to numberOfPartitions) yield {
        threadPool.submit(new Callable[(Long, Long, Long, Long)] {
          override def call(): (Long, Long, Long, Long) = {
            var rowCount = 0L

            // Capture metrics about acquiring connection
            val connectionIdleTimeStart = System.currentTimeMillis()
            val connection = dataSource.getConnection
            val connectionIdleTime = System.currentTimeMillis() - connectionIdleTimeStart

            val partSql = buildPartitionSql(s"MOD(ORA_HASH(MY_PRIMARY_KEY),$numberOfPartitions) + 1", "PARTITION_NUMBER")
            val prepareStatement = connection.prepareStatement(partSql)
            prepareStatement.setFetchSize(fetchSize)
            prepareStatement.setLong(1, i)


            // Capture metrics for query execution
            val excuteQueryTimeStart = System.currentTimeMillis()
            val rs = prepareStatement.executeQuery()
            val executeQueryTime = (System.currentTimeMillis() - excuteQueryTimeStart) / 1000

            // Capture metrics for fetching data
            val fetchTimeStart = System.currentTimeMillis()
            while (rs.next()) {
              rowCount += 1
              if (rowCount % fetchSize == 0) logger.info(s"RowCount = $rowCount")
            }
            val fetchTime = (System.currentTimeMillis() - fetchTimeStart) / 1000

            prepareStatement.close()
            rs.close()
            connection.close()
            (connectionIdleTime, executeQueryTime, fetchTime, rowCount)
          }
        })
      }

      // Total up all the rows
      var totalRowCount = 0L
      var totalConnectionIdleTime = 0L
      futures.foreach { f =>
        val (connectionIdleTime, executeQueryTime, fetchTime, rowCount) = f.get
        logger.info(s"connectionIdleTime=$connectionIdleTime, executeQueryTime=$executeQueryTime, fetchTime=$fetchTime, rowCount=$rowCount")
        totalConnectionIdleTime += connectionIdleTime
        totalRowCount += rowCount
      }
      val elapsedTime = (System.currentTimeMillis() - startTime) / 1000.0
      logger.info(s"Run $runNumber with $numPartitions partition(s): Took $elapsedTime second(s) for RowCount = $totalRowCount, totalConnectionIdlTime = $totalConnectionIdleTime")
      threadPool.shutdownNow()
      stats += s"$numPartitions\t$runNumber\t$elapsedTime"
    }
  }
  generateStatsFile("multi_partition_stats.csv", stats)

}
  • For each thread I create the partitioned SQL using Oracle MOD/HASH functions on the primary key column:
  val partSql = buildPartitionSql(s"MOD(ORA_HASH(F_CASH_FLOW_ID),$numberOfPartitions) + 1", "PARTITION_NUMBER")
  ...
  ...
  def buildPartitionSql(bindExpression: String, bindExpressionAlias: String): String = {
    s"""
       |SELECT *
       |FROM (
       |  SELECT eel_tmp.*, $bindExpression AS $bindExpressionAlias
       |  FROM ( $sql ) eel_tmp
       |)
       |WHERE $bindExpressionAlias = ?
       |""".stripMargin
  }
  • The SQL returned augments the original query with the bindExpression argument and aliases it to the column PARTITION_NUMBER
  • The subsequent lines creates a JDBC prepared statement and plants the desired partition number:
val prepareStatement = connection.prepareStatement(partSql)
prepareStatement.setFetchSize(fetchSize)
prepareStatement.setLong(1, i)

Solution

  • Can we implement this in EEL on the JdbcSource
  • It's very difficult to generalize as this mechanism may not behave the same way on another DBMS like SqlServer (they do have hash and mod functions though).
  • That 's why am I am proposing to pass in a expression and an alias to column

@sksamuel sksamuel added this to the 1.3 milestone Apr 24, 2017
@sksamuel sksamuel self-assigned this Apr 24, 2017
@hannesmiller
Copy link
Contributor Author

Overview

I have experimented with my own custom JdbSource based on the original with some new arguments for supporting 2 different partition strategies.

case class JdbcSource(connFn: () => Connection,
                                 query: String,
                                 partHashFuncExpr: Option[String] = None,
                                 partColumnAlias: Option[String] = None,
                                 partRangeColumn: Option[String] = None,
                                 minVal: Option[Long] = None,
                                 maxVal: Option[Long] = None,
                                 numberOfParts: Int = 1,
                                 bind: (PreparedStatement) => Unit = stmt => (),
                                 fetchSize: Int = 100,
                                 providedSchema: Option[StructType] = None,
                                 providedDialect: Option[JdbcDialect] = None,
                                 bucketing: Option[Bucketing] = None)

Strategy 1 – Use a SQL hash function expresion aliased to a synthetic column (e.g. PARTITION_NUMBER)

  val numberOfPartitions = 4
  JdbcSource(() => dataSource.getConnection(), query)
    .withHashPartitioning(s"MOD(ORA_HASH(ID),$numberOfPartitions) + 1", "PARTITION_NUMBER", numberOfPartitions)

Note this example is using Oracle Modulus and Hash functions – you can use an equivalent function for another Database dialect, e.g. SQLServer.

Strategy 2 – Pass in a numeric id with a minimum and maximum range

  val numberOfPartitions = 4
  JdbcSource(() => dataSource.getConnection(), query)
    .withRangePartitioning("ID", 1, 201786, numberOfPartitions)

For these partition strategies the JdbSource will generate N JdbcPart objects containing is own partitioned query :

  override def parts(): List[JdbcPart] = {
    if (partHashFuncExpr.nonEmpty || partRangeColumn.nonEmpty) {
      val jdbcParts: Seq[JdbcPart] = for (i <- 1 to numberOfParts) yield new JdbcPart(connFn, buildPartSql(i), bind, fetchSize, dialect())
      jdbcParts.toList
    } else List(new JdbcPart(connFn, query, bind, fetchSize, dialect()))
  }

The buildPartSql method generates the SQL for each partition on both strategies:

  private def buildPartSql(partitionNumber: Int): String = {
    if (partHashFuncExpr.nonEmpty) {
      s"""
         |SELECT *
         |FROM (
         |  SELECT eel_tmp.*, ${partHashFuncExpr.get} AS ${partColumnAlias.get}
         |  FROM ( $query ) eel_tmp
         |)
         |WHERE ${partColumnAlias.get} = $partitionNumber
         |""".stripMargin
    } else if (partRangeColumn.nonEmpty) {
      val partitionRanges = generatePartRanges(minVal.get, maxVal.get, numberOfParts)(partitionNumber - 1)
      s"""
         |SELECT *
         |FROM (
         |  SELECT *
         |  FROM ( $query )
         |)
         |WHERE ${partRangeColumn.get} BETWEEN ${partitionRanges.min} AND ${partitionRanges.max}
         |""".stripMargin
    }
    else query
  }

The generatePartRanges method generates the ranges for Strategy 2:

  case class PartRange(min: Long, max: Long)

  private def generatePartRanges(min: Long, max: Long, numberOfPartitions: Int): Array[PartRange] = {
    val partitionRanges = new Array[PartRange](numberOfPartitions)
    val bucketSizes = new Array[Long](numberOfPartitions)
    val evenLength = (max - min + 1) / numberOfPartitions
    for (i <- 0 until numberOfPartitions) bucketSizes(i) = evenLength

    // distribute surplus as evenly as possible across buckets
    var surplus = (max - min + 1) % numberOfPartitions
    var i: Int = 0
    while (surplus > 0) {
      bucketSizes(i) += 1
      surplus -= 1
      i = (i + 1) % numberOfPartitions
    }

    i = 0
    var n = 0
    var k = min
    while (i < numberOfPartitions && k <= max) {
      partitionRanges(i) = PartRange(k, k + bucketSizes(i) - 1)
      k += bucketSizes(i)
      i += 1
      n += 1
    }
    partitionRanges
  }
}

Finally a slight change to the fetchSchema which covers all strategies and the default:

  def fetchSchema(): StructType = {
    using(connFn()) { conn =>
      val schemaQuery = s"SELECT * FROM (${buildPartSql(1)}) tmp WHERE 1=0"
      using(conn.prepareStatement(schemaQuery)) { stmt =>

        stmt.setFetchSize(fetchSize)
        bind(stmt)

        val rs = timed("Executing query $query") {
          stmt.executeQuery()
        }

        val schema = schemaFor(dialect(), rs)
        rs.close()
        schema
      }
    }
  }

@vennapuc
Copy link

vennapuc commented Feb 4, 2018

Hi,
I have 40 Million records for table in oracle. I to use spark jdbc , i have to write this data to csv files.
Can u plz help me on Lowerbound, upperbound ,numofpartitions and hashing function to split data equally across all the partitions

@hannesmiller
Copy link
Contributor Author

Hi Vennapuc,

Let me look into this and I will get back to you with an answer - there are a few JdbcSource partitioning strategies you can use - I have actually used the HashPartitioning strategy on Oracle with a similar population to yours...

If your table has a primary key that is a number (i.e. a sequence) then the with hash strategy you can specify:

  1. A sql expression that looks something like: mod(hash(primary_key), 4)
  2. Specify the hash column - primary_key
  3. Number of partitions - 4
  4. And of course your main query

What the above does is split your main query into multiple part queries where each part is assigned to a separate thread.

For the RangeBound strategy you will need to do something similar to the above but you MUST know up front what MAX count is which can be determined with a sql query beforehand.

I think this strategy requires some kind of SQL expression like a row_number analytical function.

I think for a next major release (1.3) we should put some examples together.

Regards,
Hannes

@hannesmiller hannesmiller reopened this Feb 4, 2018
@vennapuc
Copy link

vennapuc commented Feb 4, 2018

Thanks hannesmiller for your response.
I see some hash functions above "MOD(ORA_HASH(ID),$numberOfPartitions) + 1 ..
You are adding "1" for hashcode. Is there any reason for this.
I was able to do this with has function (1 + mod(hash(fa_id), %(5)s)) as hash_code in oracle. Where Fa_id is numeric column.
I choose bucking number as 5 above. Any suggestions how to choose this number?.
My concern is how to choose no of partitions and lower bound and upper bound for ~40 Million records. Can you please help on choosing right no of parameters.

Thanks,

@hannesmiller
Copy link
Contributor Author

hannesmiller commented Feb 4, 2018

Hi Vennapuc,
Unfortunately there isn't an exact science for how many partitions you should have as there are too factors to consider (profile of your Oracle server, how many cores you have available on the client machine, etc...) – the best way is trial and error, i.e. experiment this on your target environment.

Can you tell me what version of EEL you are using? I ask because in the latest alpha release you can simply specify the Hash partition strategy on the JdbcSource, e.g.:

val partitions = 4
JdbcSource(connFn, query)
	.withPartitionStrategy(HashPartitionStrategy(s”MOD(ORA_HASH(key), $partitions)”, partitions)
            .withProvidedDialect(new OracleJdbcDialect)
            .withFetchSize(...)

@vennapuc
Copy link

vennapuc commented Feb 5, 2018

I am using spark 1.6.2 and scala 2.10.5...

@hannesmiller
Copy link
Contributor Author

Hi Vennapuc, I suggest you try posting your query to a Spark forum?

Our product EEL is light weight BigData scala library for data ingest into environments such as Hadoop.

Regards,
Hannes

@garyfrost garyfrost modified the milestones: 1.2, 1.3 Feb 5, 2018
@garyfrost garyfrost assigned garyfrost and hannesmiller and unassigned garyfrost Feb 5, 2018
@MaxStepQ
Copy link

Hello, Hannesmiller !
I am experimenting with connection to oracle 12 g via spark jdbc driver
Table in oracle has 100 million rows. Launch spark application in local mode. Laptop has 4 CPU
code fragment:
val df = spark.read
.format("jdbc")
.option("url", "jdbc:oracle:thin:login/[email protected]:1521:ORCLCDB")
.option("dbtable","C##BUSER.CUSTOMERS")
.option("driver", "oracle.jdbc.OracleDriver")
.option("numPartitions", 4)
.option("partitionColumn", "CUST_ID")
.option("lowerBound", 1)
.option("upperBound", 100000000)
.load()
df.write.csv("/path_to_file_to_save")

I tried to launch this application with numPartitions=1, numPartitions=4 and numPartitions=10
Reading data from Oracle and writing locally with partitions (4 or 10) takes 3 times more than without partitions.

Could you please help me to understand how to increase speed of reading using param "numPartitions" ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants