From b117dad053a6b358a399e5504bae7f0c4252f818 Mon Sep 17 00:00:00 2001 From: vsuthichai Date: Thu, 6 Apr 2017 11:27:31 -0700 Subject: [PATCH] Refactor VW cache distribution #37 --- core/pom.xml | 4 + .../eharmony/spotz/optimizer/Optimizer.scala | 4 +- .../spotz/optimizer/grid/GridSearch.scala | 38 +++--- .../spotz/optimizer/random/RandomSearch.scala | 38 +++--- .../eharmony/spotz/util/FileFunctions.scala | 39 +++---- .../com/eharmony/spotz/util/FileUtil.scala | 105 +++++++++++++++-- pom.xml | 6 + .../vw/VwCrossValidationObjective.scala | 108 ++++++++++++++---- .../objective/vw/VwHoldoutObjective.scala | 8 +- .../spotz/objective/vw/VwProcess.scala | 10 +- .../vw/util/VwDatasetFunctions.scala | 49 +++----- 11 files changed, 282 insertions(+), 127 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 07c0b34..8c39a67 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -50,6 +50,10 @@ ch.qos.logback logback-classic + + com.jsuereth + scala-arm_${scala.major.version} + junit junit diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/Optimizer.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/Optimizer.scala index 464b810..d515175 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/Optimizer.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/Optimizer.scala @@ -103,8 +103,8 @@ trait AbstractOptimizer[P, L, S, R <: OptimizerResult[P, L]] extends Optimizer[P * @tparam L */ trait OptimizerState[P, L] { - val bestPointSoFar: P - val bestLossSoFar: L + val bestPointSoFar: Option[P] + val bestLossSoFar: Option[L] val startTime: DateTime val currentTime: DateTime val trialsSoFar: Long diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/grid/GridSearch.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/grid/GridSearch.scala index 4e16930..84f66a5 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/grid/GridSearch.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/grid/GridSearch.scala @@ -47,17 +47,18 @@ abstract class GridSearch[P, L]( (implicit c: ClassTag[P], p: ClassTag[L]): GridSearchResult[P, L] = { val space = new Grid[P](paramSpace) val startTime = DateTime.now() + /* val firstPoint = space(0) val firstLoss = objective(firstPoint) - val currentTime = DateTime.now() + */ - val gridSearchContext = GridSearchContext( - bestPointSoFar = firstPoint, - bestLossSoFar = firstLoss, + val gridSearchContext = GridSearchContext[P, L]( + bestPointSoFar = None, + bestLossSoFar = None, startTime = startTime, - currentTime = currentTime, - trialsSoFar = 1L, - optimizerFinished = 1L >= space.size) + currentTime = startTime, + trialsSoFar = 0L, + optimizerFinished = 0L >= space.size) // Last three arguments maintain the best point and loss and the trial count gridSearch(objective, space, reducer, gridSearchContext) @@ -72,11 +73,11 @@ abstract class GridSearch[P, L]( info(s"Best point and loss after ${gsc.trialsSoFar} trials and ${DurationUtils.format(gsc.elapsedTime)} : ${gsc.bestPointSoFar} loss: ${gsc.bestLossSoFar}") - if (stopStrategy.shouldStop(gsc)) { + if (stopStrategy.shouldStop(gsc) && gsc.bestLossSoFar.isDefined && gsc.bestPointSoFar.isDefined) { // Base case, end recursion, return the result GridSearchResult[P, L]( - bestPoint = gsc.bestPointSoFar, - bestLoss = gsc.bestLossSoFar, + bestPoint = gsc.bestPointSoFar.get, + bestLoss = gsc.bestLossSoFar.get, startTime = gsc.startTime, endTime = gsc.currentTime, elapsedTime = gsc.elapsedTime, @@ -87,12 +88,19 @@ abstract class GridSearch[P, L]( val batchSize = scala.math.min(space.size - gsc.trialsSoFar, trialBatchSize) - val (bestPoint, bestLoss) = reducer((gsc.bestPointSoFar, gsc.bestLossSoFar), bestGridPointAndLoss(gsc.trialsSoFar, batchSize, objective, space, reducer)) + val (point, loss) = bestGridPointAndLoss(gsc.trialsSoFar, batchSize, objective, space, reducer) + + val (bestPoint, bestLoss) = + if (gsc.bestPointSoFar.isDefined && gsc.bestLossSoFar.isDefined) + reducer((gsc.bestPointSoFar.get, gsc.bestLossSoFar.get), (point, loss)) + else + (point, loss) + val trialsSoFar = gsc.trialsSoFar + batchSize val gridSearchContext = GridSearchContext( - bestPointSoFar = bestPoint, - bestLossSoFar = bestLoss, + bestPointSoFar = Option(bestPoint), + bestLossSoFar = Option(bestLoss), startTime = gsc.startTime, currentTime = currentTime, trialsSoFar = trialsSoFar, @@ -104,8 +112,8 @@ abstract class GridSearch[P, L]( } case class GridSearchContext[P, L]( - bestPointSoFar: P, - bestLossSoFar: L, + bestPointSoFar: Option[P], + bestLossSoFar: Option[L], startTime: DateTime, currentTime: DateTime, trialsSoFar: Long, diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala index 6fca4a0..b4e09ec 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala @@ -6,7 +6,7 @@ import com.eharmony.spotz.optimizer._ import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import com.eharmony.spotz.util.{DurationUtils, Logging} import org.apache.spark.SparkContext -import org.joda.time.{DateTime, Duration} +import org.joda.time.DateTime import scala.annotation.tailrec import scala.math.Ordering @@ -61,16 +61,13 @@ abstract class RandomSearch[P, L]( reducer: Reducer[(P, L)]) (implicit c: ClassTag[P], p: ClassTag[L]): RandomSearchResult[P, L] = { val startTime = DateTime.now() - val firstPoint = sample(paramSpace, seed) - val firstLoss = objective(firstPoint) - val currentTime = DateTime.now() - val randomSearchContext = RandomSearchContext( - bestPointSoFar = firstPoint, - bestLossSoFar = firstLoss, + val randomSearchContext = RandomSearchContext[P, L]( + bestPointSoFar = None, + bestLossSoFar = None, startTime = startTime, - currentTime = currentTime, - trialsSoFar = 1, + currentTime = startTime, + trialsSoFar = 0, optimizerFinished = false) randomSearch(objective, reducer, paramSpace, randomSearchContext) @@ -97,11 +94,11 @@ abstract class RandomSearch[P, L]( info(s"Best point and loss after ${rsc.trialsSoFar} trials and ${DurationUtils.format(rsc.elapsedTime)} : ${rsc.bestPointSoFar} loss: ${rsc.bestLossSoFar}") - if (stopStrategy.shouldStop(rsc)) { + if (stopStrategy.shouldStop(rsc) && rsc.bestPointSoFar.isDefined && rsc.bestLossSoFar.isDefined) { // Base case, end recursion, return the result new RandomSearchResult[P, L]( - bestPoint = rsc.bestPointSoFar, - bestLoss = rsc.bestLossSoFar, + bestPoint = rsc.bestPointSoFar.get, + bestLoss = rsc.bestLossSoFar.get, startTime = rsc.startTime, endTime = rsc.currentTime, elapsedTime = rsc.elapsedTime, @@ -111,14 +108,19 @@ abstract class RandomSearch[P, L]( // TODO: Adaptive batch sizing //val batchSize = nextBatchSize(None, elapsedTime, currentBatchSize, trialsSoFar, null, stopStrategy.getMaxTrials) - val (bestPoint, bestLoss) = reducer((rsc.bestPointSoFar, rsc.bestLossSoFar), - bestRandomPointAndLoss(rsc.trialsSoFar, batchSize, objective, reducer, paramSpace, sample, seed)) + val (point, loss) = bestRandomPointAndLoss(rsc.trialsSoFar, batchSize, objective, reducer, paramSpace, sample, seed) + + val (bestPoint, bestLoss) = + if (rsc.bestPointSoFar.isDefined && rsc.bestLossSoFar.isDefined) + reducer((rsc.bestPointSoFar.get, rsc.bestLossSoFar.get), (point, loss)) + else + (point, loss) val currentTime = DateTime.now() val randomSearchContext = RandomSearchContext( - bestPointSoFar = bestPoint, - bestLossSoFar = bestLoss, + bestPointSoFar = Option(bestPoint), + bestLossSoFar = Option(bestLoss), startTime = rsc.startTime, currentTime = currentTime, trialsSoFar = rsc.trialsSoFar + batchSize, @@ -145,8 +147,8 @@ abstract class RandomSearch[P, L]( * @tparam L loss type */ case class RandomSearchContext[P, L]( - bestPointSoFar: P, - bestLossSoFar: L, + bestPointSoFar: Option[P], + bestLossSoFar: Option[L], startTime: DateTime, currentTime: DateTime, trialsSoFar: Long, diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala index 719dc01..79a6280 100644 --- a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala +++ b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala @@ -1,6 +1,6 @@ package com.eharmony.spotz.util -import java.io.{File, PrintWriter} +import java.io.File import org.apache.spark.{SparkContext, SparkFiles} @@ -13,22 +13,7 @@ import org.apache.spark.{SparkContext, SparkFiles} * Later when the objection function is being parallelized, the file can be retrieved with the * get inside the apply method. */ -trait FileFunctions { - def save(inputPath: String): String = save(FileUtil.loadFile(inputPath)) - - def save(inputIterable: Iterable[String]): String = save(inputIterable.toIterator) - - def save(inputIterator: Iterator[String]): String = { - val tempFile = FileUtil.tempFile("file.temp") - val printWriter = new PrintWriter(tempFile) - inputIterator.foreach(line => printWriter.println(line)) - printWriter.close() - save(tempFile) - } - - def save(file: File): String - def get(name: String): File -} +// trait FileSystemFunctions extends LocalFileSystemFunctions with SparkFileFunctions /** * This trait is intended for handling files when parallel collections are used to do the computation. @@ -36,15 +21,19 @@ trait FileFunctions { * save on a file will return a key that can later be used to retrieve that same file * later inside the apply method of the objective function. */ -trait FileSystemFunctions extends FileFunctions { +trait LocalFileSystemFunctions { private lazy val nameToAbsPath = scala.collection.mutable.Map[String, String]() - override def save(file: File): String = { + def saveLocally(inputPath: String): String = saveLocally(new File(inputPath)) + def saveLocally(inputIterable: Iterable[String]): String = saveLocally(inputIterable.toIterator) + def saveLocally(inputIterator: Iterator[String]): String = saveLocally(FileUtil.tempFile(inputIterator)) + + def saveLocally(file: File): String = { nameToAbsPath += ((file.getName, file.getAbsolutePath)) file.getName } - override def get(name: String): File = new File(nameToAbsPath(name)) + def getLocally(name: String): File = new File(nameToAbsPath(name)) } /** @@ -64,13 +53,17 @@ trait FileSystemFunctions extends FileFunctions { * accessed from the apply method of the objective function as it's executing on the worker * through this same trait's get method. */ -trait SparkFileFunctions extends FileFunctions { +trait SparkFileFunctions { val sc: SparkContext - override def save(file: File): String = { + def saveToSparkFiles(inputPath: String): String = saveToSparkFiles(new File(inputPath)) + def saveToSparkFiles(inputIterable: Iterable[String]): String = saveToSparkFiles(inputIterable.toIterator) + def saveToSparkFiles(inputIterator: Iterator[String]): String = saveToSparkFiles(FileUtil.tempFile(inputIterator)) + + def saveToSparkFiles(file: File): String = { sc.addFile(file.getAbsolutePath) file.getName } - override def get(name: String): File = new File(SparkFiles.get(name)) + def getFromSparkFiles(name: String): File = new File(SparkFiles.get(name)) } diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala index 8d7145d..2205ab2 100644 --- a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala +++ b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala @@ -1,11 +1,20 @@ package com.eharmony.spotz.util -import java.io.{File, InputStream} +import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} import org.apache.commons.io.FilenameUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.vfs2.{FileNotFoundException, FileSystemManager, VFS} +import org.apache.spark.rdd.RDD +import org.joda.time.DateTime +import org.joda.time.format.DateTimeFormat +import resource.managed +import scala.io.Codec import scala.io.Source +import scala.util.{Success, Try} +import sys.process._ /** * @author vsuthichai @@ -14,6 +23,66 @@ object FileUtil { private val vfs2: FileSystemManager = VFS.getManager private val pwd = new File(System.getProperty("user.dir")) + /** + * + * @param absolutePath + * @return + */ + def gzip(absolutePath: String): String = { + val is = loadFileInputStream(absolutePath) + val fos = new FileOutputStream(absolutePath + ".gz") + val gzipfos = new GZIPOutputStream(fos) + + val buffer = new Array[Byte](4096) + var bytes_read = is.read(buffer) + while (bytes_read > 0) { + gzipfos.write(buffer, 0, bytes_read) + bytes_read = is.read(buffer) + } + + is.close() + gzipfos.finish() + gzipfos.close() + + absolutePath + ".gz" + } + + /** + * + * @param absolutePath + * @return + */ + def gunzip(absolutePath: String): String = { + val gzis = new GZIPInputStream(new FileInputStream(absolutePath)) + + val outputFilename = if (absolutePath.endsWith(".gz")) { + val filename = absolutePath.substring(0, absolutePath.length - 3) + val ext = FilenameUtils.getExtension(filename) + if (StringUtils.isEmpty(ext)) + tempFile(filename, "txt", true).getAbsolutePath + else { + val filenameWithoutExt = filename.substring(0, filename.indexOf('.')) + tempFile(filenameWithoutExt, ext, true).getAbsolutePath + } + } else { + tempFile(absolutePath, "txt", true).getAbsolutePath + } + + val out = new FileOutputStream(outputFilename) + + val buffer = new Array[Byte](4096) + var bytes_read = gzis.read(buffer) + while (bytes_read > 0) { + out.write(buffer, 0, bytes_read) + bytes_read = gzis.read(buffer) + } + + gzis.close() + out.close() + + outputFilename + } + /** * Return a file with a filename guaranteed not to be used on the file system. This is * mainly used for files with a lifetime of a jvm run. @@ -42,18 +111,23 @@ object FileUtil { tempFile(FilenameUtils.getBaseName(filename), FilenameUtils.getExtension(filename), deleteOnExit) } + def tempFile(inputIterator: Iterator[String]): File = { + val tempFile = FileUtil.tempFile("file.temp") + val printWriter = new PrintWriter(tempFile) + inputIterator.foreach(line => printWriter.println(line)) + printWriter.close() + tempFile + } + /** * Load the lines of a file as an iterator. * * @param path input path * @return lines of the file as an Iterator[String] */ - def loadFile(path: String): Iterator[String] = { + def fileLinesIterator(path: String): Iterator[String] = { val is = loadFileInputStream(path) - // Force reading the entire file instead of reading it lazily - val lines = Source.fromInputStream(is).getLines().toSeq.toIterator - is.close() - lines + Source.fromInputStream(is)(Codec("UTF-8")).getLines() } /** @@ -80,7 +154,7 @@ object SparkFileUtil { */ def loadFile(sc: SparkContext, path: String): Iterator[String] = { try { - FileUtil.loadFile(path) + FileUtil.fileLinesIterator(path) } catch { case e: FileNotFoundException => try { @@ -90,4 +164,21 @@ object SparkFileUtil { } } } + + /** + * TODO Fix this. It's too inefficient to save an RDD to hdfs and merge it to a local file. + * @param sc + * @param rdd + * @return + */ + def saveToLocalFile(sc: SparkContext, rdd: RDD[String]): String = { + val dtf = DateTimeFormat.forPattern("yyyyMMdd-HHmmss") + val dt = DateTime.now() + val hdfsPath = s"hdfs:///tmp/spotz-vw-dataset-${dtf.print(dt)}" + val tempFile = FileUtil.tempFile(s"spotz-vw-dataset-${dtf.print(dt)}", "txt", deleteOnExit = true) + rdd.saveAsTextFile(hdfsPath) + s"hdfs dfs -getmerge $hdfsPath ${tempFile.getAbsolutePath}".! + s"hdfs dfs -rm -r $hdfsPath".! + tempFile.getAbsolutePath + } } diff --git a/pom.xml b/pom.xml index 6894d7f..02a1c11 100644 --- a/pom.xml +++ b/pom.xml @@ -90,6 +90,7 @@ 2.1 1.7.21 1.1.7 + 1.4 4.11 1.3 3.0.0 @@ -173,6 +174,11 @@ combinatoricslib3 ${combinatoricslib3.version} + + com.jsuereth + scala-arm_${scala.major.version} + ${scala.arm.version} + org.slf4j slf4j-api diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala index 7b9222d..4bcbf86 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala @@ -1,22 +1,25 @@ package com.eharmony.spotz.objective.vw +import java.io.File + import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwCrossValidation} -import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil} +import com.eharmony.spotz.objective.vw.util.VwCrossValidation +import com.eharmony.spotz.util._ import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD /** * Perform K Fold cross validation given a dataset formatted for Vowpal Wabbit. * * @param numFolds - * @param vwDataset + * @param vwDatasetPath * @param vwTrainParamsString * @param vwTestParamsString */ abstract class AbstractVwCrossValidationObjective( val numFolds: Int, - @transient val vwDataset: Iterator[String], + val vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) extends Objective[Point, Double] @@ -24,10 +27,47 @@ abstract class AbstractVwCrossValidationObjective( with VwCrossValidation with Logging { + val localMode: Boolean + val vwTrainParamsMap = parseVwArgs(vwTrainParamsString) val vwTestParamsMap = parseVwArgs(vwTestParamsString) - val foldToVwCacheFiles = kFold(vwDataset, numFolds, vwTrainParamsMap) + // Gzip VW dataset + info(s"Gzip $vwDatasetPath") + val gzipVwDatasetFilename = FileUtil.gzip(vwDatasetPath) + + // Save VW dataset to executor or locally + info(s"Saving gzipped VW dataset to executors $gzipVwDatasetFilename") + val gzippedVwDatasetFilenameOnExecutor = save(gzipVwDatasetFilename) + + // Lazily initialize K-fold if utilizing cluster + lazy val lazyFoldToVwCacheFiles = Option(initKFold()) + + // Initialize K-fold non-lazily in local mode + val nonLazyFoldToVwCacheFiles = { + if (localMode) { + info("Operating in local mode") + val kFoldMap = initKFold() + Option(kFoldMap) + } else { + info("Operating in non-local mode") + None + } + } + + def initKFold(): Map[Int, (String, String)] = { + info(s"Retrieving gzipped VW dataset on executor $gzippedVwDatasetFilenameOnExecutor") + val file = get(gzippedVwDatasetFilenameOnExecutor) + + val unzippedFilename = FileUtil.gunzip(file.getAbsolutePath) + info(s"Unzipped ${file.getAbsolutePath} to $unzippedFilename") + + info(s"Creating K Fold cache files from $unzippedFilename") + kFold(unzippedFilename, numFolds, vwTrainParamsMap) + } + + def save(filename: String): String + def get(filename: String): File /** * This method can run on the driver and/or the executor. It performs a k-fold cross validation @@ -44,9 +84,17 @@ abstract class AbstractVwCrossValidationObjective( info(s"Vw Training Params: $vwTrainParams") info(s"Vw Testing Params: $vwTestParams") + val foldToVwCacheFiles = if (localMode) { + nonLazyFoldToVwCacheFiles + } else { + lazyFoldToVwCacheFiles + } + + assert(foldToVwCacheFiles.isDefined, "Unable to initialize K Fold cross validation") + val avgLosses = (0 until numFolds).map { fold => // Retrieve the training and test set cache for this fold. - val (vwTrainFilename, vwTestFilename) = foldToVwCacheFiles(fold) + val (vwTrainFilename, vwTestFilename) = foldToVwCacheFiles.get(fold) val vwTrainFile = getCache(vwTrainFilename) val vwTestFile = getCache(vwTestFilename) @@ -85,48 +133,62 @@ abstract class AbstractVwCrossValidationObjective( class SparkVwCrossValidationObjective( @transient val sc: SparkContext, numFolds: Int, - vwDataset: Iterator[String], + vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) - extends AbstractVwCrossValidationObjective(numFolds, vwDataset, vwTrainParamsString, vwTestParamsString) - with SparkVwDatasetFunctions { + extends AbstractVwCrossValidationObjective(numFolds, vwDatasetPath, vwTrainParamsString, vwTestParamsString) + with SparkFileFunctions { + + override lazy val localMode = sc.isLocal def this(sc: SparkContext, numFolds: Int, - vwDataset: Iterable[String], + vwDatasetIterator: Iterator[String], vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(sc, numFolds, vwDataset.toIterator, vwTrainParamsString, vwTestParamsString) + this(sc, numFolds, FileUtil.tempFile(vwDatasetIterator).getAbsolutePath, vwTrainParamsString, vwTestParamsString) } + def this(sc: SparkContext, numFolds: Int, - vwDatasetPath: String, + @transient vwDataset: RDD[String], vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(sc, numFolds, SparkFileUtil.loadFile(sc, vwDatasetPath), vwTrainParamsString, vwTestParamsString) + this(sc, numFolds, SparkFileUtil.saveToLocalFile(sc, vwDataset), vwTrainParamsString, vwTestParamsString) + } + + override def save(filename: String): String = { + saveToSparkFiles(filename) + } + + override def get(filename: String): File = { + getFromSparkFiles(filename) } } class VwCrossValidationObjective( numFolds: Int, - vwDataset: Iterator[String], + vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) - extends AbstractVwCrossValidationObjective(numFolds, vwDataset, vwTrainParamsString, vwTestParamsString) - with FSVwDatasetFunctions { + extends AbstractVwCrossValidationObjective(numFolds, vwDatasetPath, vwTrainParamsString, vwTestParamsString) + with LocalFileSystemFunctions { + + override lazy val localMode = true def this(numFolds: Int, - vwDataset: Iterable[String], + vwDatasetIterator: Iterator[String], vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(numFolds, vwDataset.toIterator, vwTrainParamsString, vwTestParamsString) + this(numFolds, FileUtil.tempFile(vwDatasetIterator).getAbsolutePath, vwTrainParamsString, vwTestParamsString) } - def this(numFolds: Int, - vwDatasetPath: String, - vwTrainParamsString: Option[String], - vwTestParamsString: Option[String]) = { - this(numFolds, FileUtil.loadFile(vwDatasetPath), vwTrainParamsString, vwTestParamsString) + override def save(filename: String): String = { + saveLocally(filename) + } + + override def get(filename: String): File = { + getLocally(filename) } } \ No newline at end of file diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala index 348c22e..ff2a7f5 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala @@ -2,7 +2,7 @@ package com.eharmony.spotz.objective.vw import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwDatasetFunctions} +import com.eharmony.spotz.objective.vw.util.VwDatasetFunctions import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil} import org.apache.spark.SparkContext @@ -61,7 +61,6 @@ class SparkVwHoldoutObjective( vwTestSetIterator: Iterator[String], vwTestParamsString: Option[String]) extends AbstractVwHoldoutObjective(vwTrainSetIterator, vwTrainParamsString, vwTestSetIterator, vwTestParamsString) - with SparkVwDatasetFunctions with Logging { def this(sc: SparkContext, @@ -86,8 +85,7 @@ class VwHoldoutObjective( vwTrainParamsString: Option[String], vwTestSetIterator: Iterator[String], vwTestParamsString: Option[String]) - extends AbstractVwHoldoutObjective(vwTrainSetIterator, vwTrainParamsString, vwTestSetIterator, vwTestParamsString) - with FSVwDatasetFunctions { + extends AbstractVwHoldoutObjective(vwTrainSetIterator, vwTrainParamsString, vwTestSetIterator, vwTestParamsString) { def this(vwTrainSetIterable: Iterable[String], vwTrainParamsString: Option[String], @@ -100,7 +98,7 @@ class VwHoldoutObjective( vwTrainParamsString: Option[String], vwTestSetPath: String, vwTestParamsString: Option[String]) = { - this(FileUtil.loadFile(vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString) + this(FileUtil.fileLinesIterator(vwTrainSetPath), vwTrainParamsString, FileUtil.fileLinesIterator(vwTestSetPath), vwTestParamsString) } } diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwProcess.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwProcess.scala index 5ac5c8d..7895d1a 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwProcess.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwProcess.scala @@ -56,15 +56,17 @@ case class VwResult( object VwProcess { val avgLossRegex = s"average\\s+loss\\s+=\\s+($floatingPointRegex)".r - def generateCache(inputStream: InputStream, cachePath: String, cacheParams: String) { + def generateCache(inputStream: InputStream, cachePath: String, cacheParams: String): VwResult = { val vwCacheProcess = VwProcess(s"-k --cache_file $cachePath $cacheParams", Option(inputStream)) val vwCacheResult = vwCacheProcess() assert(vwCacheResult.exitCode == 0, s"VW Training cache exited with non-zero exit code ${vwCacheResult.exitCode}") + + vwCacheResult } - def generateCache(vwDatasetIterator: Iterator[String], cachePath: String, cacheParams: String) { + def generateCache(vwDatasetIterator: Iterator[String], cachePath: String, cacheParams: String): VwResult = { val pos = new PipedOutputStream val pis = new PipedInputStream(pos) val pw = new PrintWriter(pos, true) @@ -77,12 +79,14 @@ object VwProcess { generateCache(pis, cachePath, cacheParams) } - def generateCache(vwDatasetPath: String, cachePath: String, cacheParams: String) { + def generateCache(vwDatasetPath: String, cachePath: String, cacheParams: String): VwResult = { val vwCacheProcess = VwProcess(s"-k --cache_file $cachePath -d $vwDatasetPath $cacheParams", None) val vwCacheResult = vwCacheProcess() assert(vwCacheResult.exitCode == 0, s"VW Training cache exited with non-zero exit code ${vwCacheResult.exitCode}") + + vwCacheResult } } diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala index df7000a..3e2b30c 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala @@ -3,7 +3,7 @@ package com.eharmony.spotz.objective.vw.util import java.io._ import com.eharmony.spotz.objective.vw.VwProcess -import com.eharmony.spotz.util.{FileFunctions, FileSystemFunctions, FileUtil, SparkFileFunctions} +import com.eharmony.spotz.util.{FileUtil, LocalFileSystemFunctions, Logging} import scala.collection.mutable @@ -13,25 +13,29 @@ import scala.collection.mutable * If Spark is not being used, ie. parallel collections are being used, then the cache file * is just saved locally to the file system. */ -trait VwDatasetFunctions extends FileFunctions { +trait VwDatasetFunctions extends LocalFileSystemFunctions with Logging { def saveAsCache(vwDatasetInputStream: InputStream, vwCacheFilename: String, vwParamsMap: Map[String, _]): String = { val vwCacheFile = FileUtil.tempFile(vwCacheFilename) - VwProcess.generateCache(vwDatasetInputStream, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) - save(vwCacheFile) + val vwResult = VwProcess.generateCache(vwDatasetInputStream, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) + info(s"VW cache generation stderr ${vwResult.stderr}") + saveLocally(vwCacheFile) vwCacheFile.getName } def saveAsCache(vwDatasetIterator: Iterator[String], vwCacheFilename: String, vwParamsMap: Map[String, _]): String = { - val vwCacheFile = FileUtil.tempFile(vwCacheFilename, false) - VwProcess.generateCache(vwDatasetIterator, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) - save(vwCacheFile) + val vwCacheFile = FileUtil.tempFile(vwCacheFilename) + val vwResult = VwProcess.generateCache(vwDatasetIterator, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) + info(s"VW cache generation stderr ${vwResult.stderr}") + + saveLocally(vwCacheFile) vwCacheFile.getName } def saveAsCache(vwDatasetPath: String, vwCacheFilename: String, vwParamsMap: Map[String, _]): String = { val vwCacheFile = FileUtil.tempFile(vwCacheFilename) - VwProcess.generateCache(vwDatasetPath, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) - save(vwCacheFile) + val vwResult = VwProcess.generateCache(vwDatasetPath, vwCacheFile.getAbsolutePath, cacheParams(vwParamsMap)) + info(s"VW cache generation stderr ${vwResult.stderr}") + saveLocally(vwCacheFile) vwCacheFile.getName } @@ -47,21 +51,7 @@ trait VwDatasetFunctions extends FileFunctions { } } - def getCache(name: String): File = get(name) -} - -/** - * Save VW Cache to file system. - */ -trait FSVwDatasetFunctions extends VwDatasetFunctions with FileSystemFunctions { - override def getCache(name: String) = get(name) -} - -/** - * Add VW Cache to SparkContext. - */ -trait SparkVwDatasetFunctions extends VwDatasetFunctions with SparkFileFunctions { - override def getCache(name: String) = get(name) + def getCache(name: String): File = getLocally(name) } /** @@ -69,12 +59,9 @@ trait SparkVwDatasetFunctions extends VwDatasetFunctions with SparkFileFunctions */ trait VwCrossValidation extends VwDatasetFunctions { def kFold(inputPath: String, folds: Int, vwParamsMap: Map[String, _]): Map[Int, (String, String)] = { - val enumeratedVwInput = FileUtil.loadFile(inputPath) - kFold(enumeratedVwInput, folds, vwParamsMap) - } - - def kFold(vwDataset: Iterable[String], folds: Int, vwParamsMap: Map[String, _]): Map[Int, (String, String)] = { - kFold(vwDataset.toIterator, folds, vwParamsMap) + val enumeratedVwInput = FileUtil.fileLinesIterator(inputPath) + println(s"kFold input path $inputPath") + kFold(enumeratedVwInput.toIterable, folds, vwParamsMap) } /** @@ -99,7 +86,7 @@ trait VwCrossValidation extends VwDatasetFunctions { * (trainingSetFilename, testSetFilename) */ //def kFold(vwDataset: Iterator[String], folds: Int, cacheBitSize: Int, cb: Option[Int]): Map[Int, (String, String)] = { - def kFold(vwDataset: Iterator[String], folds: Int, vwParamsMap: Map[String, _]): Map[Int, (String, String)] = { + def kFold(vwDataset: Iterable[String], folds: Int, vwParamsMap: Map[String, _]): Map[Int, (String, String)] = { val enumeratedVwDataset = vwDataset.zipWithIndex.toList // For every fold iteration, partition the vw input such that one fold is the test set and the