Skip to content

Commit

Permalink
Addd jdbc partition strategy with bucket based implementation #236
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Jun 5, 2017
1 parent 90c66b8 commit 3f62e55
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 136 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Part

case class BucketPartitionStrategy(columnName: String,
numberOfPartitions: Int,
min: Int,
max: Int) extends JdbcPartitionStrategy {

def ranges: Seq[Range] = {

// distribute surplus as evenly as possible across buckets
// min max + 1 because the min-max range is inclusive
val surplus = (max - min + 1) % numberOfPartitions
val gap = (max - min + 1) / numberOfPartitions

List.tabulate(numberOfPartitions) { k =>
val start = min + k * gap + Math.min(k, surplus)
val end = min + ((k + 1) * gap) + Math.min(k + 1, surplus)
Range(start, end)
}
}

override def parts(connFn: () => Connection,
query: String,
bindFn: (PreparedStatement) => Unit,
fetchSize: Int,
dialect: JdbcDialect): Seq[Part] = {

ranges.map { range =>

val partitionedQuery =
s"""|SELECT *
|FROM (
| SELECT *
| FROM ( $query )
|)
|WHERE ${range.start} <= $columnName AND $columnName <= ${range.end}
|""".stripMargin

new JdbcPart(connFn, partitionedQuery, bindFn, fetchSize, dialect)
}
}
}



Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Part

