瀏覽代碼

Compiling Expression -- Comparison Operation

simon 3 年之前
父節點
當前提交
3dbaa2f4a6
共有 5 個文件被更改,包括 160 次插入14 次删除
  1. 22 14
      code/code.go
  2. 18 0
      compiler/compiler.go
  3. 60 0
      compiler/compiler_test.go
  4. 48 0
      vm/vm.go
  5. 12 0
      vm/vm_test.go

+ 22 - 14
code/code.go

@@ -62,13 +62,18 @@ type Opcode byte
 
 const (
 	OpConstant Opcode = iota
-	OpAdd
-	OpSub
-	OpMul
-	OpDiv
+	OpAdd             // +
+	OpSub             // -
+	OpMul             // *
+	OpDiv             // /
 	OpPop
-	OpTrue
-	OpFalse
+
+	OpTrue  // true
+	OpFalse // false
+
+	OpEqual       // ==
+	OpNotEqual    // !=
+	OpGreaterThan // >
 )
 
 // Definition For debugging and testing purposes
@@ -84,14 +89,17 @@ type Definition struct {
 }
 
 var definitions = map[Opcode]*Definition{
-	OpConstant: {"OpConstant", []int{2}},
-	OpAdd:      {"OpAdd", []int{}}, // doesn't have any operands
-	OpSub:      {"OpSub", []int{}},
-	OpMul:      {"OpMul", []int{}},
-	OpDiv:      {"OpDiv", []int{}},
-	OpPop:      {"OpPop", []int{}},
-	OpTrue:     {"OpTrue", []int{}},
-	OpFalse:    {"OpFalse", []int{}},
+	OpConstant:    {"OpConstant", []int{2}},
+	OpAdd:         {"OpAdd", []int{}}, // doesn't have any operands
+	OpSub:         {"OpSub", []int{}},
+	OpMul:         {"OpMul", []int{}},
+	OpDiv:         {"OpDiv", []int{}},
+	OpPop:         {"OpPop", []int{}},
+	OpTrue:        {"OpTrue", []int{}},
+	OpFalse:       {"OpFalse", []int{}},
+	OpEqual:       {"OpEqual", []int{}},
+	OpNotEqual:    {"OpNotEqual", []int{}},
+	OpGreaterThan: {"OpGreaterThan", []int{}},
 }
 
 func Lookup(op byte) (*Definition, error) {

+ 18 - 0
compiler/compiler.go

@@ -42,6 +42,18 @@ func (c *Compiler) Compile(node ast.Node) error {
 		c.emit(code.OpPop)
 
 	case *ast.InfixExpression:
+		if node.Operator == "<" {
+			err := c.Compile(node.Right)
+			if err != nil {
+				return err
+			}
+			err = c.Compile(node.Left)
+			if err != nil {
+				return err
+			}
+			c.emit(code.OpGreaterThan)
+			return nil
+		}
 		err := c.Compile(node.Left)
 		if err != nil {
 			return err
@@ -60,6 +72,12 @@ func (c *Compiler) Compile(node ast.Node) error {
 			c.emit(code.OpMul)
 		case "/":
 			c.emit(code.OpDiv)
+		case ">":
+			c.emit(code.OpGreaterThan)
+		case "==":
+			c.emit(code.OpEqual)
+		case "!=":
+			c.emit(code.OpNotEqual)
 		default:
 			return fmt.Errorf("unknown operator %s", node.Operator)
 		}

+ 60 - 0
compiler/compiler_test.go

@@ -103,6 +103,66 @@ func TestBooleanExpression(t *testing.T) {
 				code.Make(code.OpPop),
 			},
 		},
+		{
+			input:             "1 > 2",
+			expectedConstants: []interface{}{1, 2},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 0),
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpGreaterThan),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input:             "1 < 2",
+			expectedConstants: []interface{}{2, 1},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 0),
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpGreaterThan),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input:             "1 == 2",
+			expectedConstants: []interface{}{1, 2},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 0),
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpEqual),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input:             "1 != 2",
+			expectedConstants: []interface{}{1, 2},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpConstant, 0),
+				code.Make(code.OpConstant, 1),
+				code.Make(code.OpNotEqual),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input:             "true == false",
+			expectedConstants: []interface{}{},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpTrue),
+				code.Make(code.OpFalse),
+				code.Make(code.OpEqual),
+				code.Make(code.OpPop),
+			},
+		},
+		{
+			input:             "true != false",
+			expectedConstants: []interface{}{},
+			expectedInstructions: []code.Instructions{
+				code.Make(code.OpTrue),
+				code.Make(code.OpFalse),
+				code.Make(code.OpNotEqual),
+				code.Make(code.OpPop),
+			},
+		},
 	}
 
 	runCompilerTests(t, tests)

