Skip to content

Commit

Permalink
added functions
Browse files Browse the repository at this point in the history
  • Loading branch information
useEffects committed Jul 28, 2024
1 parent 606fb47 commit 079c735
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 29 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ https://craftinginterpreters.com/
- [x] Parsing Expressions
- [x] Evaluating Expressions
- [x] Statements and State
- [x] Control Flow
- [x] Control Flow
- [x] Functions
10 changes: 10 additions & 0 deletions pkg/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,14 @@ type LogicalExpr struct {

func (l *LogicalExpr) Accept(v ExprVisitor) interface{} {
return v.VisitLogicalExpr(l)
}

type CallExpr struct {
Callee Expr
Paren scanner.Token
Arguments []Expr
}

func (c *CallExpr) Accept(v ExprVisitor) interface{} {
return v.VisitCallExpr(c)
}
143 changes: 122 additions & 21 deletions pkg/ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import (
type parser struct {
tokens []scanner.Token
current int
fdepth int
err error
}

func NewParser(tokens []scanner.Token) *parser {
return &parser{tokens, 0, nil}
return &parser{tokens, 0, 0, nil}
}

func (p *parser) Parse() ([]Stmt, error) {
Expand All @@ -34,10 +35,14 @@ func (p *parser) declaration() Stmt {
return p.varDeclaration()
}

if p.match(scanner.FUN) {
return p.funDeclaration("function")
}

return p.statement()
}

func (p *parser) varDeclaration() Stmt {
func (p *parser) varDeclaration() *VarStmt {
if !p.match(scanner.IDENTIFIER) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected variable name"))
}
Expand All @@ -55,6 +60,51 @@ func (p *parser) varDeclaration() Stmt {
return &VarStmt{&name, initializer}
}

func (p *parser) funDeclaration(kind string) *FunStmt {
if !p.match(scanner.IDENTIFIER) {
message := fmt.Sprintf("expected %s name", kind)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}
name := p.tokens[p.current-1]

if !p.match(scanner.LEFT_PAREN) {
message := fmt.Sprintf("expected '(' after %s name", kind)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}

params := []*scanner.Token{}
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
if !p.match(scanner.IDENTIFIER) {
message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}
params = append(params, &p.tokens[p.current-1])
for p.match(scanner.COMMA) {
if !p.match(scanner.IDENTIFIER) {
message := fmt.Sprintf("expected parameter name at %s", p.tokens[p.current].Lexeme)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}
params = append(params, &p.tokens[p.current-1])
if len(params) > 255 {
panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 parameters"))
}
}
}

if !p.match(scanner.RIGHT_PAREN) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after parameter list"))
}

if !p.match(scanner.LEFT_BRACE) {
message := fmt.Sprintf("expected '{' before %s body", kind)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}

p.fdepth++
defer func() { p.fdepth-- }()
return &FunStmt{&name, params, p.blockStatement()}
}

func (p *parser) statement() Stmt {
if p.match(scanner.PRINT) {
return p.printStatement()
Expand All @@ -73,13 +123,17 @@ func (p *parser) statement() Stmt {
}

if p.match(scanner.LEFT_BRACE) {
return &BlockStmt{p.block()}
return p.blockStatement()
}

if p.match(scanner.RETURN) {
return p.returnStatement()
}

return p.exprStatement()
}

func (p *parser) printStatement() Stmt {
func (p *parser) printStatement() *PrintStmt {
expr := p.expression()

if !p.match(scanner.SEMICOLON) {
Expand All @@ -89,7 +143,7 @@ func (p *parser) printStatement() Stmt {
return &PrintStmt{expr}
}

func (p *parser) ifStatement() Stmt {
func (p *parser) ifStatement() *IfStmt {
if !p.match(scanner.LEFT_PAREN) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after if"))
}
Expand Down Expand Up @@ -123,15 +177,15 @@ func (p *parser) forStatement() Stmt {
}

var condition Expr
if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.SEMICOLON {
if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF {
condition = p.expression()
}
if !p.match(scanner.SEMICOLON) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after conditional expression"))
}

var increment Expr
if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_PAREN {
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
increment = p.expression()
}
if !p.match(scanner.RIGHT_PAREN) {
Expand All @@ -156,7 +210,7 @@ func (p *parser) forStatement() Stmt {
return body
}

func (p *parser) whileStatement() Stmt {
func (p *parser) whileStatement() *WhileStmt {
if !p.match(scanner.LEFT_PAREN) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after while"))
}
Expand All @@ -169,21 +223,21 @@ func (p *parser) whileStatement() Stmt {
return &WhileStmt{condition, p.statement()}
}

func (p *parser) block() []Stmt {
func (p *parser) blockStatement() *BlockStmt {
stmts := []Stmt{}

for p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_BRACE {
for p.tokens[p.current].TokenType != scanner.RIGHT_BRACE && p.tokens[p.current].TokenType != scanner.EOF {
stmts = append(stmts, p.declaration())
}

if !p.match(scanner.RIGHT_BRACE) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected '}' after block"))
}

return stmts
return &BlockStmt{stmts}
}

func (p *parser) exprStatement() Stmt {
func (p *parser) exprStatement() *ExprStmt {
expr := p.expression()

if !p.match(scanner.SEMICOLON) {
Expand All @@ -193,6 +247,24 @@ func (p *parser) exprStatement() Stmt {
return &ExprStmt{expr}
}

func (p *parser) returnStatement() *ReturnStmt {
keyword := p.tokens[p.current-1]
if p.fdepth == 0 {
panic(fault.NewFault(keyword.Line, "cannot return outside of a function"))
}

var value Expr
if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF {
value = p.expression()
}

if !p.match(scanner.SEMICOLON) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after return statement"))
}

return &ReturnStmt{&keyword, value}
}

func (p *parser) expression() Expr {
return p.assignment()
}
Expand Down Expand Up @@ -293,7 +365,37 @@ func (p *parser) unary() Expr {
return &UnaryExpr{&operator, right}
}

return p.primary()
return p.call()
}

func (p *parser) call() Expr {
expr := p.primary()

for p.match(scanner.LEFT_PAREN) {
args, paren := p.arguments()
expr = &CallExpr{expr, paren, args}
}

return expr
}

func (p *parser) arguments() ([]Expr, scanner.Token) {
args := []Expr{}
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN && p.tokens[p.current].TokenType != scanner.EOF {
args = append(args, p.expression())
for p.match(scanner.COMMA) {
args = append(args, p.expression())
if len(args) > 255 {
panic(fault.NewFault(p.tokens[p.current].Line, "cannot have more than 255 arguments"))
}
}
}

if !p.match(scanner.RIGHT_PAREN) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after argument list"))
}

return args, p.tokens[p.current-1]
}

func (p *parser) primary() Expr {
Expand Down Expand Up @@ -321,26 +423,25 @@ func (p *parser) primary() Expr {

if p.match(scanner.LEFT_PAREN) {
e := p.expression()
if p.tokens[p.current].TokenType != scanner.RIGHT_PAREN {
message := fmt.Sprintf("expected ')' after \"%s\"", p.tokens[p.current].Lexeme)
if !p.match(scanner.RIGHT_PAREN) {
message := fmt.Sprintf("expected ')' after '%s'", p.tokens[p.current-1].Lexeme)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}
p.current++
return &GroupingExpr{e}
}

message := fmt.Sprintf("expected expression at \"%s\"", p.tokens[p.current].Lexeme)
message := fmt.Sprintf("expected expression at '%s'", p.tokens[p.current].Lexeme)
panic(fault.NewFault(p.tokens[p.current].Line, message))
}

func (p *parser) match(types ...int) bool {
if p.tokens[p.current].TokenType == scanner.EOF {
currentType := p.tokens[p.current].TokenType
if currentType == scanner.EOF {
return false
}

actualType := p.tokens[p.current].TokenType
for _, tokenType := range types {
if actualType == tokenType {
if currentType == tokenType {
p.current++
return true
}
Expand Down Expand Up @@ -386,4 +487,4 @@ func (p *parser) synchronize() {
p.current++
}
}
}
}
19 changes: 19 additions & 0 deletions pkg/ast/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,23 @@ type WhileStmt struct {

func (w *WhileStmt) Accept(v StmtVisitor) interface{} {
return v.VisitWhileStmt(w)
}

type FunStmt struct {
Name *scanner.Token
Params []*scanner.Token
Body *BlockStmt
}

func (f *FunStmt) Accept(v StmtVisitor) interface{} {
return v.VisitFunStmt(f)
}

type ReturnStmt struct {
Keyword *scanner.Token
Value Expr
}

func (r *ReturnStmt) Accept(v StmtVisitor) interface{} {
return v.VisitReturnStmt(r)
}
3 changes: 3 additions & 0 deletions pkg/ast/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type ExprVisitor interface {
VisitVariableExpr(v *VariableExpr) interface{}
VisitAssignExpr(a *AssignExpr) interface{}
VisitLogicalExpr(l *LogicalExpr) interface{}
VisitCallExpr(c *CallExpr) interface{}
}

type StmtVisitor interface {
Expand All @@ -17,4 +18,6 @@ type StmtVisitor interface {
VisitBlockStmt(b *BlockStmt) interface{}
VisitIfStmt(i *IfStmt) interface{}
VisitWhileStmt(w *WhileStmt) interface{}
VisitFunStmt(f *FunStmt) interface{}
VisitReturnStmt(r *ReturnStmt) interface{}
}
60 changes: 60 additions & 0 deletions pkg/interpreter/callable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package interpreter

import (
"fmt"
"time"

"golox/pkg/ast"
)

type callable interface {
arity() int
call(i *interpreter, args []interface{}) interface{}
}

type clock struct{}

func (c *clock) arity() int { return 0 }

func (c *clock) call(i *interpreter, args []interface{}) interface{} {
return float64(time.Now().UnixMilli() / 1000)
}

func (c clock) String() string {
return "<native function clock>"
}

type function struct {
declaration *ast.FunStmt
}

func (f *function) arity() int { return len(f.declaration.Params) }

func (f *function) call(i *interpreter, args []interface{}) (value interface{}) {
env := &environment{i.global, make(map[string]interface{})}
for i := 0; i < f.arity(); i++ {
env.define(f.declaration.Params[i].Lexeme, args[i])
}

prev := i.env
defer func() {
i.env = prev
r := recover()
if err, ok := r.(error); ok {
panic(err)
} else {
value = r
}
}()

i.env = env
for _, stmt := range f.declaration.Body.Statements {
stmt.Accept(i)
}

return value
}

func (f function) String() string {
return fmt.Sprintf("<function %s >", f.declaration.Name.Lexeme)
}
Loading

0 comments on commit 079c735

Please sign in to comment.