case class HashPartitionStrategy(hashExpression: String,
numberOfPartitions: Int) extends JdbcPartitionStrategy {

def partitionedQuery(partNum: Int, query: String): String =
s"""|SELECT *
|FROM (
| SELECT eel_tmp.*, $hashExpression AS eel_hash_col
| FROM ( $query ) eel_tmp
|)
|WHERE eel_hash_col = $partNum
|""".stripMargin

override def parts(connFn: () => Connection,
query: String,
bindFn: (PreparedStatement) => Unit,
fetchSize: Int,
dialect: JdbcDialect): Seq[Part] = {

for (k <- 0 until numberOfPartitions) yield {
new JdbcPart(connFn, partitionedQuery(k, query), bindFn, fetchSize, dialect)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import com.sksamuel.exts.metrics.Timed
import io.eels.{CloseableIterator, Part, Row}

import scala.util.Try

class JdbcPart(connFn: () => Connection,
query: String,
bindFn: (PreparedStatement) => Unit = stmt => (),
fetchSize: Int = 100,
dialect: JdbcDialect
) extends Part with Timed with JdbcPrimitives {

override def iterator(): CloseableIterator[Seq[Row]] = new CloseableIterator[Seq[Row]] {

private val conn = connFn()
private val stmt = conn.prepareStatement(query)
stmt.setFetchSize(fetchSize)
bindFn(stmt)

private val rs = timed(s"Executing query $query") {
stmt.executeQuery()
}

private val schema = schemaFor(dialect, rs)

override def close(): Unit = {
Try { super.close() }
Try { rs.close() }
Try { conn.close() }
}

override val iterator: Iterator[Seq[Row]] = new Iterator[Row] {

var _hasnext = false

override def hasNext(): Boolean = _hasnext || {
_hasnext = rs.next()
_hasnext
}

override def next(): Row = {
_hasnext = false
val values = schema.fieldNames().map(name => rs.getObject(name))
Row(schema, values)
}

}.grouped(fetchSize).withPartial(true)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Part

trait JdbcPartitionStrategy {
def parts(connFn: () => Connection,
query: String,
bindFn: (PreparedStatement) => Unit,
fetchSize: Int,
dialect: JdbcDialect): Seq[Part]
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,43 @@ import java.sql.{Connection, DriverManager, PreparedStatement}
import com.sksamuel.exts.Logging
import com.sksamuel.exts.io.Using
import com.sksamuel.exts.metrics.Timed
import io.eels.{Part, Source}
import io.eels.schema.StructType
import io.eels.{CloseableIterator, Part, Row, Source}

object JdbcSource {
def apply(url: String, query: String): JdbcSource = JdbcSource(() => DriverManager.getConnection(url), query)
}

case class JdbcSource(connFn: () => Connection,
query: String,
bind: (PreparedStatement) => Unit = stmt => (),
bindFn: (PreparedStatement) => Unit = stmt => (),
fetchSize: Int = 100,
providedSchema: Option[StructType] = None,
providedDialect: Option[JdbcDialect] = None,
bucketing: Option[Bucketing] = None)
partitionStrategy: JdbcPartitionStrategy = SinglePartitionStrategy)
extends Source with JdbcPrimitives with Logging with Using with Timed {

override lazy val schema: StructType = providedSchema.getOrElse(fetchSchema())

def withBind(bind: (PreparedStatement) => Unit): JdbcSource = copy(bind = bind)
def withBind(bind: (PreparedStatement) => Unit): JdbcSource = copy(bindFn = bind)
def withFetchSize(fetchSize: Int): JdbcSource = copy(fetchSize = fetchSize)
def withProvidedSchema(schema: StructType): JdbcSource = copy(providedSchema = Option(schema))
def withProvidedDialect(dialect: JdbcDialect): JdbcSource = copy(providedDialect = Option(dialect))
def withPartitionStrategy(strategy: JdbcPartitionStrategy): JdbcSource = copy(partitionStrategy = strategy)

private def dialect(): JdbcDialect = providedDialect.getOrElse(new GenericJdbcDialect())

override def parts(): List[JdbcPart] = List(new JdbcPart(connFn, query, bind, fetchSize, dialect()))
override def parts(): Seq[Part] = partitionStrategy.parts(connFn, query, bindFn, fetchSize, dialect())

def fetchSchema(): StructType = {
using(connFn()) { conn =>
val schemaQuery = s"SELECT * FROM ($query) tmp WHERE 1=0"
using(conn.prepareStatement(schemaQuery)) { stmt =>

stmt.setFetchSize(fetchSize)
bind(stmt)
bindFn(stmt)

val rs = timed("Executing query $query") {
val rs = timed(s"Executing query $query") {
stmt.executeQuery()
}

Expand All @@ -50,55 +51,4 @@ case class JdbcSource(connFn: () => Connection,
}
}
}
}

case class Bucketing(columnName: String, numberOfBuckets: Int)

class JdbcPart(connFn: () => Connection,
query: String,
bind: (PreparedStatement) => Unit = stmt => (),
fetchSize: Int = 100,
dialect: JdbcDialect
) extends Part with Timed with JdbcPrimitives {

/**
* Returns the data contained in this part in the form of an iterator. This function should return a new
* iterator on each invocation. The iterator can be lazily initialized to the first read if required.
*/
override def iterator(): CloseableIterator[Seq[Row]] = new CloseableIterator[Seq[Row]] {

private val conn = connFn()
private val stmt = conn.prepareStatement(query)
stmt.setFetchSize(fetchSize)
bind(stmt)

private val rs = timed(s"Executing query $query") {
stmt.executeQuery()
}

private val schema = schemaFor(dialect, rs)

override def close(): Unit = {
super.close()
rs.close()
conn.close()
}

override val iterator: Iterator[Seq[Row]] = new Iterator[Row] {

var _hasnext = false

override def hasNext(): Boolean = _hasnext || {
_hasnext = rs.next()
_hasnext
}

override def next(): Row = {
_hasnext = false
val values = schema.fieldNames().map(name => rs.getObject(name))
Row(schema, values)
}

}.grouped(100).withPartial(true)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.eels.component.jdbc

import java.sql.{Connection, PreparedStatement}

import io.eels.Part

case object SinglePartitionStrategy extends JdbcPartitionStrategy {
override def parts(connFn: () => Connection,
query: String,
bindFn: (PreparedStatement) => Unit,
fetchSize: Int,
dialect: JdbcDialect): List[Part] = {
List(new JdbcPart(connFn, query, bindFn, fetchSize, dialect))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.eels.component.jdbc

import java.sql.DriverManager

import org.scalatest.{Matchers, WordSpec}

import scala.util.Random

class BucketPartitionTest extends WordSpec with Matchers {

private val conn = DriverManager.getConnection("jdbc:h2:mem:bucket_test")
conn.createStatement().executeUpdate("create table mytable (a integer)")
for (k <- 0 until 20) {
conn.createStatement().executeUpdate(s"insert into mytable (a) values (${Random.nextInt(10000)})")
}

"BucketPartitionStrategy" should {
"generate evenly spaced ranges" in {
BucketPartitionStrategy("a", 10, 2, 29).ranges shouldBe List(Range.inclusive(2, 4), Range.inclusive(5, 7), Range.inclusive(8, 10), Range.inclusive(11, 13), Range.inclusive(14, 16), Range.inclusive(17, 19), Range.inclusive(20, 22), Range.inclusive(23, 25), Range.inclusive(26, 27), Range.inclusive(28, 29))
BucketPartitionStrategy("a", 2, 2, 30).ranges shouldBe List(Range.inclusive(2, 16), Range.inclusive(17, 30))
BucketPartitionStrategy("a", 1, 4, 5).ranges shouldBe List(Range.inclusive(4, 5))
BucketPartitionStrategy("a", 1, 4, 4).ranges shouldBe List(Range.inclusive(4, 4))
BucketPartitionStrategy("a", 6, 1, 29).ranges shouldBe List(Range.inclusive(1, 5), Range.inclusive(6, 10), Range.inclusive(11, 15), Range.inclusive(16, 20), Range.inclusive(21, 25), Range.inclusive(26, 29))
}
"return correct number of ranges" in {
JdbcSource(() => conn, "select * from mytable")
.withPartitionStrategy(BucketPartitionStrategy("a", 4, 0, 10000))
.parts().size shouldBe 4
}
"return full and non overlapping data" in {
JdbcSource(() => conn, "select * from mytable")
.withPartitionStrategy(BucketPartitionStrategy("a", 4, 0, 10000))
.toFrame().collect().size shouldBe 20
}
}
}

This file was deleted.

0 comments on commit 3f62e55

Please sign in to comment.