forked from scala-lms/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfft.scala
348 lines (265 loc) · 12.1 KB
/
fft.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
/**
# Fast Fourier Transform (FFT)
<a name="sec:Afft"></a>
Outline:
<div id="tableofcontents"></div>
We consider staging a fast fourier transform (FFT) algorithm. A staged FFT,
implemented in MetaOCaml, has been presented by Kiselyov et~al.\
[(*)](DBLP:conf/emsoft/KiselyovST04) Their work is a very good example for how
staging allows to transform a simple, unoptimized algorithm into an efficient
program generator. Achieving this in the context of MetaOCaml, however,
required restructuring the program into monadic style and adding a front-end
layer for performing symbolic rewritings. Using our approach of just adding
`Rep` types, we can go from the naive textbook-algorithm to the staged version
(shown below) by changing literally two lines of code:
trait FFT { this: Arith with Trig =>
case class Complex(re: Rep[Double], im: Rep[Double])
...
}
All that is needed is adding the self-type annotation to import arithmetic and
trigonometric operations and changing the type of the real and imaginary
components of complex numbers from `Double` to `Rep[Double]`.
trait FFT { this: Arith with Trig =>
case class Complex(re: Rep[Double], im: Rep[Double]) {
def +(that: Complex) = Complex(this.re + that.re, this.im + that.im)
def *(that: Complex) = ...
}
def omega(k: Int, N: Int): Complex = {
val kth = -2.0 * k * Math.Pi / N
Complex(cos(kth), sin(kth))
}
def fft(xs: Array[Complex]): Array[Complex] = xs match {
case (x :: Nil) => xs
case _ =>
val N = xs.length // assume it's a power of two
val (even0, odd0) = splitEvenOdd(xs)
val (even1, odd1) = (fft(even0), fft(odd0))
val (even2, odd2) = (even1 zip odd1 zipWithIndex) map {
case ((x, y), k) =>
val z = omega(k, N) * y
(x + z, x - z)
}.unzip;
even2 ::: odd2
}
}
FFT code. Only the real and imaginary components of complex numbers need to be
staged.
\begin{figure}\centering
\includegraphics[scale=0.5]{papers/cacm2012/figures/test2-fft2-x-dot.pdf}
\caption{\label{fig:fftgraph} Computation graph for size-4 FFT. Auto-generated from
staged code in Figure~\ref{fig:fftcode}.}
\end{figure}
Merely changing the types will not provide us with the desired optimizations
yet. We will see below how we can add the transformations described by
Kiselyov et~al.\ to generate the same fixed-size FFT code, corresponding to
the famous FFT butterfly networks (see Figure~\ref{fig:fftgraph}). Despite the
seemingly naive algorithm, this staged code is free of branches, intermediate
data structures and redundant computations. The important point here is that
we can add these transformations without any further changes to the code in
Figure~\ref{fig:fftcode}, just by mixing in the trait `FFT` with a few others.
trait ArithExpOptFFT extends ArithExp {
override def infix_*(x:Exp[Double],y:Exp[Double]) = (x,y) match {
case (Const(k), Def(Times(Const(l), y))) => Const(k * l) * y
case (x, Def(Times(Const(k), y))) => Const(k) * (x * y))
case (Def(Times(Const(k), x)), y) => Const(k) * (x * y))
...
case (x, Const(y)) => Times(Const(y), x)
case _ => super.infix_*(x, y)
}
}
Extending the generic implementation from [here](#sec:308addOpts) with FFT-
specific optimizations.
## Implementing Optimizations
As already discussed [here](#sec:308addOpts), some profitable optimizations
are very generic (CSE, DCE, etc), whereas others are specific to the actual
program. In the FFT case, Kiselyov et al.\
[(*)](DBLP:conf/emsoft/KiselyovST04) describe a number of rewritings that are
particularly effective for the patterns of code generated by the FFT algorithm
but not as much for other programs.
What we want to achieve again is modularity, such that optimizations can be
combined in a way that is most useful for a given task. This can be achieved
by overriding smart constructors, as shown by trait `ArithExpOptFFT` (see
Figure~\ref{fig:expOpt}). Note that the use of `x*y` within the body of
`infix_*` will apply the optimization recursively.
## Running the Generated Code
Extending the FFT component from Figure~\ref{fig:fftcode} with explicit
compilation.
trait FFTC extends FFT { this: Arrays with Compile =>
def fftc(size: Int) = compile { input: Rep[Array[Double]] =>
assert(<size is power of 2>) // happens at staging time
val arg = Array.tabulate(size) { i =>
Complex(input(2*i), input(2*i+1))
}
val res = fft(arg)
updateArray(input, res.flatMap {
case Complex(re,im) => Array(re,im)
})
}
}
Using the staged FFT implementation as part of some larger Scala program is
straightforward but requires us to interface the generic algorithm with a
concrete data representation. The algorithm in Figure~\ref{fig:fftcode}
expects an array of `Complex` objects as input, each of which contains fields
of type `Rep[Double]`. The algorithm itself has no notion of staged arrays but
uses arrays only in the generator stage, which means that it is agnostic to
how data is stored. The enclosing program, however, will store arrays of
complex numbers in some native format which we will need to feed into the
algorithm. A simple choice of representation is to use `Array[Double]` with
the complex numbers flattened into adjacent slots. When applying `compile`, we
will thus receive input of type `Rep[Array[Double]]`. Figure~\ref{fig:fftc}
shows how we can extend trait `FFT` to `FFTC` to obtain compiled FFT
implementations that realize the necessary data interface for a fixed input
size.
We can then define code that creates and uses compiled FFT ``codelets'' by
extending `FFTC`:
trait TestFFTC extends FFTC {
val fft4: Array[Double] => Array[Double] = fftc(4)
val fft8: Array[Double] => Array[Double] = fftc(8)
// embedded code using fft4, fft8, ...
}
Constructing an instance of this subtrait (mixed in with the appropriate LMS
traits) will execute the embedded code:
val OP: TestFFC = new TestFFTC with CompileScala
with ArithExpOpt with ArithExpOptFFT with ScalaGenArith
with TrigExpOpt with ScalaGenTrig
with ArraysExpOpt with ScalaGenArrays
We can also use the compiled methods from outside the object:
OP.fft4(Array(1.0,0.0, 1.0,0.0, 2.0,0.0, 2.0,0.0))
$\hookrightarrow$ Array(6.0,0.0,-1.0,1.0,0.0,0.0,-1.0,-1.0)
Providing an explicit type in the definition `val OP: TestFFC = ...` ensures
that the internal representation is not accessible from the outside, only the
members defined by `TestFFC`.
The full code is below:
package scala.lms
package epfl
package test2
import common._
import test1._
import reflect.SourceContext
import java.io.PrintWriter
import org.scalatest._
trait FFT { this: Arith with Trig =>
def omega(k: Int, N: Int): Complex = {
val kth = -2.0 * k * math.Pi / N
Complex(cos(kth), sin(kth))
}
case class Complex(re: Rep[Double], im: Rep[Double]) {
def +(that: Complex) = Complex(this.re + that.re, this.im + that.im)
def -(that: Complex) = Complex(this.re - that.re, this.im - that.im)
def *(that: Complex) = Complex(this.re * that.re - this.im * that.im, this.re * that.im + this.im * that.re)
}
def splitEvenOdd[T](xs: List[T]): (List[T], List[T]) = (xs: @unchecked) match {
case e :: o :: xt =>
val (es, os) = splitEvenOdd(xt)
((e :: es), (o :: os))
case Nil => (Nil, Nil)
// cases?
}
def mergeEvenOdd[T](even: List[T], odd: List[T]): List[T] = ((even, odd): @unchecked) match {
case (Nil, Nil) =>
Nil
case ((e :: es), (o :: os)) =>
e :: (o :: mergeEvenOdd(es, os))
// cases?
}
def fft(xs: List[Complex]): List[Complex] = xs match {
case (x :: Nil) => xs
case _ =>
val N = xs.length // assume it's a power of two
val (even0, odd0) = splitEvenOdd(xs)
val (even1, odd1) = (fft(even0), fft(odd0))
val (even2, odd2) = (even1 zip odd1 zipWithIndex) map {
case ((x, y), k) =>
val z = omega(k, N) * y
(x + z, x - z)
} unzip;
even2 ::: odd2
}
}
trait ArithExpOptFFT extends ArithExpOpt {
override def infix_+(x: Exp[Double], y: Exp[Double])(implicit pos: SourceContext) = (x, y) match {
case (x, Def(Minus(Const(0.0) | Const(-0.0), y))) => infix_-(x, y)
case _ => super.infix_+(x, y)
}
override def infix_-(x: Exp[Double], y: Exp[Double])(implicit pos: SourceContext) = (x, y) match {
case (x, Def(Minus(Const(0.0) | Const(-0.0), y))) => infix_+(x, y)
case _ => super.infix_-(x, y)
}
override def infix_*(x: Exp[Double], y: Exp[Double])(implicit pos: SourceContext) = (x, y) match {
case (x, Const(-1.0)) => infix_-(0.0, x)
case (Const(-1.0), y) => infix_-(0.0, y)
case _ => super.infix_*(x, y)
}
}
trait TrigExpOptFFT extends TrigExpOpt {
override def cos(x: Exp[Double]) = x match {
case Const(x) if { val z = x / math.Pi / 0.5; z != 0 && z == z.toInt } => Const(0.0)
case _ => super.cos(x)
}
}
trait FlatResult extends BaseExp { // just to make dot output nicer
case class Result(x: Any) extends Def[Any]
def result(x: Any): Exp[Any] = toAtom(Result(x))
}
trait ScalaGenFlat extends ScalaGenBase {
import IR._
type Block[+T] = Exp[T]
def getBlockResultFull[T](x: Block[T]): Exp[T] = x
def reifyBlock[T:Typ](x: =>Exp[T]): Block[T] = x
def traverseBlock[A](block: Block[A]): Unit = {
buildScheduleForResult(block) foreach traverseStm
}
}
class TestFFT extends FileDiffSuite {
val prefix = home + "test-out/epfl/test2-"
def testFFT1 = {
withOutFile(prefix+"fft1") {
val o = new FFT with ArithExp with TrigExpOpt with FlatResult with DisableCSE //with DisableDCE
import o._
val r = fft(List.tabulate(4)(_ => Complex(fresh, fresh)))
println(globalDefs.mkString("\n"))
println(r)
val p = new ExportGraph with DisableDCE { val IR: o.type = o }
p.emitDepGraph(result(r), prefix+"fft1-dot", true)
}
assertFileEqualsCheck(prefix+"fft1")
assertFileEqualsCheck(prefix+"fft1-dot")
}
def testFFT2 = {
withOutFile(prefix+"fft2") {
val o = new FFT with ArithExpOptFFT with TrigExpOptFFT with FlatResult
import o._
case class Result(x: Any) extends Exp[Any]
val r = fft(List.tabulate(4)(_ => Complex(fresh, fresh)))
println(globalDefs.mkString("\n"))
println(r)
val p = new ExportGraph { val IR: o.type = o }
p.emitDepGraph(result(r), prefix+"fft2-dot", true)
}
assertFileEqualsCheck(prefix+"fft2")
assertFileEqualsCheck(prefix+"fft2-dot")
}
def testFFT3 = {
withOutFile(prefix+"fft3") {
class FooBar extends FFT
with ArithExpOptFFT with TrigExpOptFFT with ArraysExp
with CompileScala {
def ffts(input: Rep[Array[Double]], size: Int) = {
val list = List.tabulate(size)(i => Complex(input(2*i), input(2*i+1)))
val r = fft(list)
// make a new array for now - doing in-place update would be better
makeArray(r.flatMap { case Complex(re,im) => List(re,im) })
}
val codegen = new ScalaGenFlat with ScalaGenArith with ScalaGenArrays { val IR: FooBar.this.type = FooBar.this } // TODO: find a better way...
}
val o = new FooBar
import o._
val fft4 = (input: Rep[Array[Double]]) => ffts(input, 4)
codegen.emitSource(fft4, "FFT4", new PrintWriter(System.out))
val fft4c = compile(fft4)
println(fft4c(Array(1.0,0.0, 1.0,0.0, 2.0,0.0, 2.0,0.0, 1.0,0.0, 1.0,0.0, 0.0,0.0, 0.0,0.0)).mkString(","))
}
assertFileEqualsCheck(prefix+"fft3")
}
}
*/