forked from scala-lms/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdslapi.scala
301 lines (277 loc) · 11.2 KB
/
dslapi.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
package scala.lms.tutorial
import scala.lms.common._
import scala.reflect.SourceContext
// should this be added to LMS?
trait UtilOps extends Base { this: Dsl =>
def infix_HashCode[T:Typ](o: Rep[T])(implicit pos: SourceContext): Rep[Long]
def infix_HashCode(o: Rep[String], len: Rep[Int])(implicit v: Overloaded1, pos: SourceContext): Rep[Long]
}
trait UtilOpsExp extends UtilOps with BaseExp { this: DslExp =>
case class ObjHashCode[T:Typ](o: Rep[T])(implicit pos: SourceContext) extends Def[Long] { def m = typ[T] }
case class StrSubHashCode(o: Rep[String], len: Rep[Int])(implicit pos: SourceContext) extends Def[Long]
def infix_HashCode[T:Typ](o: Rep[T])(implicit pos: SourceContext) = ObjHashCode(o)
def infix_HashCode(o: Rep[String], len: Rep[Int])(implicit v: Overloaded1, pos: SourceContext) = StrSubHashCode(o,len)
override def mirror[A:Typ](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match {
case e@ObjHashCode(a) => infix_HashCode(f(a))(e.m,pos)
case e@StrSubHashCode(o,len) => infix_HashCode(f(o),f(len))
case _ => super.mirror(e,f)
}).asInstanceOf[Exp[A]]
}
trait ScalaGenUtilOps extends ScalaGenBase {
val IR: UtilOpsExp
import IR._
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
case ObjHashCode(o) => emitValDef(sym, src"$o.##")
case _ => super.emitNode(sym, rhs)
}
}
trait CGenUtilOps extends CGenBase {
val IR: UtilOpsExp
import IR._
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
case StrSubHashCode(o,len) => emitValDef(sym, src"hash($o,$len)")
case _ => super.emitNode(sym, rhs)
}
}
trait Dsl extends PrimitiveOps with NumericOps with BooleanOps with LiftString with LiftPrimitives with LiftNumeric with LiftBoolean with IfThenElse with Equal with RangeOps with OrderingOps with MiscOps with ArrayOps with StringOps with SeqOps with Functions with While with StaticData with Variables with LiftVariables with ObjectOps with UtilOps {
implicit def repStrToSeqOps(a: Rep[String]) = new SeqOpsCls(a.asInstanceOf[Rep[Seq[Char]]])
override def infix_&&(lhs: Rep[Boolean], rhs: => Rep[Boolean])(implicit pos: scala.reflect.SourceContext): Rep[Boolean] =
__ifThenElse(lhs, rhs, unit(false))
def generate_comment(l: String): Rep[Unit]
def comment[A:Typ](l: String, verbose: Boolean = true)(b: => Rep[A]): Rep[A]
}
trait DslExp extends Dsl with PrimitiveOpsExpOpt with NumericOpsExpOpt with BooleanOpsExp with IfThenElseExpOpt with EqualExpBridgeOpt with RangeOpsExp with OrderingOpsExp with MiscOpsExp with EffectExp with ArrayOpsExpOpt with StringOpsExp with SeqOpsExp with FunctionsRecursiveExp with WhileExp with StaticDataExp with VariablesExpOpt with ObjectOpsExpOpt with UtilOpsExp {
override def boolean_or(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = lhs match {
case Const(false) => rhs
case _ => super.boolean_or(lhs, rhs)
}
override def boolean_and(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = lhs match {
case Const(true) => rhs
case _ => super.boolean_and(lhs, rhs)
}
case class GenerateComment(l: String) extends Def[Unit]
def generate_comment(l: String) = reflectEffect(GenerateComment(l))
case class Comment[A:Typ](l: String, verbose: Boolean, b: Block[A]) extends Def[A]
def comment[A:Typ](l: String, verbose: Boolean)(b: => Rep[A]): Rep[A] = {
val br = reifyEffects(b)
val be = summarizeEffects(br)
reflectEffect[A](Comment(l, verbose, br), be)
}
override def boundSyms(e: Any): List[Sym[Any]] = e match {
case Comment(_, _, b) => effectSyms(b)
case _ => super.boundSyms(e)
}
override def array_apply[T:Typ](x: Exp[Array[T]], n: Exp[Int])(implicit pos: SourceContext): Exp[T] = (x,n) match {
case (Def(StaticData(x:Array[T])), Const(n)) =>
val y = x(n)
if (y.isInstanceOf[Int]) unit(y) else staticData(y)
case _ => super.array_apply(x,n)
}
// TODO: should this be in LMS?
override def isPrimitiveType[T](m: Typ[T]) = (m == manifest[String]) || super.isPrimitiveType(m)
}
trait DslGen extends ScalaGenNumericOps
with ScalaGenPrimitiveOps with ScalaGenBooleanOps with ScalaGenIfThenElse
with ScalaGenEqual with ScalaGenRangeOps with ScalaGenOrderingOps
with ScalaGenMiscOps with ScalaGenArrayOps with ScalaGenStringOps
with ScalaGenSeqOps with ScalaGenFunctions with ScalaGenWhile
with ScalaGenStaticData with ScalaGenVariables
with ScalaGenObjectOps
with ScalaGenUtilOps {
val IR: DslExp
import IR._
override def quote(x: Exp[Any]) = x match {
case Const('\n') if x.tp == typ[Char] => "'\\n'"
case Const('\t') if x.tp == typ[Char] => "'\\t'"
case Const(0) if x.tp == typ[Char] => "'\\0'"
case _ => super.quote(x)
}
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
case IfThenElse(c,Block(Const(true)),Block(Const(false))) =>
emitValDef(sym, quote(c))
case PrintF(f:String,xs) =>
emitValDef(sym, src"printf(${Const(f)::xs})")
case GenerateComment(s) =>
stream.println("// "+s)
case Comment(s, verbose, b) =>
stream.println("val " + quote(sym) + " = {")
stream.println("//#" + s)
if (verbose) {
stream.println("// generated code for " + s.replace('_', ' '))
} else {
stream.println("// generated code")
}
emitBlock(b)
stream.println(quote(getBlockResult(b)))
stream.println("//#" + s)
stream.println("}")
case _ => super.emitNode(sym, rhs)
}
}
trait DslImpl extends DslExp { q =>
val codegen = new DslGen {
val IR: q.type = q
}
}
// TODO: currently part of this is specific to the query tests. generalize? move?
trait DslGenC extends CGenNumericOps
with CGenPrimitiveOps with CGenBooleanOps with CGenIfThenElse
with CGenEqual with CGenRangeOps with CGenOrderingOps
with CGenMiscOps with CGenArrayOps with CGenStringOps
with CGenSeqOps with CGenFunctions with CGenWhile
with CGenStaticData with CGenVariables
with CGenObjectOps
with CGenUtilOps {
val IR: DslExp
import IR._
def getMemoryAllocString(count: String, memType: String): String = {
"(" + memType + "*)malloc(" + count + " * sizeof(" + memType + "));"
}
override def remap[A](m: Typ[A]): String = m.toString match {
case "java.lang.String" => "char*"
case "Array[Char]" => "char*"
case "Char" => "char"
case _ => super.remap(m)
}
override def format(s: Exp[Any]): String = {
remap(s.tp) match {
case "uint16_t" => "%c"
case "bool" | "int8_t" | "int16_t" | "int32_t" => "%d"
case "int64_t" => "%ld"
case "float" | "double" => "%f"
case "string" => "%s"
case "char*" => "%s"
case "char" => "%c"
case "void" => "%c"
case _ =>
import scala.lms.internal.GenerationFailedException
throw new GenerationFailedException("CGenMiscOps: cannot print type " + remap(s.tp))
}
}
override def quoteRawString(s: Exp[Any]): String = {
remap(s.tp) match {
case "string" => quote(s) + ".c_str()"
case _ => quote(s)
}
}
// we treat string as a primitive type to prevent memory management on strings
// strings are always stack allocated and freed automatically at the scope exit
override def isPrimitiveType(tpe: String) : Boolean = {
tpe match {
case "char*" => true
case "char" => true
case _ => super.isPrimitiveType(tpe)
}
}
override def quote(x: Exp[Any]) = x match {
case Const(s: String) => "\""+s.replace("\"", "\\\"")+"\"" // TODO: more escapes?
case Const('\n') if x.tp == typ[Char] => "'\\n'"
case Const('\t') if x.tp == typ[Char] => "'\\t'"
case Const(0) if x.tp == typ[Char] => "'\\0'"
case _ => super.quote(x)
}
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
case a@ArrayNew(n) =>
val arrType = remap(a.m)
stream.println(arrType + "* " + quote(sym) + " = " + getMemoryAllocString(quote(n), arrType))
case ArrayApply(x,n) => emitValDef(sym, quote(x) + "[" + quote(n) + "]")
case ArrayUpdate(x,n,y) => stream.println(quote(x) + "[" + quote(n) + "] = " + quote(y) + ";")
case PrintLn(s) => stream.println("printf(\"" + format(s) + "\\n\"," + quoteRawString(s) + ");")
case StringCharAt(s,i) => emitValDef(sym, "%s[%s]".format(quote(s), quote(i)))
case Comment(s, verbose, b) =>
stream.println("//#" + s)
if (verbose) {
stream.println("// generated code for " + s.replace('_', ' '))
} else {
stream.println("// generated code")
}
emitBlock(b)
emitValDef(sym, quote(getBlockResult(b)))
stream.println("//#" + s)
case _ => super.emitNode(sym,rhs)
}
override def emitSource[A:Typ](args: List[Sym[_]], body: Block[A], functionName: String, out: java.io.PrintWriter) = {
withStream(out) {
stream.println("""
#include <fcntl.h>
#include <errno.h>
#include <err.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <stdio.h>
#include <stdint.h>
#include <unistd.h>
#ifndef MAP_FILE
#define MAP_FILE MAP_SHARED
#endif
int fsize(int fd) {
struct stat stat;
int res = fstat(fd,&stat);
return stat.st_size;
}
int printll(char* s) {
while (*s != '\n' && *s != ',' && *s != '\t') {
putchar(*s++);
}
return 0;
}
long hash(char *str0, int len)
{
unsigned char* str = (unsigned char*)str0;
unsigned long hash = 5381;
int c;
while ((c = *str++) && len--)
hash = ((hash << 5) + hash) + c; /* hash * 33 + c */
return hash;
}
void Snippet(char*);
int main(int argc, char *argv[])
{
if (argc != 2) {
printf("usage: query <filename>\n");
return 0;
}
Snippet(argv[1]);
return 0;
}
""")
}
super.emitSource[A](args, body, functionName, out)
}
}
abstract class DslSnippet[A:Manifest,B:Manifest] extends Dsl {
def snippet(x: Rep[A]): Rep[B]
}
abstract class DslDriver[A:Manifest,B:Manifest] extends DslSnippet[A,B] with DslImpl with CompileScala {
lazy val f = compile(snippet)(manifestTyp[A],manifestTyp[B])
def precompile: Unit = f
def precompileSilently: Unit = utils.devnull(f)
def eval(x: A): B = f(x)
lazy val code: String = {
val source = new java.io.StringWriter()
codegen.emitSource(snippet, "Snippet", new java.io.PrintWriter(source))(manifestTyp[A],manifestTyp[B])
source.toString
}
}
abstract class DslDriverC[A:Manifest,B:Manifest] extends DslSnippet[A,B] with DslExp { q =>
val codegen = new DslGenC {
val IR: q.type = q
}
lazy val code: String = {
implicit val mA = manifestTyp[A]
implicit val mB = manifestTyp[B]
val source = new java.io.StringWriter()
codegen.emitSource(snippet, "Snippet", new java.io.PrintWriter(source))
source.toString
}
def eval(a:A): Unit = { // TBD: should read result of type B?
val out = new java.io.PrintWriter("/tmp/snippet.c")
out.println(code)
out.close
//TODO: use precompile
(new java.io.File("/tmp/snippet")).delete
import scala.sys.process._
(s"cc -std=c99 -O3 /tmp/snippet.c -o /tmp/snippet":ProcessBuilder).lines.foreach(Console.println _)
(s"/tmp/snippet $a":ProcessBuilder).lines.foreach(Console.println _)
}
}