Skip to content

Commit

Permalink
Refactor VW cache distribution #37
Browse files Browse the repository at this point in the history
  • Loading branch information
vsuthichai committed Apr 6, 2017
1 parent 64e3dfc commit b117dad
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 127 deletions.
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>com.jsuereth</groupId>
<artifactId>scala-arm_${scala.major.version}</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ trait AbstractOptimizer[P, L, S, R <: OptimizerResult[P, L]] extends Optimizer[P
* @tparam L
*/
trait OptimizerState[P, L] {
val bestPointSoFar: P
val bestLossSoFar: L
val bestPointSoFar: Option[P]
val bestLossSoFar: Option[L]
val startTime: DateTime
val currentTime: DateTime
val trialsSoFar: Long
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ abstract class GridSearch[P, L](
(implicit c: ClassTag[P], p: ClassTag[L]): GridSearchResult[P, L] = {
val space = new Grid[P](paramSpace)
val startTime = DateTime.now()
/*
val firstPoint = space(0)
val firstLoss = objective(firstPoint)
val currentTime = DateTime.now()
*/

val gridSearchContext = GridSearchContext(
bestPointSoFar = firstPoint,
bestLossSoFar = firstLoss,
val gridSearchContext = GridSearchContext[P, L](
bestPointSoFar = None,
bestLossSoFar = None,
startTime = startTime,
currentTime = currentTime,
trialsSoFar = 1L,
optimizerFinished = 1L >= space.size)
currentTime = startTime,
trialsSoFar = 0L,
optimizerFinished = 0L >= space.size)

// Last three arguments maintain the best point and loss and the trial count
gridSearch(objective, space, reducer, gridSearchContext)
Expand All @@ -72,11 +73,11 @@ abstract class GridSearch[P, L](

info(s"Best point and loss after ${gsc.trialsSoFar} trials and ${DurationUtils.format(gsc.elapsedTime)} : ${gsc.bestPointSoFar} loss: ${gsc.bestLossSoFar}")

if (stopStrategy.shouldStop(gsc)) {
if (stopStrategy.shouldStop(gsc) && gsc.bestLossSoFar.isDefined && gsc.bestPointSoFar.isDefined) {
// Base case, end recursion, return the result
GridSearchResult[P, L](
bestPoint = gsc.bestPointSoFar,
bestLoss = gsc.bestLossSoFar,
bestPoint = gsc.bestPointSoFar.get,
bestLoss = gsc.bestLossSoFar.get,
startTime = gsc.startTime,
endTime = gsc.currentTime,
elapsedTime = gsc.elapsedTime,
Expand All @@ -87,12 +88,19 @@ abstract class GridSearch[P, L](

val batchSize = scala.math.min(space.size - gsc.trialsSoFar, trialBatchSize)

val (bestPoint, bestLoss) = reducer((gsc.bestPointSoFar, gsc.bestLossSoFar), bestGridPointAndLoss(gsc.trialsSoFar, batchSize, objective, space, reducer))
val (point, loss) = bestGridPointAndLoss(gsc.trialsSoFar, batchSize, objective, space, reducer)

val (bestPoint, bestLoss) =
if (gsc.bestPointSoFar.isDefined && gsc.bestLossSoFar.isDefined)
reducer((gsc.bestPointSoFar.get, gsc.bestLossSoFar.get), (point, loss))
else
(point, loss)

val trialsSoFar = gsc.trialsSoFar + batchSize

val gridSearchContext = GridSearchContext(
bestPointSoFar = bestPoint,
bestLossSoFar = bestLoss,
bestPointSoFar = Option(bestPoint),
bestLossSoFar = Option(bestLoss),
startTime = gsc.startTime,
currentTime = currentTime,
trialsSoFar = trialsSoFar,
Expand All @@ -104,8 +112,8 @@ abstract class GridSearch[P, L](
}

case class GridSearchContext[P, L](
bestPointSoFar: P,
bestLossSoFar: L,
bestPointSoFar: Option[P],
bestLossSoFar: Option[L],
startTime: DateTime,
currentTime: DateTime,
trialsSoFar: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.eharmony.spotz.optimizer._
import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import com.eharmony.spotz.util.{DurationUtils, Logging}
import org.apache.spark.SparkContext
import org.joda.time.{DateTime, Duration}
import org.joda.time.DateTime

import scala.annotation.tailrec
import scala.math.Ordering
Expand Down Expand Up @@ -61,16 +61,13 @@ abstract class RandomSearch[P, L](
reducer: Reducer[(P, L)])
(implicit c: ClassTag[P], p: ClassTag[L]): RandomSearchResult[P, L] = {
val startTime = DateTime.now()
val firstPoint = sample(paramSpace, seed)
val firstLoss = objective(firstPoint)
val currentTime = DateTime.now()

val randomSearchContext = RandomSearchContext(
bestPointSoFar = firstPoint,
bestLossSoFar = firstLoss,
val randomSearchContext = RandomSearchContext[P, L](
bestPointSoFar = None,
bestLossSoFar = None,
startTime = startTime,
currentTime = currentTime,
trialsSoFar = 1,
currentTime = startTime,
trialsSoFar = 0,
optimizerFinished = false)

randomSearch(objective, reducer, paramSpace, randomSearchContext)
Expand All @@ -97,11 +94,11 @@ abstract class RandomSearch[P, L](

info(s"Best point and loss after ${rsc.trialsSoFar} trials and ${DurationUtils.format(rsc.elapsedTime)} : ${rsc.bestPointSoFar} loss: ${rsc.bestLossSoFar}")

if (stopStrategy.shouldStop(rsc)) {
if (stopStrategy.shouldStop(rsc) && rsc.bestPointSoFar.isDefined && rsc.bestLossSoFar.isDefined) {
// Base case, end recursion, return the result
new RandomSearchResult[P, L](
bestPoint = rsc.bestPointSoFar,
bestLoss = rsc.bestLossSoFar,
bestPoint = rsc.bestPointSoFar.get,
bestLoss = rsc.bestLossSoFar.get,
startTime = rsc.startTime,
endTime = rsc.currentTime,
elapsedTime = rsc.elapsedTime,
Expand All @@ -111,14 +108,19 @@ abstract class RandomSearch[P, L](
// TODO: Adaptive batch sizing
//val batchSize = nextBatchSize(None, elapsedTime, currentBatchSize, trialsSoFar, null, stopStrategy.getMaxTrials)

val (bestPoint, bestLoss) = reducer((rsc.bestPointSoFar, rsc.bestLossSoFar),
bestRandomPointAndLoss(rsc.trialsSoFar, batchSize, objective, reducer, paramSpace, sample, seed))
val (point, loss) = bestRandomPointAndLoss(rsc.trialsSoFar, batchSize, objective, reducer, paramSpace, sample, seed)

val (bestPoint, bestLoss) =
if (rsc.bestPointSoFar.isDefined && rsc.bestLossSoFar.isDefined)
reducer((rsc.bestPointSoFar.get, rsc.bestLossSoFar.get), (point, loss))
else
(point, loss)

val currentTime = DateTime.now()

val randomSearchContext = RandomSearchContext(
bestPointSoFar = bestPoint,
bestLossSoFar = bestLoss,
bestPointSoFar = Option(bestPoint),
bestLossSoFar = Option(bestLoss),
startTime = rsc.startTime,
currentTime = currentTime,
trialsSoFar = rsc.trialsSoFar + batchSize,
Expand All @@ -145,8 +147,8 @@ abstract class RandomSearch[P, L](
* @tparam L loss type
*/
case class RandomSearchContext[P, L](
bestPointSoFar: P,
bestLossSoFar: L,
bestPointSoFar: Option[P],
bestLossSoFar: Option[L],
startTime: DateTime,
currentTime: DateTime,
trialsSoFar: Long,
Expand Down
39 changes: 16 additions & 23 deletions core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.eharmony.spotz.util

import java.io.{File, PrintWriter}
import java.io.File

import org.apache.spark.{SparkContext, SparkFiles}

Expand All @@ -13,38 +13,27 @@ import org.apache.spark.{SparkContext, SparkFiles}
* Later when the objection function is being parallelized, the file can be retrieved with the
* <code>get</method> inside the <code>apply</code> method.
*/
trait FileFunctions {
def save(inputPath: String): String = save(FileUtil.loadFile(inputPath))

def save(inputIterable: Iterable[String]): String = save(inputIterable.toIterator)

def save(inputIterator: Iterator[String]): String = {
val tempFile = FileUtil.tempFile("file.temp")
val printWriter = new PrintWriter(tempFile)
inputIterator.foreach(line => printWriter.println(line))
printWriter.close()
save(tempFile)
}

def save(file: File): String
def get(name: String): File
}
// trait FileSystemFunctions extends LocalFileSystemFunctions with SparkFileFunctions

/**
* This trait is intended for handling files when parallel collections are used to do the computation.
* It interacts directly with the file system since parallel collections run on a single node. Calling
* <code>save</code> on a file will return a key that can later be used to retrieve that same file
* later inside the <code>apply</code> method of the objective function.
*/
trait FileSystemFunctions extends FileFunctions {
trait LocalFileSystemFunctions {
private lazy val nameToAbsPath = scala.collection.mutable.Map[String, String]()

override def save(file: File): String = {
def saveLocally(inputPath: String): String = saveLocally(new File(inputPath))
def saveLocally(inputIterable: Iterable[String]): String = saveLocally(inputIterable.toIterator)
def saveLocally(inputIterator: Iterator[String]): String = saveLocally(FileUtil.tempFile(inputIterator))

def saveLocally(file: File): String = {
nameToAbsPath += ((file.getName, file.getAbsolutePath))
file.getName
}

override def get(name: String): File = new File(nameToAbsPath(name))
def getLocally(name: String): File = new File(nameToAbsPath(name))
}

/**
Expand All @@ -64,13 +53,17 @@ trait FileSystemFunctions extends FileFunctions {
* accessed from the <code>apply</code> method of the objective function as it's executing on the worker
* through this same trait's <code>get</code> method.
*/
trait SparkFileFunctions extends FileFunctions {
trait SparkFileFunctions {
val sc: SparkContext

override def save(file: File): String = {
def saveToSparkFiles(inputPath: String): String = saveToSparkFiles(new File(inputPath))
def saveToSparkFiles(inputIterable: Iterable[String]): String = saveToSparkFiles(inputIterable.toIterator)
def saveToSparkFiles(inputIterator: Iterator[String]): String = saveToSparkFiles(FileUtil.tempFile(inputIterator))

def saveToSparkFiles(file: File): String = {
sc.addFile(file.getAbsolutePath)
file.getName
}

override def get(name: String): File = new File(SparkFiles.get(name))
def getFromSparkFiles(name: String): File = new File(SparkFiles.get(name))
}
Loading

0 comments on commit b117dad

Please sign in to comment.