Skip to content

Commit

Permalink
Refactor #22 to allow user to specify their own stop strategy predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
vsuthichai committed Aug 14, 2016
1 parent 5f1404e commit 49aa133
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 91 deletions.
11 changes: 11 additions & 0 deletions core/src/main/scala/com/eharmony/spotz/optimizer/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.eharmony.spotz.optimizer

import com.eharmony.spotz.objective.Objective
import org.joda.time.{DateTime, Duration}

import scala.math.Ordering
import scala.reflect.ClassTag
Expand Down Expand Up @@ -95,6 +96,16 @@ trait AbstractOptimizer[P, L, S, R <: OptimizerResult[P, L]] extends Optimizer[P
(implicit c: ClassTag[P], p: ClassTag[L]): R
}

trait OptimizerState[P, L] {
val bestPointSoFar: P
val bestLossSoFar: L
val startTime: DateTime
val currentTime: DateTime
val elapsedTime: Duration
val trialsSoFar: Long
val optimizerFinished: Boolean
}

/**
* Result of an optimizer. All other optimization algorithms' results should inherit from this.
* Minimally, this result contains the best point and the best loss.
Expand Down
41 changes: 17 additions & 24 deletions core/src/main/scala/com/eharmony/spotz/optimizer/StopStrategy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,9 @@ sealed trait StopStrategy extends Serializable {

def getMaxTrials: Long = UNLIMITED
def getMaxDuration: Duration = FOREVER
def shouldStop(trialsSoFar: Long, timeSinceFirstTrial: Duration): Boolean
}

// TODO
/**
* A context object describing the current state of an optimizer. It keeps track of state
* such as the best point, the best loss, elapsed time, trials executed so far, and other
* important information that could be used to specify some stopping criteria for the
* optimizer.
*
* @param foo
* @tparam P
* @tparam L
*/
case class StopContext[P, L](foo: Any)
def shouldStop[P, L](optimizerState: OptimizerState[P, L]): Boolean
}

/**
* Stop after a maximum number of executed trials.
Expand All @@ -35,8 +23,8 @@ case class StopContext[P, L](foo: Any)
class MaxTrialsStop(maxTrials: Long) extends StopStrategy {
assert(maxTrials > 0, "Must specify greater than 0 trials.")
override def getMaxTrials: Long = maxTrials
override def shouldStop(trialsSoFar: Long, durationSinceFirstTrial: Duration): Boolean = {
trialsSoFar >= maxTrials
override def shouldStop[P, L](optimizerState: OptimizerState[P, L]): Boolean = {
optimizerState.trialsSoFar >= maxTrials
}
}

Expand All @@ -49,8 +37,8 @@ class TimedStop(maxDuration: Duration) extends StopStrategy {
assert(maxDuration.toStandardSeconds.getSeconds > 0, "Must specify a longer duration")

override def getMaxDuration: Duration = maxDuration
override def shouldStop(trialsSoFar: Long, durationSinceFirstTrial: Duration): Boolean = {
durationSinceFirstTrial.getMillis >= maxDuration.getMillis
override def shouldStop[P, L](optimizerState: OptimizerState[P, L]): Boolean = {
optimizerState.elapsedTime.getMillis >= maxDuration.getMillis
}
}

Expand All @@ -63,25 +51,30 @@ class TimedStop(maxDuration: Duration) extends StopStrategy {
class MaxTrialsOrMaxDurationStop(maxTrials: Long, maxDuration: Duration) extends StopStrategy {
override def getMaxTrials: Long = maxTrials
override def getMaxDuration: Duration = maxDuration
override def shouldStop(trialsSoFar: Long, durationSinceFirstTrial: Duration): Boolean = {
trialsSoFar >= maxTrials || durationSinceFirstTrial.getMillis >= maxDuration.getMillis
override def shouldStop[P, L](optimizerState: OptimizerState[P, L]): Boolean = {
optimizerState.trialsSoFar >= maxTrials || optimizerState.elapsedTime.getMillis >= maxDuration.getMillis
}
}

/**
* Stop after an optimizer has finished running. This should never be used for RandomSearch because
* it will never complete without some specific stopping criteria.
*/
object OptimizerFinishes extends StopStrategy {
override def shouldStop(trialsSoFar: Long, durationSinceFirstTrial: Duration): Boolean = false
override def shouldStop[P, L](optimizerState: OptimizerState[P, L]): Boolean = {
optimizerState.optimizerFinished
}
}

