diff --git a/README.md b/README.md index a26d072..3c98032 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,5 @@ https://craftinginterpreters.com/ - [x] Parsing Expressions - [x] Evaluating Expressions - [x] Statements and State -- [x] Control Flow \ No newline at end of file +- [x] Control Flow +- [x] Functions \ No newline at end of file diff --git a/pkg/ast/expr.go b/pkg/ast/expr.go index 229cb2f..9ea0340 100644 --- a/pkg/ast/expr.go +++ b/pkg/ast/expr.go @@ -67,4 +67,14 @@ type LogicalExpr struct { func (l *LogicalExpr) Accept(v ExprVisitor) interface{} { return v.VisitLogicalExpr(l) +} + +type CallExpr struct { + Callee Expr + Paren scanner.Token + Arguments []Expr +} + +func (c *CallExpr) Accept(v ExprVisitor) interface{} { + return v.VisitCallExpr(c) } \ No newline at end of file diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index a122c79..d12dffc 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -10,11 +10,12 @@ import ( type parser struct { tokens []scanner.Token current int + fdepth int err error } func NewParser(tokens []scanner.Token) *parser { - return &parser{tokens, 0, nil} + return &parser{tokens, 0, 0, nil} } func (p *parser) Parse() ([]Stmt, error) { @@ -34,10 +35,14 @@ func (p *parser) declaration() Stmt { return p.varDeclaration() } + if p.match(scanner.FUN) { + return p.funDeclaration("function") + } + return p.statement() } -func (p *parser) varDeclaration() Stmt { +func (p *parser) varDeclaration() *VarStmt { if !p.match(scanner.IDENTIFIER) { panic(fault.NewFault(p.tokens[p.current].Line, "expected variable name")) } @@ -55,6 +60,51 @@ func (p *parser) varDeclaration() Stmt { return &VarStmt{&name, initializer} } +func (p *parser) funDeclaration(kind string) *FunStmt { + if !p.match(scanner.IDENTIFIER) { + message := fmt.Sprintf("expected %s name", kind) + panic(fault.NewFault(p.tokens[p.current].Line, message)) + } + name := p.tokens[p.current-1] + + if !p.match(scanner.LEFT_PAREN) { + message := fmt.Sprintf("expected '(' after %s name", kind) + panic(fault.NewFault(p.tokens[p.current].Line, message)) + } + + params := []*scanner.Token{} + if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF { + if !p.match(scanner.IDENTIFIER) { + message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme) + panic(fault.NewFault(p.tokens[p.current].Line, message)) + } + params = append(params, &p.tokens[p.current-1]) + for p.match(scanner.COMMA) { + if !p.match(scanner.IDENTIFIER) { + message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme) + panic(fault.NewFault(p.tokens[p.current].Line, message)) + } + params = append(params, &p.tokens[p.current-1]) + if len(params) > 255 { + panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 parameters")) + } + } + } + + if !p.match(scanner.RIGHT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after parameter list")) + } + + if !p.match(scanner.LEFT_BRACE) { + message := fmt.Sprintf("expected '{' before %s body", kind) + panic(fault.NewFault(p.tokens[p.current].Line, message)) + } + + p.fdepth++ + defer func() { p.fdepth-- }() + return &FunStmt{&name, params, p.blockStatement()} +} + func (p *parser) statement() Stmt { if p.match(scanner.PRINT) { return p.printStatement() @@ -73,13 +123,17 @@ func (p *parser) statement() Stmt { } if p.match(scanner.LEFT_BRACE) { - return &BlockStmt{p.block()} + return p.blockStatement() + } + + if p.match(scanner.RETURN) { + return p.returnStatement() } return p.exprStatement() } -func (p *parser) printStatement() Stmt { +func (p *parser) printStatement() *PrintStmt { expr := p.expression() if !p.match(scanner.SEMICOLON) { @@ -89,7 +143,7 @@ func (p *parser) printStatement() Stmt { return &PrintStmt{expr} } -func (p *parser) ifStatement() Stmt { +func (p *parser) ifStatement() *IfStmt { if !p.match(scanner.LEFT_PAREN) { panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after if")) } @@ -123,7 +177,7 @@ func (p *parser) forStatement() Stmt { } var condition Expr - if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.SEMICOLON { + if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF { condition = p.expression() } if !p.match(scanner.SEMICOLON) { @@ -131,7 +185,7 @@ func (p *parser) forStatement() Stmt { } var increment Expr - if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_PAREN { + if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF { increment = p.expression() } if !p.match(scanner.RIGHT_PAREN) { @@ -156,7 +210,7 @@ func (p *parser) forStatement() Stmt { return body } -func (p *parser) whileStatement() Stmt { +func (p *parser) whileStatement() *WhileStmt { if !p.match(scanner.LEFT_PAREN) { panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after while")) } @@ -169,10 +223,10 @@ func (p *parser) whileStatement() Stmt { return &WhileStmt{condition, p.statement()} } -func (p *parser) block() []Stmt { +func (p *parser) blockStatement() *BlockStmt { stmts := []Stmt{} - for p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_BRACE { + for p.tokens[p.current].TokenType != scanner.RIGHT_BRACE && p.tokens[p.current].TokenType != scanner.EOF { stmts = append(stmts, p.declaration()) } @@ -180,10 +234,10 @@ func (p *parser) block() []Stmt { panic(fault.NewFault(p.tokens[p.current].Line, "expected '}' after block")) } - return stmts + return &BlockStmt{stmts} } -func (p *parser) exprStatement() Stmt { +func (p *parser) exprStatement() *ExprStmt { expr := p.expression() if !p.match(scanner.SEMICOLON) { @@ -193,6 +247,24 @@ func (p *parser) exprStatement() Stmt { return &ExprStmt{expr} } +func (p *parser) returnStatement() *ReturnStmt { + keyword := p.tokens[p.current-1] + if p.fdepth == 0 { + panic(fault.NewFault(keyword.Line, "cannot return outside of a function")) + } + + var value Expr + if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF { + value = p.expression() + } + + if !p.match(scanner.SEMICOLON) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after return statement")) + } + + return &ReturnStmt{&keyword, value} +} + func (p *parser) expression() Expr { return p.assignment() } @@ -293,7 +365,37 @@ func (p *parser) unary() Expr { return &UnaryExpr{&operator, right} } - return p.primary() + return p.call() +} + +func (p *parser) call() Expr { + expr := p.primary() + + for p.match(scanner.LEFT_PAREN) { + args, paren := p.arguments() + expr = &CallExpr{expr, paren, args} + } + + return expr +} + +func (p *parser) arguments() ([]Expr, scanner.Token) { + args := []Expr{} + if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF { + args = append(args, p.expression()) + for p.match(scanner.COMMA) { + args = append(args, p.expression()) + if len(args) > 255 { + panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 arguments")) + } + } + } + + if !p.match(scanner.RIGHT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after argument list")) + } + + return args, p.tokens[p.current-1] } func (p *parser) primary() Expr { @@ -321,26 +423,25 @@ func (p *parser) primary() Expr { if p.match(scanner.LEFT_PAREN) { e := p.expression() - if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN { - message := fmt.Sprintf("expected ')' after \"%s\"", p.tokens[p.current].Lexeme) + if !p.match(scanner.RIGHT_PAREN) { + message := fmt.Sprintf("expected ')' after '%s'", p.tokens[p.current-1].Lexeme) panic(fault.NewFault(p.tokens[p.current].Line, message)) } - p.current++ return &GroupingExpr{e} } - message := fmt.Sprintf("expected expression at \"%s\"", p.tokens[p.current].Lexeme) + message := fmt.Sprintf("expected expression at '%s'", p.tokens[p.current].Lexeme) panic(fault.NewFault(p.tokens[p.current].Line, message)) } func (p *parser) match(types ...int) bool { - if p.tokens[p.current].TokenType == scanner.EOF { + currentType := p.tokens[p.current].TokenType + if currentType == scanner.EOF { return false } - actualType := p.tokens[p.current].TokenType for _, tokenType := range types { - if actualType == tokenType { + if currentType == tokenType { p.current++ return true } @@ -386,4 +487,4 @@ func (p *parser) synchronize() { p.current++ } } -} \ No newline at end of file +} diff --git a/pkg/ast/stmt.go b/pkg/ast/stmt.go index 34e7ad6..a658f71 100644 --- a/pkg/ast/stmt.go +++ b/pkg/ast/stmt.go @@ -56,4 +56,23 @@ type WhileStmt struct { func (w *WhileStmt) Accept(v StmtVisitor) interface{} { return v.VisitWhileStmt(w) +} + +type FunStmt struct { + Name *scanner.Token + Params []*scanner.Token + Body *BlockStmt +} + +func (f *FunStmt) Accept(v StmtVisitor) interface{} { + return v.VisitFunStmt(f) +} + +type ReturnStmt struct { + Keyword *scanner.Token + Value Expr +} + +func (r *ReturnStmt) Accept(v StmtVisitor) interface{} { + return v.VisitReturnStmt(r) } \ No newline at end of file diff --git a/pkg/ast/visitor.go b/pkg/ast/visitor.go index fee47c2..d127d01 100644 --- a/pkg/ast/visitor.go +++ b/pkg/ast/visitor.go @@ -8,6 +8,7 @@ type ExprVisitor interface { VisitVariableExpr(v *VariableExpr) interface{} VisitAssignExpr(a *AssignExpr) interface{} VisitLogicalExpr(l *LogicalExpr) interface{} + VisitCallExpr(c *CallExpr) interface{} } type StmtVisitor interface { @@ -17,4 +18,6 @@ type StmtVisitor interface { VisitBlockStmt(b *BlockStmt) interface{} VisitIfStmt(i *IfStmt) interface{} VisitWhileStmt(w *WhileStmt) interface{} + VisitFunStmt(f *FunStmt) interface{} + VisitReturnStmt(r *ReturnStmt) interface{} } \ No newline at end of file diff --git a/pkg/interpreter/callable.go b/pkg/interpreter/callable.go new file mode 100644 index 0000000..b533b0a --- /dev/null +++ b/pkg/interpreter/callable.go @@ -0,0 +1,60 @@ +package interpreter + +import ( + "fmt" + "time" + + "golox/pkg/ast" +) + +type callable interface { + arity() int + call(i *interpreter, args []interface{}) interface{} +} + +type clock struct{} + +func (c *clock) arity() int { return 0 } + +func (c *clock) call(i *interpreter, args []interface{}) interface{} { + return float64(time.Now().UnixMilli() / 1000) +} + +func (c clock) String() string { + return "" +} + +type function struct { + declaration *ast.FunStmt +} + +func (f *function) arity() int { return len(f.declaration.Params) } + +func (f *function) call(i *interpreter, args []interface{}) (value interface{}) { + env := &environment{i.global, make(map[string]interface{})} + for i := 0; i < f.arity(); i++ { + env.define(f.declaration.Params[i].Lexeme, args[i]) + } + + prev := i.env + defer func() { + i.env = prev + r := recover() + if err, ok := r.(error); ok { + panic(err) + } else { + value = r + } + }() + + i.env = env + for _, stmt := range f.declaration.Body.Statements { + stmt.Accept(i) + } + + return value +} + +func (f function) String() string { + return fmt.Sprintf("", f.declaration.Name.Lexeme) +} diff --git a/pkg/interpreter/interpreter.go b/pkg/interpreter/interpreter.go index 30b4755..eb211fd 100644 --- a/pkg/interpreter/interpreter.go +++ b/pkg/interpreter/interpreter.go @@ -11,12 +11,14 @@ import ( type interpreter struct { - env *environment + global *environment + env *environment } func NewInterpreter() *interpreter { - env := &environment{nil, map[string]interface{}{}} - return &interpreter{env} + global := &environment{nil, make(map[string]interface{})} + global.define("clock", &clock{}) + return &interpreter{global, global} } func (i *interpreter) Interpret(stmts []ast.Stmt) (err error) { @@ -45,8 +47,6 @@ func (i *interpreter) VisitPrintStmt(p *ast.PrintStmt) interface{} { fmt.Println(strconv.FormatFloat(v, 'f', -1, 64)) case bool: fmt.Println(strconv.FormatBool(v)) - case nil: - fmt.Println("nil") default: fmt.Println(v) } @@ -66,12 +66,11 @@ func (i *interpreter) VisitVarStmt(v *ast.VarStmt) interface{} { func (i *interpreter) VisitBlockStmt(b *ast.BlockStmt) interface{} { prev := i.env - defer func() { i.env = prev }() - i.env = &environment{prev, map[string]interface{}{}} + i.env = &environment{prev, make(map[string]interface{})} for _, stmt := range b.Statements { stmt.Accept(i) } @@ -98,6 +97,20 @@ func (i *interpreter) VisitWhileStmt(w *ast.WhileStmt) interface{} { return nil } +func (i *interpreter) VisitFunStmt(f *ast.FunStmt) interface{} { + i.env.define(f.Name.Lexeme, &function{f}) + return nil +} + +func (i *interpreter) VisitReturnStmt(v *ast.ReturnStmt) interface{} { + var value interface{} + if v.Value != nil { + value = v.Value.Accept(i) + } + + panic(value) +} + func (i *interpreter) VisitBinaryExpr(b *ast.BinaryExpr) interface{} { left := b.Left.Accept(i) right := b.Right.Accept(i) @@ -200,6 +213,25 @@ func (i *interpreter) VisitLogicalExpr(l *ast.LogicalExpr) interface{} { return l.Right.Accept(i) } +func (i *interpreter) VisitCallExpr(c *ast.CallExpr) interface{} { + callee := c.Callee.Accept(i) + args := []interface{}{} + for _, arg := range c.Arguments { + args = append(args, arg.Accept(i)) + } + + if f, ok := callee.(callable); ok { + if len(args) != f.arity() { + message := fmt.Sprintf("expected %d arguments but got %d.", f.arity(), len(args)) + panic(fault.NewFault(c.Paren.Line, message)) + } + + return f.call(i, args) + } + + panic(fault.NewFault(c.Paren.Line, "can only call functions and classes")) +} + func (i *interpreter) checkNumberOperands(operator *scanner.Token, left interface{}, right interface{}) (float64, float64) { if leftValue, leftOk := left.(float64); leftOk { if rightValue, rightOk := right.(float64); rightOk {