ソースを参照

Local Functions and Closures

runningwater 2 年 前
コミット
725676d576

+ 10 - 11
README.md

@@ -95,17 +95,21 @@
                 declaration  → funDecl
                              | varDecl
                              | statement ;
-          
-                 statement      → exprStmt
-                                 | ifStmt
-                                 | printStmt 
-                                 | block;
+
+                statement      → exprStmt
+                               | forStmt
+                               | ifStmt
+                               | printStmt
+                               | returnStmt
+                               | whileStmt
+                               | block ;
                  funDecl       → "fun" function;
                  function      →  IDENTIFIER "(" parameters? ")" block ;
                  block         → "{" declaration* "}"
                  ifStmt        → "if" "(" expression ")" statement ( "else" statement)? ;
             
                 varDecl  → "var" IDENTIFIER ( "=" expression )? ";" ;
+                returnStmt → "return" expression? ";";
               ```
           - Syntax tree -- Logical Operators
             ```
@@ -115,12 +119,7 @@
               logic_or       → logic_and ( "or" logic_and )* ;
               logic_and      → equality ( "and" equality )* ;
             ```
-            statement      → exprStmt
-            | forStmt
-            | ifStmt
-            | printStmt
-            | whileStmt
-            | block ;
+
 
             whileStmt   → "while" "(" expression ")" statement ;
             forStmt     → "for" " (" (varDecl | exprStmt | ";") 

+ 9 - 1
src/main/java/com/craftinginterpreters/lox/Interpreter.java

@@ -181,7 +181,7 @@ public class Interpreter implements Expr.Visitor<Object>, Stmt.Visitor<Void> {
 
     @Override
     public Void visitFunctionStmt(Stmt.Function stmt) {
-        LoxFunction function = new LoxFunction(stmt);
+        LoxFunction function = new LoxFunction(stmt, environment);
         environment.define(stmt.name.lexeme, function);
         return null;
     }
@@ -204,6 +204,14 @@ public class Interpreter implements Expr.Visitor<Object>, Stmt.Visitor<Void> {
     }
 
     @Override
+    public Void visitReturnStmt(Stmt.Return stmt) {
+        Object value = null;
+        if (stmt.value != null) value = evaluate(stmt.value);
+
+        throw new Return(value);
+    }
+
+    @Override
     public Void visitVarStmt(Stmt.Var stmt) {
         Object value = null;
         if (stmt.initializer != null) {

+ 10 - 4
src/main/java/com/craftinginterpreters/lox/LoxFunction.java

@@ -11,9 +11,11 @@ import java.util.List;
  */
 public class LoxFunction implements LoxCallable {
     private final Stmt.Function declaration;
+    private final Environment closure;
 
-    public LoxFunction(Stmt.Function declaration) {
+    public LoxFunction(Stmt.Function declaration, Environment closure) {
         this.declaration = declaration;
+        this.closure = closure;
     }
 
     @Override
@@ -23,18 +25,22 @@ public class LoxFunction implements LoxCallable {
 
     @Override
     public Object call(Interpreter interpreter, List<Object> arguments) {
-        Environment environment = new Environment(interpreter.globals);
+        Environment environment = new Environment(closure);
         for (int i = 0; i < declaration.params.size(); i++) {
             environment.define(declaration.params.get(i).lexeme, arguments.get(i));
         }
-        interpreter.executeBlock(declaration.body, environment);
+        try {
+            interpreter.executeBlock(declaration.body, environment);
+        } catch (Return returnVal) {
+            return returnVal.value;
+        }
         return null;
     }
 
     /**
      * <code>
      * fun add(a,b) {
-     *   print a + b;
+     * print a + b;
      * }
      * <p>
      * print add; // "<fn add>".

+ 14 - 1
src/main/java/com/craftinginterpreters/lox/Parser.java

@@ -54,11 +54,12 @@ public class Parser {
         }
     }
 
-    // statement  → exprStmt | forStmt | ifStmt | printStmt | whileStmt | block;
+    // statement  → exprStmt | forStmt | ifStmt | printStmt | returnStmt | whileStmt | block;
     private Stmt statement() {
         if (match(FOR)) return forStatement();
         if (match(IF)) return ifStatement();
         if (match(PRINT)) return printStatement();
+        if (match(RETURN)) return returnStatement();
         if (match(WHILE)) return whileStatement();
         if (match(LEFT_BRACE)) return new Stmt.Block(block());
         return expressionStatement();
@@ -173,6 +174,18 @@ public class Parser {
         return new Stmt.Print(value);
     }
 
+    // returnStatement → "return" expression? ";" ;
+    private Stmt returnStatement() {
+        Token keyword = previous();
+        Expr value = null;
+        if (!check(SEMICOLON)) {
+            value = expression();
+        }
+
+        consume(SEMICOLON, "Expect ';' after return value.");
+        return new Stmt.Return(keyword, value);
+    }
+
     // varDecl  → "var" IDENTIFIER ( "=" expression )? ";" ;
     private Stmt varDecl() {
         Token name = consume(IDENTIFIER, "Expect variable name.");

+ 17 - 0
src/main/java/com/craftinginterpreters/lox/Return.java

@@ -0,0 +1,17 @@
+/* Copyright (C) 2019-2023 Hangzhou HSH Co. Ltd.
+ * All right reserved.*/
+package com.craftinginterpreters.lox;
+
+/**
+ * @author simon
+ * @date 2023-07-26 15:07
+ * @desc
+ */
+public class Return extends RuntimeException {
+    final Object value;
+
+    public Return(Object value) {
+        super(null, null, false, false);
+        this.value = value;
+    }
+}

+ 41 - 22
src/main/java/com/craftinginterpreters/lox/Stmt.java

@@ -4,12 +4,9 @@ import java.util.List;
 
 /**
  * @author GenerateAst
- * @date 2023-07-25 17:57
+ * @date 2023-07-26 10:50
  */
 abstract class Stmt {
-    abstract <R> R accept(Visitor<R> visitor);
-
-
     interface Visitor<R> {
         R visitBlockStmt(Block stmt);
 
@@ -21,15 +18,16 @@ abstract class Stmt {
 
         R visitPrintStmt(Print stmt);
 
+        R visitReturnStmt(Return stmt);
+
         R visitVarStmt(Var stmt);
 
         R visitWhileStmt(While stmt);
 
     }
 
-    static class Block extends Stmt {
-        final List<Stmt> statements;
 
+    static class Block extends Stmt {
         Block(List<Stmt> statements) {
             this.statements = statements;
         }
@@ -38,11 +36,11 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitBlockStmt(this);
         }
+
+        final List<Stmt> statements;
     }
 
     static class Expression extends Stmt {
-        final Expr expression;
-
         Expression(Expr expression) {
             this.expression = expression;
         }
@@ -51,12 +49,11 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitExpressionStmt(this);
         }
+
+        final Expr expression;
     }
 
     static class Function extends Stmt {
-        final Token name;
-        final List<Token> params;
-        final List<Stmt> body;
         Function(Token name, List<Token> params, List<Stmt> body) {
             this.name = name;
             this.params = params;
@@ -67,12 +64,13 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitFunctionStmt(this);
         }
+
+        final Token name;
+        final  List<Token> params;
+        final  List<Stmt> body;
     }
 
     static class If extends Stmt {
-        final Expr condition;
-        final Stmt thenBranch;
-        final Stmt elseBranch;
         If(Expr condition, Stmt thenBranch, Stmt elseBranch) {
             this.condition = condition;
             this.thenBranch = thenBranch;
@@ -83,11 +81,13 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitIfStmt(this);
         }
+
+        final Expr condition;
+        final  Stmt thenBranch;
+        final  Stmt elseBranch;
     }
 
     static class Print extends Stmt {
-        final Expr expression;
-
         Print(Expr expression) {
             this.expression = expression;
         }
@@ -96,12 +96,26 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitPrintStmt(this);
         }
+
+        final Expr expression;
     }
 
-    static class Var extends Stmt {
-        final Token name;
-        final Expr initializer;
+    static class Return extends Stmt {
+        Return(Token keyword, Expr value) {
+            this.keyword = keyword;
+            this.value = value;
+        }
+
+        @Override
+        <R> R accept(Visitor<R> visitor) {
+            return visitor.visitReturnStmt(this);
+        }
+
+        final Token keyword;
+        final  Expr value;
+    }
 
+    static class Var extends Stmt {
         Var(Token name, Expr initializer) {
             this.name = name;
             this.initializer = initializer;
@@ -111,12 +125,12 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitVarStmt(this);
         }
+
+        final Token name;
+        final  Expr initializer;
     }
 
     static class While extends Stmt {
-        final Expr condition;
-        final Stmt body;
-
         While(Expr condition, Stmt body) {
             this.condition = condition;
             this.body = body;
@@ -126,5 +140,10 @@ abstract class Stmt {
         <R> R accept(Visitor<R> visitor) {
             return visitor.visitWhileStmt(this);
         }
+
+        final Expr condition;
+        final  Stmt body;
     }
+
+    abstract <R> R accept(Visitor<R> visitor);
 }

+ 1 - 0
src/main/java/com/craftinginterpreters/tool/GenerateAst.java

@@ -58,6 +58,7 @@ public class GenerateAst {
                 "Function    : Token name, List<Token> params, List<Stmt> body",
                 "If          : Expr condition, Stmt thenBranch, Stmt elseBranch",
                 "Print       : Expr expression",
+                "Return      : Token keyword, Expr value",
                 "Var         : Token name, Expr initializer",
                 "While       : Expr condition, Stmt body"
         ));