Skip to content

Commit

Permalink
add support for Inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
useEffects committed Jul 29, 2024
1 parent ff4717f commit dabd695
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 37 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ https://craftinginterpreters.com/
- [x] Control Flow
- [x] Functions
- [x] Resolving and Binding
- [x] Classes
- [x] Classes
- [x] Inheritance
5 changes: 5 additions & 0 deletions pkg/interpreter/callable.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func (f function) String() string {

type class struct {
name string
super *class
methods map[string]*function
}

Expand Down Expand Up @@ -104,6 +105,10 @@ func (c *class) findMethod(name string) *function {
return fn
}

if c.super != nil {
return c.super.findMethod(name)
}

return nil
}

Expand Down
46 changes: 37 additions & 9 deletions pkg/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"golox/pkg/scanner"
)


type Interpreter struct {
global *environment
current *environment
Expand Down Expand Up @@ -70,9 +71,7 @@ func (i *Interpreter) VisitVarStmt(v *parser.VarStmt) interface{} {

func (i *Interpreter) VisitBlockStmt(b *parser.BlockStmt) interface{} {
prev := i.current
defer func() {
i.current = prev
}()
defer func() { i.current = prev }()

i.current = &environment{prev, make(map[string]interface{})}
for _, stmt := range b.Statements {
Expand Down Expand Up @@ -117,7 +116,22 @@ func (i *Interpreter) VisitReturnStmt(v *parser.ReturnStmt) interface{} {
}

func (i *Interpreter) VisitClassStmt(c *parser.ClassStmt) interface{} {
var super *class
if c.Super != nil {
if value, ok := c.Super.Accept(i).(*class); ok {
super = value
} else {
message := fmt.Sprintf("%s is a not a class", c.Super.Name.Lexeme)
panic(fault.NewFault(c.Super.Name.Line, message))
}
}

i.current.define(c.Name.Lexeme, nil)
if c.Super != nil {
i.current = &environment{i.current, make(map[string]interface{})}
i.current.define("super", super)
}

methods := make(map[string]*function)
for _, method := range c.Methods {
if method.Name.Lexeme == "init" {
Expand All @@ -127,14 +141,18 @@ func (i *Interpreter) VisitClassStmt(c *parser.ClassStmt) interface{} {
}
}

i.current.assign(c.Name, &class{c.Name.Lexeme, methods})
c_ := &class{c.Name.Lexeme, super, methods}
if c.Super != nil {
i.current = i.current.enclosing
}

i.current.assign(c.Name, c_)
return nil
}

func (i *Interpreter) VisitBinaryExpr(b *parser.BinaryExpr) interface{} {
left := b.Left.Accept(i)
right := b.Right.Accept(i)

switch b.Operator.TokenType {
case scanner.BANG_EQUAL:
return left != right
Expand Down Expand Up @@ -190,7 +208,6 @@ func (i *Interpreter) VisitLiteralExpr(l *parser.LiteralExpr) interface{} {

func (i *Interpreter) VisitUnaryExpr(u *parser.UnaryExpr) interface{} {
right := u.Right.Accept(i)

if u.Operator.TokenType == scanner.MINUS {
if value, ok := right.(float64); ok {
return -value
Expand Down Expand Up @@ -223,7 +240,6 @@ func (i *Interpreter) VisitVariableExpr(v *parser.VariableExpr) interface{} {

func (i *Interpreter) VisitAssignExpr(a *parser.AssignExpr) interface{} {
value := a.Value.Accept(i)

if dist, ok := i.locals[a]; ok {
i.current.assignAt(a.Name.Lexeme, value, dist)
} else {
Expand All @@ -235,7 +251,6 @@ func (i *Interpreter) VisitAssignExpr(a *parser.AssignExpr) interface{} {

func (i *Interpreter) VisitLogicalExpr(l *parser.LogicalExpr) interface{} {
left := l.Left.Accept(i)

if (l.Operator.TokenType == scanner.OR && isTruthy(left)) || !isTruthy(left) {
return left
}
Expand Down Expand Up @@ -290,6 +305,19 @@ func (i *Interpreter) VisitThisExpr(t *parser.ThisExpr) interface{} {
return i.global.get(t.Keyword)
}

func (i *Interpreter) VisitSuperExpr(s *parser.SuperExpr) interface{} {
dist := i.locals[s]
super := i.current.getAt("super", dist).(*class)
object := i.current.getAt("this", dist-1).(*instance)
method := super.findMethod(s.Method.Lexeme)
if method == nil {
message := fmt.Sprintf("undefined property '%s'", s.Method.Lexeme)
panic(fault.NewFault(s.Method.Line, message))
}

return method.bind(object)
}

func (i *Interpreter) checkNumberOperands(operator *scanner.Token, left interface{}, right interface{}) (float64, float64) {
if leftValue, leftOk := left.(float64); leftOk {
if rightValue, rightOk := right.(float64); rightOk {
Expand All @@ -310,4 +338,4 @@ func isTruthy(value interface{}) bool {
}

return true
}
}
10 changes: 10 additions & 0 deletions pkg/parser/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package parser
import "golox/pkg/scanner"



type Expr interface {
Accept(v ExprVisitor) interface{}
}
Expand Down Expand Up @@ -104,4 +105,13 @@ type ThisExpr struct {

func (t *ThisExpr) Accept(v ExprVisitor) interface{} {
return v.VisitThisExpr(t)
}

type SuperExpr struct {
Keyword *scanner.Token
Method *scanner.Token
}

func (s *SuperExpr) Accept(v ExprVisitor) interface{} {
return v.VisitSuperExpr(s)
}
37 changes: 20 additions & 17 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ func NewParser(tokens []scanner.Token) *Parser {

func (p *Parser) Parse() ([]Stmt, error) {
stmts := []Stmt{}

for p.tokens[p.current].TokenType != scanner.EOF {
stmts = append(stmts, p.declaration())
}
Expand Down Expand Up @@ -112,6 +111,15 @@ func (p *Parser) classDeclaration() *ClassStmt {
}
name := p.tokens[p.current-1]

var super *VariableExpr
if p.match(scanner.LESS) {
if !p.match(scanner.IDENTIFIER) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected superclass name after '<'"))
}
superName := p.tokens[p.current-1]
super = &VariableExpr{&superName}
}

if !p.match(scanner.LEFT_BRACE) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected '{' before class body"))
}
Expand All @@ -125,7 +133,7 @@ func (p *Parser) classDeclaration() *ClassStmt {
panic(fault.NewFault(p.tokens[p.current].Line, "expected '}' after class body"))
}

return &ClassStmt{&name, methods}
return &ClassStmt{&name, super, methods}
}

func (p *Parser) statement() Stmt {
Expand Down Expand Up @@ -158,7 +166,6 @@ func (p *Parser) statement() Stmt {

func (p *Parser) printStatement() *PrintStmt {
expr := p.expression()

if !p.match(scanner.SEMICOLON) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after print statement"))
}
Expand Down Expand Up @@ -248,7 +255,6 @@ func (p *Parser) whileStatement() *WhileStmt {

func (p *Parser) blockStatement() *BlockStmt {
stmts := []Stmt{}

for p.tokens[p.current].TokenType != scanner.RIGHT_BRACE && p.tokens[p.current].TokenType != scanner.EOF {
stmts = append(stmts, p.declaration())
}
Expand All @@ -262,7 +268,6 @@ func (p *Parser) blockStatement() *BlockStmt {

func (p *Parser) exprStatement() *ExprStmt {
expr := p.expression()

if !p.match(scanner.SEMICOLON) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after expression statement"))
}
Expand All @@ -272,7 +277,6 @@ func (p *Parser) exprStatement() *ExprStmt {

func (p *Parser) returnStatement() *ReturnStmt {
keyword := p.tokens[p.current-1]

var value Expr
if p.tokens[p.current].TokenType != scanner.SEMICOLON && p.tokens[p.current].TokenType != scanner.EOF {
value = p.expression()
Expand All @@ -291,7 +295,6 @@ func (p *Parser) expression() Expr {

func (p *Parser) assignment() Expr {
expr := p.or()

if p.match(scanner.EQUAL) {
equals := p.tokens[p.current-1]
value := p.assignment()
Expand All @@ -312,7 +315,6 @@ func (p *Parser) assignment() Expr {

func (p *Parser) or() Expr {
left := p.and()

for p.match(scanner.OR) {
operator := p.tokens[p.current-1]
right := p.and()
Expand All @@ -324,7 +326,6 @@ func (p *Parser) or() Expr {

func (p *Parser) and() Expr {
left := p.equality()

for p.match(scanner.AND) {
operator := p.tokens[p.current-1]
right := p.equality()
Expand All @@ -336,7 +337,6 @@ func (p *Parser) and() Expr {

func (p *Parser) equality() Expr {
left := p.comparison()

for p.match(scanner.BANG_EQUAL, scanner.EQUAL_EQUAL) {
operator := p.tokens[p.current-1]
right := p.comparison()
Expand All @@ -348,7 +348,6 @@ func (p *Parser) equality() Expr {

func (p *Parser) comparison() Expr {
left := p.term()

for p.match(scanner.GREATER, scanner.GREATER_EQUAL, scanner.LESS, scanner.LESS_EQUAL) {
operator := p.tokens[p.current-1]
right := p.term()
Expand All @@ -360,7 +359,6 @@ func (p *Parser) comparison() Expr {

func (p *Parser) term() Expr {
left := p.factor()

for p.match(scanner.MINUS, scanner.PLUS) {
operator := p.tokens[p.current-1]
right := p.factor()
Expand All @@ -372,7 +370,6 @@ func (p *Parser) term() Expr {

func (p *Parser) factor() Expr {
left := p.unary()

for p.match(scanner.SLASH, scanner.STAR) {
operator := p.tokens[p.current-1]
right := p.unary()
Expand All @@ -394,7 +391,6 @@ func (p *Parser) unary() Expr {

func (p *Parser) call() Expr {
expr := p.primary()

for {
if p.match(scanner.LEFT_PAREN) {
args, paren := p.arguments()
Expand Down Expand Up @@ -460,6 +456,15 @@ func (p *Parser) primary() Expr {
return &ThisExpr{previous}
}

if p.match(scanner.SUPER) {
keyword := p.tokens[p.current-1]
if !p.match(scanner.DOT) || !p.match(scanner.IDENTIFIER) {
panic(fault.NewFault(p.tokens[p.current].Line, "expected property access after 'super'"))
}
method := p.tokens[p.current-1]
return &SuperExpr{&keyword, &method}
}

if p.match(scanner.LEFT_PAREN) {
e := p.expression()
if !p.match(scanner.RIGHT_PAREN) {
Expand Down Expand Up @@ -491,9 +496,7 @@ func (p *Parser) match(types ...int) bool {

func (p *Parser) synchronize() {
if r := recover(); r != nil {
defer func() {
p.err = r.(error)
}()
defer func() { p.err = r.(error) }()

if p.tokens[p.current].TokenType != scanner.EOF {
p.current++
Expand Down
2 changes: 2 additions & 0 deletions pkg/parser/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package parser

import "golox/pkg/scanner"


type Stmt interface {
Accept(v StmtVisitor) interface{}
}
Expand Down Expand Up @@ -79,6 +80,7 @@ func (r *ReturnStmt) Accept(v StmtVisitor) interface{} {

type ClassStmt struct {
Name *scanner.Token
Super *VariableExpr
Methods []*FunStmt
}

Expand Down
1 change: 1 addition & 0 deletions pkg/parser/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type ExprVisitor interface {
VisitGetExpr(g *GetExpr) interface{}
VisitSetExpr(s *SetExpr) interface{}
VisitThisExpr(t *ThisExpr) interface{}
VisitSuperExpr(s *SuperExpr) interface{}
}

type StmtVisitor interface {
Expand Down
Loading

0 comments on commit dabd695

Please sign in to comment.