Преглед на файлове

return statement object evaluator

simon преди 3 години
родител
ревизия
dd82147c66
променени са 3 файла, в които са добавени 62 реда и са изтрити 8 реда
  1. 25 5
      evaluator/evaluator.go
  2. 26 0
      evaluator/evaluator_test.go
  3. 11 3
      object/object.go

+ 25 - 5
evaluator/evaluator.go

@@ -15,7 +15,7 @@ func Eval(node ast.Node) object.Object {
 	switch node := node.(type) {
 	// Statements
 	case *ast.Program:
-		return evalStatements(node.Statements)
+		return evalProgram(node)
 
 	case *ast.ExpressionStatement:
 		return Eval(node.Expression)
@@ -34,20 +34,29 @@ func Eval(node ast.Node) object.Object {
 		return evalInfixExpression(node.Operator, left, right)
 
 	case *ast.BlockStatement:
-		return evalStatements(node.Statements)
+		return evalBlockStatements(node)
 	case *ast.IfExpression:
 		return evalIfExpression(node)
+
+	case *ast.ReturnStatement:
+		val := Eval(node.ReturnValue)
+		return &object.ReturnValue{Value: val}
 	}
 
 	return nil
 }
 
-func evalStatements(stmts []ast.Statement) object.Object {
+func evalProgram(node *ast.Program) object.Object {
 	var result object.Object
 
-	for _, statement := range stmts {
-		result = Eval(statement)
+	for _, stmt := range node.Statements {
+		result = Eval(stmt)
+		// skip return 后面的语句
+		if rVal, ok := result.(*object.ReturnValue); ok {
+			return rVal.Value
+		}
 	}
+
 	return result
 }
 
@@ -91,7 +100,18 @@ func evalIfExpression(node *ast.IfExpression) object.Object {
 		return NULL
 	}
 }
+func evalBlockStatements(node *ast.BlockStatement) object.Object {
+	var result object.Object
+
+	for _, stmt := range node.Statements {
+		result = Eval(stmt)
 
+		if result != nil && result.Type() == object.ReturnValueObj {
+			return result
+		}
+	}
+	return result
+}
 func evalIntegerInfixExpression(operator string,
 	left object.Object, right object.Object) object.Object {
 	leftVal := left.(*object.Integer).Value

+ 26 - 0
evaluator/evaluator_test.go

@@ -95,6 +95,32 @@ func TestIfElseExpression(t *testing.T) {
 	}
 }
 
+func TestReturnStatements(t *testing.T) {
+	tests := []struct {
+		input    string
+		expected int64
+	}{
+		{"return 10;", 10},
+		{"return 10;9", 10},
+		{"return 2*5;9;", 10},
+		{"9;return 3 *5; 9;", 15},
+		{
+			`if ( 10 > 1) {
+				if( 10 > 1) {
+					return 10;
+				}
+				return 1;
+				}
+           `,
+			10,
+		},
+	}
+	for _, tt := range tests {
+		evaluated := testEval(tt.input)
+		testIntegerObject(t, evaluated, tt.expected)
+	}
+}
+
 func testNullObject(t *testing.T, obj object.Object) bool {
 	if obj != NULL {
 		t.Errorf("object is not NULL. got=%T (%+v)", obj, obj)

+ 11 - 3
object/object.go

@@ -6,9 +6,10 @@ import "fmt"
 type ObjType string
 
 const (
-	IntegerObj = "INTEGER"
-	BooleanObj = "BOOLEAN"
-	NullObj    = "NULL"
+	IntegerObj     = "INTEGER"
+	BooleanObj     = "BOOLEAN"
+	NullObj        = "NULL"
+	ReturnValueObj = "RETURN_VALUE"
 )
 
 // Object source code as an Object
@@ -35,3 +36,10 @@ type Null struct{}
 
 func (n *Null) Type() ObjType   { return NullObj }
 func (n *Null) Inspect() string { return "null" }
+
+type ReturnValue struct {
+	Value Object
+}
+
+func (rv *ReturnValue) Type() ObjType   { return ReturnValueObj }
+func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() }