Skip to content

Commit

Permalink
Combinations #23 Subsets #25
Browse files Browse the repository at this point in the history
  • Loading branch information
vsuthichai committed Aug 18, 2016
1 parent 24247ec commit 4ea92e7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
package com.eharmony.spotz.optimizer.hyperparam

import scala.collection.mutable
import scala.util.Random


trait CombinatoricRandomSampler[T] extends RandomSampler[Iterable[Iterable[T]]]
trait IterableRandomSampler[T] extends RandomSampler[Iterable[T]]

/**
*
* @param iterable
* @param k
* @param x
* @param replacement
* @tparam T
*/
abstract class AbstractCombinations[T](
iterable: Iterable[T],
k: Int,
x: Int = 1,
replacement: Boolean = false) extends Serializable {

import org.paukov.combinatorics3.Generator

import scala.collection.JavaConverters._
replacement: Boolean = false)(implicit ord: Ordering[T]) extends Serializable {

private val values = iterable.toSeq
protected val values = iterable.toSeq

assert(k > 0, "k must be greater than 0")
assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}")

// TODO: This is hideous! Rewrite this to be more memory efficient by unranking combinations. For now, use a Java lib.
val combinations = Generator.combination(iterable.asJavaCollection).simple(k).asScala.toIndexedSeq.map(l => l.asScala.toIndexedSeq)
def sample(rng: Random): Iterable[T] = {
if (replacement) sampleWithReplacement(rng)
else sampleNoReplacement(rng)
}

/**
*
* @param rng
* @return
*/
def combos(rng: Random): Iterable[Iterable[T]] = {
if (replacement) {
Seq.fill(x)(combinations(rng.nextInt(combinations.size)))
} else {
val indices = collection.mutable.Set[Int]()
val numElements = scala.math.min(x, combinations.size)
val ret = new collection.mutable.ArrayBuffer[Iterable[T]](numElements)
while (indices.size < numElements) {
val index = rng.nextInt(combinations.size)
if (!indices.contains(index)) {
indices.add(index)
ret += combinations(index)
}
def sampleWithReplacement(rng: Random) = {
val combo = mutable.SortedSet[T]()

while (combo.size < k) {
val index = rng.nextInt(values.length)
val element = values(rng.nextInt(values.length))
combo.add(element)
}

combo.toSeq
}

def sampleNoReplacement(rng: Random) = {
val combo = mutable.SortedSet[T]()
val indices = mutable.Set[Int]()

while (combo.size < k) {
val index = rng.nextInt(values.length)
val element = values(rng.nextInt(values.length))
if (!indices.contains(index)) {
indices.add(index)
combo.add(element)
}
ret.toIndexedSeq
}

combo.toSeq
}
}


/**
* Sample a single combination of K unordered items from the iterable of length N.
*
Expand All @@ -66,12 +66,15 @@ abstract class AbstractCombinations[T](
* @tparam T
*/
case class Combination[T](
iterable: Iterable[T],
k: Int,
replacement: Boolean = false)
extends AbstractCombinations[T](iterable, k, 1, replacement) with IterableRandomSampler[T] {
iterable: Iterable[T],
k: Int,
replacement: Boolean = false)(implicit ord: Ordering[T])
extends AbstractCombinations[T](iterable, k, replacement)(ord) with IterableRandomSampler[T] {

assert(k > 0, "k must be greater than 0")
assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}")

override def apply(rng: Random): Iterable[T] = combos(rng).head
override def apply(rng: Random): Iterable[T] = sample(rng)
}


Expand All @@ -89,8 +92,22 @@ case class Combinations[T](
iterable: Iterable[T],
k: Int,
x: Int = 1,
replacement: Boolean = false)
extends AbstractCombinations[T](iterable, k, x, replacement) with CombinatoricRandomSampler[T] {
replacement: Boolean = false)(implicit ord: Ordering[T])
extends AbstractCombinations[T](iterable, k, replacement)(ord) with CombinatoricRandomSampler[T] {

override def apply(rng: Random): Iterable[Iterable[T]] = combos(rng)
assert(k > 0, "k must be greater than 0")
assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}")

override def apply(rng: Random): Iterable[Iterable[T]] = {
if (replacement) {
Seq.fill(x)(sample(rng))
} else {
val numElements = x // scala.math.min(x, combinations.size)
val ret = collection.mutable.Set[Iterable[T]]()
while (ret.size < numElements) {
ret += sample(rng)
}
ret.toIndexedSeq
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ package com.eharmony.spotz.optimizer.hyperparam
import scala.collection.mutable
import scala.util.Random

case class Subset[T](iterable: Iterable[T], k: Int)


case class Subsets[T](iterable: Iterable[T], k: Int, x: Int, replacement: Boolean = false)(implicit ord: Ordering[T]) extends CombinatoricRandomSampler[T] {
private val values = iterable.toIndexedSeq
abstract class AbstractSubset[T](iterable: Iterable[T], k: Int, replacement: Boolean = false)(implicit ord: Ordering[T]) extends Serializable {
protected val values = iterable.toIndexedSeq

def sample(rng: Random): Iterable[T] = {
val sampleSize = rng.nextInt(k) + 1
Expand All @@ -16,16 +13,59 @@ case class Subsets[T](iterable: Iterable[T], k: Int, x: Int, replacement: Boolea

while (subset.size < sampleSize) {
val index = rng.nextInt(values.size)
val element = values(index)

if (replacement) {
subset.add(values(index))
subset.add(element)
} else if (!indices.contains(index)) {
indices.add(index)
subset.add(values(index))
subset.add(element)
}
}

subset.toIndexedSeq
}
}

case class Subset[T](
iterable: Iterable[T],
k: Int,
replacement: Boolean = false)
(implicit ord: Ordering[T])
extends AbstractSubset[T](iterable, k, replacement)
with IterableRandomSampler[T] {

assert(k > 0 && k <= values.size, "K must be in the interval (0, N]")

def apply(rng: Random): Iterable[T] = sample(rng)
}

/**
*
* @param iterable
* @param k
* @param x
* @param replacement
* @param ord
* @tparam T
*/
case class Subsets[T](
iterable: Iterable[T],
k: Int,
x: Int,
replacement: Boolean = false)
(implicit ord: Ordering[T])
extends AbstractSubset[T](iterable, k, replacement)
with CombinatoricRandomSampler[T] {

assert(k > 0 && k <= values.size, "K must be in the interval (0, N]")
assert(x > 0, "X must be greater than 0")

/**
*
* @param rng
* @return
*/
def apply(rng: Random): Iterable[Iterable[T]] = {
val numSubsets = rng.nextInt(x) + 1

Expand Down

0 comments on commit 4ea92e7

Please sign in to comment.