// TODO
/**
* Stop after some criteria defined by the user.
*
* @param f
* @tparam P
* @tparam L
*/
class StopStrategyPredicate[P, L](f: (StopContext[P, L]) => Boolean) {
def shouldStop(stopContext: StopContext[P, L]) = f(stopContext)
class StopStrategyPredicate[P, L](f: OptimizerState[P, L] => Boolean) {
def shouldStop(stopContext: OptimizerState[P, L]) = f(stopContext)
}

/**
Expand Down
39 changes: 24 additions & 15 deletions core/src/main/scala/com/eharmony/spotz/optimizer/grid/Grid.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,43 @@ class Grid[P](
assert(gridParams.nonEmpty, "No grid parameters have been specified")

/** Expand the each grid row. The memory required is linear in the sum of lengths of all the iterables */
private val gridSpace = gridParams.map { case (label, it) => (label, it.toSeq) } toSeq

/** pre-compute the length of each grid iterable, indexed into a Seq */
private val gridLengths = gridSpace.map { case (label, seq) => seq.length.toLong }

/** pre-compute the divisible factor and store along with the length inside GridProperty */
private val gridProperties = gridLengths.foldRight(ArrayBuffer[GridProperty]()) { case (l, b) =>
if (b.isEmpty) GridProperty(1L, l) +=: b
else GridProperty(b.head.length * b.head.factor, l) +=: b
}
private val gridRows = gridParams.map { case (label, it) => (label, it.toSeq) }
.foldRight(ArrayBuffer[GridRow]()) { case ((label, it), b) =>
if (b.isEmpty) GridRow(label, it, 1L, it.length.toLong) +=: b
else GridRow(label, it, b.head.length * b.head.factor, it.length.toLong) +=: b
}.toIndexedSeq

val length = gridLengths.product
val length = gridRows.foldLeft(1L)((product, gridRow) => product * gridRow.length)
val size = length

info(s"$size hyper parameter tuples found in GridSpace")

/**
*
* @param idx
* @return
*/
def apply(idx: Long): P = {
if (idx < 0 || idx >= size)
throw new IndexOutOfBoundsException(idx.toString)

val gridIndices = gridProperties.map { case GridProperty(factor, this.length) => (idx / factor) % length }
val hyperParamValues = gridIndices.zipWithIndex.map { case (columnIndex, rowIndex) =>
(gridSpace(rowIndex)._1, gridSpace(rowIndex)._2(columnIndex.toInt))
val gridRowIndices = gridRows.map(gridRow => (idx / gridRow.factor) % gridRow.length)

val hyperParamValues = gridRowIndices.zipWithIndex.map { case (columnIndex, rowIndex) =>
(gridRows(rowIndex).label, gridRows(rowIndex).values(columnIndex.toInt))
}.toMap

factory(hyperParamValues)
}
}

case class GridRow(label: String, values: Seq[_])
case class GridProperty(factor: Long, length: Long)
/**
* This case class contains the label name and the values of any specific row within the grid.
*
* @param label
* @param values
* @param factor
* @param length
*/
case class GridRow(label: String, values: Seq[_], factor: Long, length: Long)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.eharmony.spotz.optimizer.grid

import com.eharmony.spotz.backend.{BackendFunctions, ParallelFunctions, SparkFunctions}
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer.AbstractOptimizer
import com.eharmony.spotz.optimizer.{AbstractOptimizer, OptimizerState, StopStrategy}
import com.eharmony.spotz.util.{DurationUtils, Logging}
import org.apache.spark.SparkContext
import org.joda.time.{DateTime, Duration}
Expand All @@ -29,7 +29,8 @@ import scala.reflect.ClassTag
* @author vsuthichai
*/
abstract class GridSearch[P, L]
(trialBatchSize: Int)
(trialBatchSize: Int,
stopStrategy: StopStrategy = StopStrategy.stopWhenOptimizerFinishes)
(implicit ord: Ordering[(P, L)], factory: Map[String, _] => P)
extends AbstractOptimizer[P, L, Map[String, Iterable[Any]], GridSearchResult[P, L]]
with BackendFunctions
Expand All @@ -43,31 +44,75 @@ abstract class GridSearch[P, L]
val startTime = DateTime.now()
val firstPoint = space(0)
val firstLoss = objective(firstPoint)
val currentTime = DateTime.now()

val gridSearchContext = GridSearchContext(
bestPointSoFar = firstPoint,
bestLossSoFar = firstLoss,
startTime = startTime,
currentTime = currentTime,
elapsedTime = new Duration(startTime, currentTime),
trialsSoFar = 1L,
optimizerFinished = 1L >= space.size)

// Last three arguments maintain the best point and loss and the trial count
gridSearch(objective, space, reducer, startTime, firstPoint, firstLoss, 1)
gridSearch(objective, space, reducer, gridSearchContext)
}

