// Author: simon // Author: ynwdlxm@163.com // Date: 2022/10/19 14:39 // Desc: Compiler package compiler import ( "fmt" "github/runnignwater/monkey/ast" "github/runnignwater/monkey/code" "github/runnignwater/monkey/object" "sort" ) type EmittedInstruction struct { Opcode code.Opcode Position int } type Compiler struct { constants []object.Object // slice that serves as our constant pool symbolTable *SymbolTable scopes []CompilationScope scopeIndex int } type CompilationScope struct { instructions code.Instructions // hold the generated bytecode lastInstruction EmittedInstruction previousInstruction EmittedInstruction } func New() *Compiler { mainScope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } return &Compiler{ constants: []object.Object{}, symbolTable: NewSymbolTable(), scopes: []CompilationScope{mainScope}, scopeIndex: 0, } } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { compiler := New() compiler.symbolTable = s compiler.constants = constants return compiler } func (c *Compiler) Compile(node ast.Node) error { switch node := node.(type) { case *ast.Program: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.ExpressionStatement: err := c.Compile(node.Expression) if err != nil { return err } c.emit(code.OpPop) case *ast.InfixExpression: if node.Operator == "<" { err := c.Compile(node.Right) if err != nil { return err } err = c.Compile(node.Left) if err != nil { return err } c.emit(code.OpGreaterThan) return nil } err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "+": c.emit(code.OpAdd) case "-": c.emit(code.OpSub) case "*": c.emit(code.OpMul) case "/": c.emit(code.OpDiv) case ">": c.emit(code.OpGreaterThan) case "==": c.emit(code.OpEqual) case "!=": c.emit(code.OpNotEqual) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.PrefixExpression: err := c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "!": c.emit(code.OpBang) case "-": c.emit(code.OpMinus) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IfExpression: err := c.Compile(node.Condition) if err != nil { return err } // Emit an 'OpJumNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) err = c.Compile(node.Consequence) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OpJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) afterConsequencePos := len(c.currentInstructions()) c.changOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { c.emit(code.OpNull) } else { err := c.Compile(node.Alternative) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } afterAlternativePos := len(c.currentInstructions()) c.changOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.LetStatement: err := c.Compile(node.Value) if err != nil { return err } symbol := c.symbolTable.Define(node.Name.Value) c.emit(code.OpSetGlobal, symbol.Index) case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) if !ok { return fmt.Errorf("undefined variable %s", node.Value) } c.emit(code.OpGetGlobal, symbol.Index) case *ast.Boolean: if node.Value { c.emit(code.OpTrue) } else { c.emit(code.OpFalse) } case *ast.IntegerLiteral: integer := &object.Integer{Value: node.Value} c.emit(code.OpConstant, c.addConstant(integer)) case *ast.StringLiteral: str := &object.String{Value: node.Value} c.emit(code.OpConstant, c.addConstant(str)) case *ast.ArrayLiteral: for _, el := range node.Element { err := c.Compile(el) if err != nil { return err } } c.emit(code.OpArray, len(node.Element)) case *ast.HashLiteral: var keys []ast.Expression for k := range node.Pairs { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].String() < keys[j].String() }) for _, k := range keys { err := c.Compile(k) if err != nil { return err } err = c.Compile(node.Pairs[k]) if err != nil { return err } } c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Index) if err != nil { return err } c.emit(code.OpIndex) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) if err != nil { return err } c.emit(code.OpReturnValue) case *ast.FunctionLiteral: c.enterScope() err := c.Compile(node.Body) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.replaceLastPopWithReturn() } if !c.lastInstructionIs(code.OpReturnValue) { c.emit(code.OpReturn) } instructions := c.leaveScope() compiledFn := &object.CompileFunction{Instructions: instructions} c.emit(code.OpConstant, c.addConstant(compiledFn)) case *ast.CallExpression: err := c.Compile(node.Function) if err != nil { return err } c.emit(code.OpCall) } return nil } func (c *Compiler) addConstant(obj object.Object) int { c.constants = append(c.constants, obj) return len(c.constants) - 1 } func (c *Compiler) addInstruction(ins []byte) int { posNewInstruction := len(c.currentInstructions()) updatedInstructions := append(c.currentInstructions(), ins...) c.scopes[c.scopeIndex].instructions = updatedInstructions return posNewInstruction } // emit generate an instruction and add it to the results, // either by printing it, writing it to a file or // by adding it to a collection in memory // // op code.Opcode // // operands [operand]int func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) c.setLastInstruction(op, pos) return pos } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} c.scopes[c.scopeIndex].previousInstruction = previous c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) lastInstructionIsPop() bool { return c.scopes[c.scopeIndex].lastInstruction.Opcode == code.OpPop } func (c *Compiler) lastInstructionIs(op code.Opcode) bool { if len(c.currentInstructions()) == 0 { return false } return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { last := c.scopes[c.scopeIndex].lastInstruction previous := c.scopes[c.scopeIndex].previousInstruction oldIns := c.currentInstructions() newIns := oldIns[:last.Position] c.scopes[c.scopeIndex].instructions = newIns c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { ins := c.currentInstructions() for i := 0; i < len(newInstruction); i++ { ins[pos+i] = newInstruction[i] } } func (c *Compiler) changOperand(opPos int, operand int) { op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } func (c *Compiler) currentInstructions() code.Instructions { return c.scopes[c.scopeIndex].instructions } func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } c.scopes = append(c.scopes, scope) c.scopeIndex++ } func (c *Compiler) leaveScope() code.Instructions { ins := c.currentInstructions() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- return ins } func (c *Compiler) replaceLastPopWithReturn() { lastPop := c.scopes[c.scopeIndex].lastInstruction.Position c.replaceInstruction(lastPop, code.Make(code.OpReturnValue)) c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue } func (c *Compiler) ByteCode() *ByteCode { return &ByteCode{ Instructions: c.currentInstructions(), Constants: c.constants, } } type ByteCode struct { Instructions code.Instructions Constants []object.Object }