浏览代码

Error handling

simon 3 年之前
父节点
当前提交
55d2ce3f3c
共有 3 个文件被更改,包括 107 次插入10 次删除
  1. 44 10
      evaluator/evaluator.go
  2. 55 0
      evaluator/evaluator_test.go
  3. 8 0
      object/object.go

+ 44 - 10
evaluator/evaluator.go

@@ -1,6 +1,7 @@
 package evaluator
 
 import (
+	"fmt"
 	"github/runnignwater/monkey/ast"
 	"github/runnignwater/monkey/object"
 )
@@ -27,10 +28,19 @@ func Eval(node ast.Node) object.Object {
 		return nativeBooleanObject(node.Value)
 	case *ast.PrefixExpression:
 		right := Eval(node.Right)
+		if isError(right) {
+			return right
+		}
 		return evalPrefixExpression(node.Operator, right)
 	case *ast.InfixExpression:
 		left := Eval(node.Left)
+		if isError(left) {
+			return left
+		}
 		right := Eval(node.Right)
+		if isError(right) {
+			return right
+		}
 		return evalInfixExpression(node.Operator, left, right)
 
 	case *ast.BlockStatement:
@@ -40,20 +50,29 @@ func Eval(node ast.Node) object.Object {
 
 	case *ast.ReturnStatement:
 		val := Eval(node.ReturnValue)
+		if isError(val) {
+			return val
+		}
 		return &object.ReturnValue{Value: val}
 	}
 
 	return nil
 }
 
+func newError(format string, a ...interface{}) *object.Error {
+	return &object.Error{Msg: fmt.Sprintf(format, a...)}
+}
 func evalProgram(node *ast.Program) object.Object {
 	var result object.Object
 
 	for _, stmt := range node.Statements {
 		result = Eval(stmt)
 		// skip return 后面的语句
-		if rVal, ok := result.(*object.ReturnValue); ok {
-			return rVal.Value
+		switch result := result.(type) {
+		case *object.ReturnValue:
+			return result.Value
+		case *object.Error:
+			return result
 		}
 	}
 
@@ -75,7 +94,7 @@ func evalPrefixExpression(operator string, right object.Object) object.Object {
 	case "-":
 		return evalMinusPreOperatorExpression(right)
 	default:
-		return NULL
+		return newError("unknown operator: %s%s", operator, right.Type())
 	}
 }
 func evalInfixExpression(operator string, left object.Object, right object.Object) object.Object {
@@ -86,12 +105,17 @@ func evalInfixExpression(operator string, left object.Object, right object.Objec
 		return nativeBooleanObject(left == right)
 	case operator == "!=":
 		return nativeBooleanObject(left != right)
+	case left.Type() != right.Type():
+		return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type())
 	default:
-		return NULL
+		return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
 	}
 }
 func evalIfExpression(node *ast.IfExpression) object.Object {
 	condition := Eval(node.Condition)
+	if isError(condition) {
+		return condition
+	}
 	if isTruthy(condition) {
 		return Eval(node.Consequence)
 	} else if node.Alternative != nil {
@@ -100,14 +124,17 @@ func evalIfExpression(node *ast.IfExpression) object.Object {
 		return NULL
 	}
 }
-func evalBlockStatements(node *ast.BlockStatement) object.Object {
+func evalBlockStatements(block *ast.BlockStatement) object.Object {
 	var result object.Object
 
-	for _, stmt := range node.Statements {
+	for _, stmt := range block.Statements {
 		result = Eval(stmt)
 
-		if result != nil && result.Type() == object.ReturnValueObj {
-			return result
+		if result != nil {
+			rt := result.Type()
+			if rt == object.ReturnValueObj || rt == object.ErrorObj {
+				return result
+			}
 		}
 	}
 	return result
@@ -135,13 +162,13 @@ func evalIntegerInfixExpression(operator string,
 	case "!=":
 		return nativeBooleanObject(leftVal != rightVal)
 	default:
-		return NULL
+		return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
 	}
 }
 
 func evalMinusPreOperatorExpression(right object.Object) object.Object {
 	if right.Type() != object.IntegerObj {
-		return NULL
+		return newError("unknown operator: -%s", right.Type())
 	}
 	value := right.(*object.Integer).Value
 	return &object.Integer{Value: -value}
@@ -170,3 +197,10 @@ func isTruthy(obj object.Object) bool {
 		return true
 	}
 }
+
+func isError(obj object.Object) bool {
+	if obj != nil {
+		return obj.Type() == object.ErrorObj
+	}
+	return false
+}

+ 55 - 0
evaluator/evaluator_test.go

@@ -121,6 +121,61 @@ func TestReturnStatements(t *testing.T) {
 	}
 }
 
+func TestErrorHandling(t *testing.T) {
+	tests := []struct {
+		input       string
+		expectedMsg string
+	}{
+		{
+			"5 + true",
+			"type mismatch: INTEGER + BOOLEAN",
+		},
+		{
+			"5 + true; 5;",
+			"type mismatch: INTEGER + BOOLEAN",
+		},
+		{
+			"-true",
+			"unknown operator: -BOOLEAN",
+		},
+		{
+			"true + false",
+			"unknown operator: BOOLEAN + BOOLEAN",
+		},
+		{
+			"5;true + false;5",
+			"unknown operator: BOOLEAN + BOOLEAN",
+		},
+		{
+			"if ( 10 > 1) { true + false; }",
+			"unknown operator: BOOLEAN + BOOLEAN",
+		},
+		{
+			`
+			if (10 > 1) {
+				if (10 > 1) {
+					return true + false;
+				}
+				return 1;
+			}
+       		`,
+			"unknown operator: BOOLEAN + BOOLEAN",
+		},
+	}
+
+	for _, tt := range tests {
+		evaluated := testEval(tt.input)
+
+		errObj, ok := evaluated.(*object.Error)
+		if !ok {
+			t.Errorf("no error object returned. got=%T (%+v)", evaluated, evaluated)
+			continue
+		}
+		if errObj.Msg != tt.expectedMsg {
+			t.Errorf("wrong error message. expected=%q, got=%q", tt.expectedMsg, errObj.Msg)
+		}
+	}
+}
 func testNullObject(t *testing.T, obj object.Object) bool {
 	if obj != NULL {
 		t.Errorf("object is not NULL. got=%T (%+v)", obj, obj)

+ 8 - 0
object/object.go

@@ -10,6 +10,7 @@ const (
 	BooleanObj     = "BOOLEAN"
 	NullObj        = "NULL"
 	ReturnValueObj = "RETURN_VALUE"
+	ErrorObj       = "ERROR"
 )
 
 // Object source code as an Object
@@ -43,3 +44,10 @@ type ReturnValue struct {
 
 func (rv *ReturnValue) Type() ObjType   { return ReturnValueObj }
 func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() }
+
+type Error struct {
+	Msg string
+}
+
+func (e *Error) Type() ObjType   { return ErrorObj }
+func (e *Error) Inspect() string { return "ERROR: " + e.Msg }