forked from spirom/LearningSpark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUDT.scala
106 lines (83 loc) · 3.02 KB
/
UDT.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package dataframe
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.sql.types._
//
// This example demonstrates how to define a basic user defined type (UDT) and
// how to use it in a query. The attributes of the underlying class are not
// directly accessible in the query, but you can access them by defining
// a user defined function (UDF) to be applied to instances of the UDT.
//
// NOTE: there is a more comprehensive example that involves layering of one
// UDT on top of the other in sql/UDT.scala.
//
//
// Underlying case class defining 3D points. The annotation conencts it with
// the UDT definition below.
//
/** *** SPECIAL NOTE ***
* This feature has been removed in Spark 2.0.0 -- please see
* https://issues.apache.org/jira/browse/SPARK-14155
@SQLUserDefinedType(udt = classOf[MyPoint3DUDT])
private case class MyPoint3D(x: Double, y: Double, z: Double)
//
// The UDT definition for 3D points: basically how to serialize and deserialize.
//
private class MyPoint3DUDT extends UserDefinedType[MyPoint3D] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyPoint3D =>
new GenericArrayData(Array(features.x, features.y, features.z))
}
}
override def deserialize(datum: Any): MyPoint3D = {
datum match {
case data: ArrayData if data.numElements() == 3 => {
val arr = data.toDoubleArray()
new MyPoint3D(arr(0), arr(1), arr(2))
}
}
}
override def userClass: Class[MyPoint3D] = classOf[MyPoint3D]
override def asNullable: MyPoint3DUDT = this
}
object UDT {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("DataFrame-UDT").setMaster("local[4]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
//
// First define some points, store them in a table and filter them
// based on magnitude -- i.e.: distance form the origin
//
val p1 = new MyPoint3D(1.0, 2.0, 3.0)
val p2 = new MyPoint3D(1.0, 0.0, 2.0)
val p3 = new MyPoint3D(10.0, 20.0, 30.0)
val p4 = new MyPoint3D(11.0, 22.0, 33.0)
val points = Seq(
("P1", p1),
("P2", p2),
("P3", p3),
("P4", p4)
).toDF("label", "point")
println("*** All the points as a dataframe")
points.printSchema()
points.show()
// Define a UDF to get access to attributes of a point in a query
val myMagnitude =
udf { p: MyPoint3D =>
math.sqrt(math.pow(p.x, 2) + math.pow(p.y, 2) + math.pow(p.z, 2))
}
val nearPoints =
points.filter(myMagnitude($"point").lt(10))
.select($"label", myMagnitude($"point").as("magnitude"))
println("*** The points close to the origin, selected from the table")
nearPoints.printSchema()
nearPoints.show()
}
}
**/