Skip to content

Commit

Permalink
minor refactor (show cluster allocations)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulk-asert committed May 28, 2024
1 parent a9255b7 commit 3d4d785
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/

// inspired by the Apache Wayang Scala example here:
// https://github.com/apache/incubator-wayang/blob/main/README.md#k-means
// https://github.com/apache/incubator-wayang/blob/011859f8bd2b8f05a4ffc2d71c73edf1dd893fac/README.md#k-means
// See also a way to resurrect "lost" (no points selected) centroids:
// https://github.com/apache/incubator-wayang/blob/9071f6c3d47611657407a9287923294cb2b07e8c/wayang-benchmark/src/main/scala/org/apache/wayang/apps/kmeans/Kmeans.scala#L79

import org.apache.wayang.api.JavaPlanBuilder
import org.apache.wayang.core.api.WayangContext
Expand All @@ -27,10 +29,7 @@ import org.apache.wayang.java.Java
import org.apache.wayang.spark.Spark
import static java.lang.Math.sqrt

record Point(double[] pts) implements Serializable {
static Point fromLine(String line) {
new Point(line.split(',')[2..-1] as double[]) }
}
record Point(double[] pts) implements Serializable { }

record PointGrouping(double[] pts, int cluster, long count) implements Serializable {
PointGrouping(List<Double> pts, int cluster, long count) {
Expand All @@ -43,7 +42,7 @@ record PointGrouping(double[] pts, int cluster, long count) implements Serializa
}

PointGrouping average() {
new PointGrouping(pts.collect{ double d -> d/count }, cluster, count)
new PointGrouping(pts.collect{ double d -> d/count }, cluster, 1)
}
}

Expand Down Expand Up @@ -80,7 +79,9 @@ int iterations = 10

// read in data from our file
var url = WhiskeyWayang.classLoader.getResource('whiskey.csv').file
var pointsData = new File(url).readLines()[1..-1].collect{ Point.fromLine(it) }
def rows = new File(url).readLines()[1..-1]*.split(',')
var distilleries = rows*.getAt(1)
var pointsData = rows.collect{ new Point(it[2..-1] as double[]) }
var dims = pointsData[0].pts.size()

// create some random points as initial centroids
Expand All @@ -106,17 +107,30 @@ var finalCentroids = initialCentroids.repeat(iterations, currentCentroids ->
.reduceByKey(cluster, plus).withName('Aggregate points')
.map(average).withName('Average points')
.withOutputClass(PointGrouping)
).withName('Loop')
).withName('Loop').collect()

println 'Centroids:'
finalCentroids.forEach { c ->
var pts = c.pts.collect { sprintf '%.2f', it }.join(', ')
println "Cluster$c.cluster ($c.count points): $pts"
finalCentroids.each { c ->
println "Cluster $c.cluster: ${c.pts.collect { sprintf '%.2f', it }.join(', ')}"
}

println()
var allocator = new SelectNearestCentroid(centroids: finalCentroids)
var allocations = pointsData.withIndex()
.collect{ pt, idx -> [allocator.apply(pt).cluster, distilleries[idx]] }
.groupBy{ cluster, ds -> "Cluster $cluster" }
.collectValues{ v -> v.collect{ it[1] } }
.sort{ e1, e2 -> e1.key <=> e2.key }
allocations.each{ c, ds -> println "$c (${ds.size()} members): ${ds.join(', ')}" }
/*
Centroids:
Cluster0 (20 points): 2.00, 2.50, 1.55, 0.35, 0.20, 1.15, 1.55, 0.95, 0.90, 1.80, 1.35, 1.35
Cluster2 (21 points): 2.81, 2.43, 1.52, 0.05, 0.00, 1.90, 1.67, 2.05, 2.10, 2.10, 2.19, 1.76
Cluster3 (34 points): 1.38, 2.32, 1.09, 0.26, 0.03, 1.15, 1.09, 0.47, 1.38, 1.74, 2.03, 2.24
Cluster4 (11 points): 2.91, 1.55, 2.91, 2.73, 0.45, 0.45, 1.45, 0.55, 1.55, 1.45, 1.18, 0.55
Cluster 0: 2.53, 1.65, 2.76, 2.12, 0.29, 0.65, 1.65, 0.59, 1.35, 1.41, 1.35, 0.94
Cluster 2: 3.33, 2.56, 1.67, 0.11, 0.00, 1.89, 1.89, 2.78, 2.00, 1.89, 2.33, 1.33
Cluster 3: 1.42, 2.47, 1.03, 0.22, 0.06, 1.00, 1.03, 0.47, 1.19, 1.72, 1.92, 2.08
Cluster 4: 2.25, 2.38, 1.38, 0.08, 0.13, 1.79, 1.54, 1.33, 1.75, 2.17, 1.75, 1.79
Cluster 0 (17 members): Ardbeg, Balblair, Bowmore, Bruichladdich, Caol Ila, Clynelish, GlenGarioch, GlenScotia, Highland Park, Isle of Jura, Lagavulin, Laphroig, Oban, OldPulteney, Springbank, Talisker, Teaninich
Cluster 2 (9 members): Aberlour, Balmenach, Dailuaine, Dalmore, Glendronach, Glenfarclas, Macallan, Mortlach, RoyalLochnagar
Cluster 3 (36 members): AnCnoc, ArranIsleOf, Auchentoshan, Aultmore, Benriach, Bladnoch, Bunnahabhain, Cardhu, Craigganmore, Dalwhinnie, Dufftown, GlenElgin, GlenGrant, GlenMoray, GlenSpey, Glenallachie, Glenfiddich, Glengoyne, Glenkinchie, Glenlossie, Glenmorangie, Inchgower, Linkwood, Loch Lomond, Mannochmore, Miltonduff, RoyalBrackla, Speyburn, Speyside, Strathmill, Tamdhu, Tamnavulin, Tobermory, Tomintoul, Tomore, Tullibardine
Cluster 4 (24 members): Aberfeldy, Ardmore, Auchroisk, Belvenie, BenNevis, Benrinnes, Benromach, BlairAthol, Craigallechie, Deanston, Edradour, GlenDeveronMacduff, GlenKeith, GlenOrd, Glendullan, Glenlivet, Glenrothes, Glenturret, Knochando, Longmorn, OldFettercairn, Scapa, Strathisla, Tomatin
*/

0 comments on commit 3d4d785

Please sign in to comment.