+ 48 - 0
vm/vm.go

@@ -71,6 +71,11 @@ func (vm *VM) Run() error {
 			if err != nil {
 				return err
 			}
+		case code.OpEqual, code.OpNotEqual, code.OpGreaterThan:
+			err := vm.executeComparison(op)
+			if err != nil {
+				return err
+			}
 		case code.OpPop:
 			vm.pop()
 		}
@@ -109,6 +114,23 @@ func (vm *VM) executeBinaryOperation(op code.Opcode) error {
 	return fmt.Errorf("unsupported types of binary operation: %s %s", leftType, rightType)
 }
 
+func (vm *VM) executeComparison(op code.Opcode) error {
+	right := vm.pop()
+	left := vm.pop()
+
+	if left.Type() == object.IntegerObj || right.Type() == object.IntegerObj {
+		return vm.executeIntegerComparison(op, left, right)
+	}
+	switch op {
+	case code.OpEqual:
+		return vm.push(nativeBoolToBooleanObject(right == left))
+	case code.OpNotEqual:
+		return vm.push(nativeBoolToBooleanObject(right != left))
+	default:
+		return fmt.Errorf("unknown operator: %d (%s %s)", op, left.Type(), right.Type())
+	}
+}
+
 func (vm *VM) LastPopStackElem() object.Object {
 	return vm.stack[vm.sp]
 }
@@ -137,3 +159,29 @@ func (vm *VM) executeBinaryIntegerOperation(
 
 	return vm.push(&object.Integer{Value: result})
 }
+
+func (vm *VM) executeIntegerComparison(
+	op code.Opcode,
+	left, right object.Object,
+) error {
+	leftValue := left.(*object.Integer).Value
+	rightValue := right.(*object.Integer).Value
+
+	switch op {
+	case code.OpEqual:
+		return vm.push(nativeBoolToBooleanObject(rightValue == leftValue))
+	case code.OpNotEqual:
+		return vm.push(nativeBoolToBooleanObject(rightValue != leftValue))
+	case code.OpGreaterThan:
+		return vm.push(nativeBoolToBooleanObject(leftValue > rightValue))
+	default:
+		return fmt.Errorf("unknown operator: %d", op)
+	}
+}
+
+func nativeBoolToBooleanObject(input bool) *object.Boolean {
+	if input {
+		return True
+	}
+	return False
+}

+ 12 - 0
vm/vm_test.go

@@ -39,6 +39,18 @@ func TestBooleanExpression(t *testing.T) {
 	tests := []vmTestCase{
 		{"true", true},
 		{"false", false},
+		{"1 < 2", true},
+		{"1 > 2", false},
+		{"1 < 1", false},
+		{"1 > 1", false},
+		{"1 == 1", true},
+		{"1 != 1", false},
+		{"true == true", true},
+		{"false == false", true},
+		{"true == false", false},
+		{"false != true", true},
+		{"(1 < 2) == true", true},
+		{"(1 < 2) == false", false},
 	}
 
 	runVmTests(t, tests)