@tailrec
private def gridSearch(objective: Objective[P, L], space: Grid[P], reducer: Reducer[(P, L)],
startTime: DateTime, bestPointSoFar: P, bestLossSoFar: L, trialsSoFar: Long)
private def gridSearch(objective: Objective[P, L],
space: Grid[P],
reducer: Reducer[(P, L)],
gsc: GridSearchContext[P, L])
(implicit c: ClassTag[P], p: ClassTag[L]): GridSearchResult[P, L] = {
val endTime = DateTime.now()
val elapsedTime = new Duration(startTime, endTime)

info(s"Best point and loss after $trialsSoFar trials and ${DurationUtils.format(elapsedTime)} : $bestPointSoFar loss: $bestLossSoFar")
info(s"Best point and loss after ${gsc.trialsSoFar} trials and ${DurationUtils.format(gsc.elapsedTime)} : ${gsc.bestPointSoFar} loss: ${gsc.bestLossSoFar}")

stopStrategy.shouldStop(gsc) match {

trialsSoFar >= space.size match {
case true =>
GridSearchResult(bestPointSoFar, bestLossSoFar, startTime, endTime, trialsSoFar, elapsedTime)
// Base case, end recursion, return the result
GridSearchResult[P, L](
bestPoint = gsc.bestPointSoFar,
bestLoss = gsc.bestLossSoFar,
startTime = gsc.startTime,
endTime = gsc.currentTime,
elapsedTime = gsc.elapsedTime,
totalTrials = gsc.trialsSoFar)

case false =>
val batchSize = scala.math.min(space.size - trialsSoFar, trialBatchSize)
val (bestPoint, bestLoss) = reducer((bestPointSoFar, bestLossSoFar), bestGridPointAndLoss(trialsSoFar, batchSize, objective, space, reducer))
gridSearch(objective, space, reducer, startTime, bestPoint, bestLoss, trialsSoFar + batchSize)
val currentTime = DateTime.now()
val elapsedTime = new Duration(gsc.startTime, currentTime)


val batchSize = scala.math.min(space.size - gsc.trialsSoFar, trialBatchSize)

val (bestPoint, bestLoss) = reducer((gsc.bestPointSoFar, gsc.bestLossSoFar), bestGridPointAndLoss(gsc.trialsSoFar, batchSize, objective, space, reducer))
val trialsSoFar = gsc.trialsSoFar + batchSize

val gridSearchContext = GridSearchContext(
bestPointSoFar = bestPoint,
bestLossSoFar = bestLoss,
startTime = gsc.startTime,
currentTime = currentTime,
elapsedTime = elapsedTime,
trialsSoFar = trialsSoFar,
optimizerFinished = trialsSoFar >= space.size)

gridSearch(objective, space, reducer, gridSearchContext)
}
}
}

case class GridSearchContext[P, L](
bestPointSoFar: P,
bestLossSoFar: L,
startTime: DateTime,
currentTime: DateTime,
elapsedTime: Duration,
trialsSoFar: Long,
optimizerFinished: Boolean) extends OptimizerState[P, L]

/**
* Grid search backed by parallel collections.
*
Expand All @@ -78,9 +123,10 @@ abstract class GridSearch[P, L]
* @tparam L loss type representation
*/
class ParGridSearch[P, L](
trialBatchSize: Int = 1000000)
trialBatchSize: Int = 1000000,
stopStrategy: StopStrategy = StopStrategy.stopWhenOptimizerFinishes)
(implicit val ord: Ordering[(P, L)], factory: Map[String, _] => P)
extends GridSearch[P, L](trialBatchSize)(ord, factory)
extends GridSearch[P, L](trialBatchSize, stopStrategy)(ord, factory)
with ParallelFunctions

/**
Expand All @@ -95,7 +141,8 @@ class ParGridSearch[P, L](
*/
class SparkGridSearch[P, L](
@transient val sc: SparkContext,
trialBatchSize: Int = 1000000)
trialBatchSize: Int = 1000000,
stopStrategy: StopStrategy = StopStrategy.stopWhenOptimizerFinishes)
(implicit val ord: Ordering[(P, L)], factory: Map[String, _] => P)
extends GridSearch[P, L](trialBatchSize)(ord, factory)
extends GridSearch[P, L](trialBatchSize, stopStrategy)(ord, factory)
with SparkFunctions
Loading

0 comments on commit 49aa133

Please sign in to comment.