Преглед на файлове

Functions Local Bindings -- in the compiler

runningwater преди 3 години
родител
ревизия
0aa321e310
променени са 12 файла, в които са добавени 279 реда и са изтрити 14 реда
  1. 13 0
      code/code.go
  2. 25 0
      code/code_test.go
  3. 13 2
      compiler/compiler.go
  4. 87 0
      compiler/compiler_test.go
  5. 20 1
      compiler/symbol_table.go
  6. 107 0
      compiler/symbol_table_test.go
  7. 2 2
      evaluator/builtins.go
  8. 1 1
      lexer/lexer.go
  9. 1 1
      lexer/lexer_test.go
  10. 2 2
      parser/parser.go
  11. 4 4
      parser/parser_tracing.go
  12. 4 1
      token/token.go

+ 13 - 0
code/code.go

@@ -46,6 +46,9 @@ const (
 	OpCall
 	OpReturnValue
 	OpReturn // no return value
+
+	OpGetLocal
+	OpSetLocal
 )
 
 type Instructions []byte
@@ -135,6 +138,9 @@ var definitions = map[Opcode]*Definition{
 	OpCall:        {"OpCall", []int{}},
 	OpReturnValue: {"OpReturnValue", []int{}},
 	OpReturn:      {"OpReturn", []int{}},
+
+	OpGetLocal: {"OpGetLocal", []int{1}},
+	OpSetLocal: {"OpSetLocal", []int{1}},
 }
 
 func Lookup(op byte) (*Definition, error) {
@@ -164,6 +170,8 @@ func Make(op Opcode, operands ...int) []byte {
 		switch width {
 		case 2:
 			binary.BigEndian.PutUint16(instruction[offset:], uint16(o))
+		case 1:
+			instruction[offset] = byte(o)
 		}
 		offset += width
 	}
@@ -180,6 +188,8 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) {
 		switch width {
 		case 2:
 			operands[i] = int(ReadUint16(ins[offset:]))
+		case 1:
+			operands[i] = int(ReadUint8(ins[offset:]))
 		}
 
 		offset += width
@@ -191,3 +201,6 @@ func ReadOperands(def *Definition, ins Instructions) ([]int, int) {
 func ReadUint16(ins Instructions) uint16 {
 	return binary.BigEndian.Uint16(ins)
 }
+func ReadUint8(ins Instructions) uint8 {
+	return uint8(ins[0])
+}

+ 25 - 0
code/code_test.go

@@ -10,6 +10,7 @@ func TestMake(t *testing.T) {
 	}{
 		{OpConstant, []int{65534}, []byte{byte(OpConstant), 0xFF, 0xFE}},
 		{OpAdd, []int{}, []byte{byte(OpAdd)}},
+		{OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}},
 	}
 
 	for _, tt := range tests {
@@ -48,6 +49,29 @@ func TestInstructionsString(t *testing.T) {
 	}
 }
 
+func TestInstructionsStringGetLocal(t *testing.T) {
+	instructions := []Instructions{
+		Make(OpAdd),
+		Make(OpGetLocal, 1),
+		Make(OpConstant, 2),
+		Make(OpConstant, 65535),
+	}
+
+	expected := `0000 OpAdd
+0001 OpGetLocal 1
+0003 OpConstant 2
+0006 OpConstant 65535
+`
+
+	conCatted := Instructions{}
+	for _, ins := range instructions {
+		conCatted = append(conCatted, ins...)
+	}
+	if conCatted.String() != expected {
+		t.Errorf("instructions wrong formatted.\nwant=%q\n got=%q", expected, conCatted.String())
+	}
+}
+
 func TestReadOperands(t *testing.T) {
 	tests := []struct {
 		op        Opcode
@@ -55,6 +79,7 @@ func TestReadOperands(t *testing.T) {
 		bytesRead int
 	}{
 		{OpConstant, []int{65535}, 2},
+		{OpGetLocal, []int{255}, 1},
 	}
 
 	for _, tt := range tests {

+ 13 - 2
compiler/compiler.go

@@ -173,14 +173,22 @@ func (c *Compiler) Compile(node ast.Node) error {
 			return err
 		}
 		symbol := c.symbolTable.Define(node.Name.Value)
-		c.emit(code.OpSetGlobal, symbol.Index)
+		if symbol.Scope == GlobalScope {
+			c.emit(code.OpSetGlobal, symbol.Index)
+		} else {
+			c.emit(code.OpSetLocal, symbol.Index)
+		}
 
 	case *ast.Identifier:
 		symbol, ok := c.symbolTable.Resolve(node.Value)
 		if !ok {
 			return fmt.Errorf("undefined variable %s", node.Value)
 		}
-		c.emit(code.OpGetGlobal, symbol.Index)
+		if symbol.Scope == GlobalScope {
+			c.emit(code.OpGetGlobal, symbol.Index)
+		} else {
+			c.emit(code.OpGetLocal, symbol.Index)
+		}
 
 	case *ast.Boolean:
 		if node.Value {
@@ -354,6 +362,8 @@ func (c *Compiler) enterScope() {
 	}
 	c.scopes = append(c.scopes, scope)
 	c.scopeIndex++
+
+	c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
 }
 
 func (c *Compiler) leaveScope() code.Instructions {
@@ -362,6 +372,7 @@ func (c *Compiler) leaveScope() code.Instructions {
 	c.scopes = c.scopes[:len(c.scopes)-1]
 	c.scopeIndex--
 
+	c.symbolTable = c.symbolTable.Outer
 	return ins
 }
 

+ 87 - 0
compiler/compiler_test.go

@@ -264,6 +264,76 @@ func TestGlobalLetStatements(t *testing.T) {
 	runCompilerTests(t, tests)
 }
 
+func TestLetStatement(t *testing.T) {
+	tests := []compilerTestCase{
+		{
+			input: `let num = 55; fn(){num};`,
+			expectedConstants: []interface{}{
+				55,
+				[]code.Instructions{
+					code.Make(code.OpGetGlobal, 0),
+					code.Make(code.OpReturnValue),
+				},
+			},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 0),
+				code.Make(code.OpSetGlobal, 0),
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input: `
+			fn () {
+				let num = 55;
+				num;
+			}
+			`,
+			expectedConstants: []interface{}{
+				55,
+				[]code.Instructions{
+					code.Make(code.OpConstant, 0),
+					code.Make(code.OpSetLocal, 0),
+					code.Make(code.OpGetLocal, 0),
+					code.Make(code.OpReturnValue),
+				},
+			},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input: `
+			fn () {
+				let a = 55;
+				let b = 77;
+				a + b;
+ 			}
+			`,
+			expectedConstants: []interface{}{
+				55, 77,
+				[]code.Instructions{
+					code.Make(code.OpConstant, 0),
+					code.Make(code.OpSetLocal, 0),
+					code.Make(code.OpConstant, 1),
+					code.Make(code.OpSetLocal, 1),
+					code.Make(code.OpGetLocal, 0),
+					code.Make(code.OpGetLocal, 1),
+					code.Make(code.OpAdd),
+					code.Make(code.OpReturnValue),
+				},
+			},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 2),
+				code.Make(code.OpPop),
+			},
+		},
+	}
+
+	runCompilerTests(t, tests)
+}
+
 func TestStringExpression(t *testing.T) {
 	tests := []compilerTestCase{
 		{
@@ -474,8 +544,12 @@ func TestCompilerScopes(t *testing.T) {
 	if compiler.scopeIndex != 0 {
 		t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 0)
 	}
+
+	globalSymbolTable := compiler.symbolTable
+
 	compiler.emit(code.OpMul)
 
+	// enterScope ======================================================================================================
 	compiler.enterScope()
 	if compiler.scopeIndex != 1 {
 		t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 0)
@@ -491,10 +565,23 @@ func TestCompilerScopes(t *testing.T) {
 	if last.Opcode != code.OpSub {
 		t.Errorf("lastInstruction.Opcde wrong. got=%d, want=%d", last.Opcode, code.OpSub)
 	}
+	if compiler.symbolTable.Outer != globalSymbolTable {
+		t.Errorf("compiler did not enclose symboTable")
+	}
+
 	compiler.leaveScope()
 	if compiler.scopeIndex != 0 {
 		t.Errorf("scopeIndex wrong. got=%d, want=%d", compiler.scopeIndex, 0)
 	}
+	// leaveScope ======================================================================================================
+
+	if compiler.symbolTable != globalSymbolTable {
+		t.Errorf("compiler did not restore global symbol table")
+	}
+
+	if compiler.symbolTable.Outer != nil {
+		t.Errorf("compiler modified global symbol table incorrectly")
+	}
 
 	compiler.emit(code.OpAdd)
 	if len(compiler.scopes[compiler.scopeIndex].instructions) != 2 {

+ 20 - 1
compiler/symbol_table.go

@@ -4,6 +4,7 @@ type SymbolScope string
 
 const (
 	GlobalScope SymbolScope = "GLOBAL"
+	LocalScope  SymbolScope = "LOCAL"
 )
 
 type Symbol struct {
@@ -13,6 +14,8 @@ type Symbol struct {
 }
 
 type SymbolTable struct {
+	Outer *SymbolTable
+
 	store          map[string]Symbol
 	numDefinitions int
 }
@@ -22,8 +25,20 @@ func NewSymbolTable() *SymbolTable {
 	return &SymbolTable{store: s}
 }
 
+func NewEnclosedSymbolTable(outer *SymbolTable) *SymbolTable {
+	s := NewSymbolTable()
+	s.Outer = outer
+	return s
+}
+
 func (s *SymbolTable) Define(name string) Symbol {
-	symbol := Symbol{Name: name, Index: s.numDefinitions, Scope: GlobalScope}
+	symbol := Symbol{Name: name, Index: s.numDefinitions}
+	if s.Outer == nil {
+		symbol.Scope = GlobalScope
+	} else {
+		symbol.Scope = LocalScope
+	}
+
 	s.store[name] = symbol
 	s.numDefinitions++
 	return symbol
@@ -31,5 +46,9 @@ func (s *SymbolTable) Define(name string) Symbol {
 
 func (s *SymbolTable) Resolve(name string) (Symbol, bool) {
 	obj, ok := s.store[name]
+	if !ok && s.Outer != nil {
+		obj, ok = s.Outer.Resolve(name)
+		return obj, ok
+	}
 	return obj, ok
 }

+ 107 - 0
compiler/symbol_table_test.go

@@ -6,6 +6,10 @@ func TestDefine(t *testing.T) {
 	expected := map[string]Symbol{
 		"a": {"a", GlobalScope, 0},
 		"b": {"b", GlobalScope, 1},
+		"c": {"c", LocalScope, 0},
+		"d": {"d", LocalScope, 1},
+		"e": {"e", LocalScope, 0},
+		"f": {"f", LocalScope, 1},
 	}
 
 	global := NewSymbolTable()
@@ -18,6 +22,30 @@ func TestDefine(t *testing.T) {
 	if b != expected["b"] {
 		t.Errorf("expected b=%+v, got=%+v", expected["b"], b)
 	}
+
+	firstLocal := NewEnclosedSymbolTable(global)
+
+	c := firstLocal.Define("c")
+	if c != expected["c"] {
+		t.Errorf("expected c=%+v, got=%+v", expected["c"], c)
+	}
+
+	d := firstLocal.Define("d")
+	if d != expected["d"] {
+		t.Errorf("expected d=%+v, got=%+v", expected["d"], d)
+	}
+
+	secondLocal := NewEnclosedSymbolTable(firstLocal)
+
+	e := secondLocal.Define("e")
+	if e != expected["e"] {
+		t.Errorf("expeted e=%+v, got=%+v", expected["e"], e)
+	}
+
+	f := secondLocal.Define("f")
+	if f != expected["f"] {
+		t.Errorf("expected f=%+v, got=%+v", expected["f"], f)
+	}
 }
 
 func TestResolveGlobal(t *testing.T) {
@@ -42,3 +70,82 @@ func TestResolveGlobal(t *testing.T) {
 		}
 	}
 }
+
+func TestResolveLocal(t *testing.T) {
+	global := NewSymbolTable()
+	global.Define("a")
+	global.Define("b")
+
+	local := NewEnclosedSymbolTable(global)
+	local.Define("c")
+	local.Define("d")
+
+	expected := []Symbol{
+		{Name: "a", Scope: GlobalScope, Index: 0},
+		{Name: "b", Scope: GlobalScope, Index: 1},
+		{Name: "c", Scope: LocalScope, Index: 0},
+		{Name: "d", Scope: LocalScope, Index: 1},
+	}
+
+	for _, sym := range expected {
+		result, ok := local.Resolve(sym.Name)
+		if !ok {
+			t.Errorf("name %s not resolvable", sym.Name)
+			continue
+		}
+		if result != sym {
+			t.Errorf("expected %s to resolve to %+v, got=%+v", sym.Name, sym, result)
+		}
+	}
+}
+
+func TestResolveNestedLocal(t *testing.T) {
+	global := NewSymbolTable()
+	global.Define("a")
+	global.Define("b")
+
+	firstLocal := NewEnclosedSymbolTable(global)
+	firstLocal.Define("c")
+	firstLocal.Define("d")
+
+	secondLocal := NewEnclosedSymbolTable(firstLocal)
+	secondLocal.Define("e")
+	secondLocal.Define("f")
+
+	tests := []struct {
+		table           *SymbolTable
+		expectedSymbols []Symbol
+	}{
+		{
+			firstLocal,
+			[]Symbol{
+				{"a", GlobalScope, 0},
+				{"b", GlobalScope, 1},
+				{"c", LocalScope, 0},
+				{"d", LocalScope, 1},
+			},
+		},
+		{
+			secondLocal,
+			[]Symbol{
+				{"a", GlobalScope, 0},
+				{"b", GlobalScope, 1},
+				{"e", LocalScope, 0},
+				{"f", LocalScope, 1},
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		for _, sym := range tt.expectedSymbols {
+			result, ok := tt.table.Resolve(sym.Name)
+			if !ok {
+				t.Errorf("name %s not resolvable", sym.Name)
+				continue
+			}
+			if result != sym {
+				t.Errorf("expected %s resolve to %+v, got=%+v", sym.Name, sym, result)
+			}
+		}
+	}
+}

+ 2 - 2
evaluator/builtins.go

@@ -90,9 +90,9 @@ var builtins = map[string]*object.Builtin{
 	// >> let a = [1, 2, 3, 4];
 	// >> let b = push(a, 5);
 	// >> a
-	//[1, 2, 3, 4]
+	// [1, 2, 3, 4]
 	// >> b
-	//[1, 2, 3, 4, 5]
+	// [1, 2, 3, 4, 5]
 	"push": {
 		Fn: func(args ...object.Object) object.Object {
 			if len(args) != 2 {

+ 1 - 1
lexer/lexer.go

@@ -101,7 +101,7 @@ func (l *Lexer) NextToken() token.Token {
 		if l.peekChar() == '=' {
 			ch := l.ch
 			l.readChar()
-			tok = token.Token{Type: token.NOT_EQ, Literal: string(ch) + string(l.ch)}
+			tok = token.Token{Type: token.NotEq, Literal: string(ch) + string(l.ch)}
 		} else {
 			tok = newToken(token.BANG, l.ch)
 		}

+ 1 - 1
lexer/lexer_test.go

@@ -105,7 +105,7 @@ func TestNextToken(t *testing.T) {
 		{token.INT, "10"},
 		{token.SEMICOLON, ";"},
 		{token.INT, "10"},
-		{token.NOT_EQ, "!="},
+		{token.NotEq, "!="},
 		{token.INT, "9"},
 		{token.SEMICOLON, ";"},
 		{token.STRING, "foobar"},

+ 2 - 2
parser/parser.go

@@ -31,7 +31,7 @@ const (
 // precedences 指派 token 类型的优先级
 var precedences = map[token.TypeToken]int{
 	token.EQ:       EQUALS,
-	token.NOT_EQ:   EQUALS,
+	token.NotEq:    EQUALS,
 	token.GT:       LESSGREATER,
 	token.LT:       LESSGREATER,
 	token.PLUS:     SUM,
@@ -84,7 +84,7 @@ func New(l *lexer.Lexer) *Parser {
 	p.registerInfix(token.ASTERISK, p.parseInfixExpression)
 	p.registerInfix(token.SLASH, p.parseInfixExpression)
 	p.registerInfix(token.EQ, p.parseInfixExpression)
-	p.registerInfix(token.NOT_EQ, p.parseInfixExpression)
+	p.registerInfix(token.NotEq, p.parseInfixExpression)
 	p.registerInfix(token.GT, p.parseInfixExpression)
 	p.registerInfix(token.LT, p.parseInfixExpression)
 	p.registerInfix(token.LPAREN, p.parseCallExpression)

+ 4 - 4
parser/parser_tracing.go

@@ -21,12 +21,12 @@ func incIdent() { traceLevel = traceLevel + 1 }
 func decIdent() { traceLevel = traceLevel - 1 }
 
 func trace(msg string) string {
-	//incIdent()
-	//tracePrint("BEGIN " + msg)
+	// incIdent()
+	// tracePrint("BEGIN " + msg)
 	return msg
 }
 
 func untrace(msg string) {
-	//tracePrint("END " + msg)
-	//decIdent()
+	// tracePrint("END " + msg)
+	// decIdent()
 }

+ 4 - 1
token/token.go

@@ -18,6 +18,7 @@ const (
 	STRING = "STRING"
 
 	// Operators
+
 	ASSIGN   = "="
 	PLUS     = "+"
 	MINUS    = "-"
@@ -27,9 +28,10 @@ const (
 	LT       = "<"
 	GT       = ">"
 	EQ       = "=="
-	NOT_EQ   = "!="
+	NotEq    = "!="
 
 	// Delimiters
+
 	COMMA     = ","
 	SEMICOLON = ";"
 	COLON     = ":"
@@ -43,6 +45,7 @@ const (
 	RBRACKET = "]"
 
 	// Keywords
+
 	FUNCTION = "FUNCTION"
 	LET      = "LET"
 	TRUE     = "TRUE"