瀏覽代碼

closure upvalues

runningwater 2 年之前
父節點
當前提交
a29924ec2a
共有 9 個文件被更改,包括 192 次插入20 次删除
  1. 5 2
      chunk.h
  2. 52 2
      compiler.c
  3. 19 0
      debug.c
  4. 9 0
      memory.c
  5. 26 1
      object.c
  6. 19 0
      object.h
  7. 11 0
      test/closures.lox
  8. 48 12
      vm.c
  9. 3 3
      vm.h

+ 5 - 2
chunk.h

@@ -26,8 +26,10 @@ typedef enum {
   OP_SET_GLOBAL,/// \brief setter
   OP_GET_LOCAL,
   OP_SET_LOCAL,/// \brief setter
-  OP_NOT,      /// \brief print !true; // "false"
-  OP_NEGATE,   /// \brief prefix -
+  OP_GET_UPVALUE,
+  OP_SET_UPVALUE,
+  OP_NOT,   /// \brief print !true; // "false"
+  OP_NEGATE,/// \brief prefix -
   OP_EQUAL,
   OP_GREATER,
   OP_LESS,
@@ -41,6 +43,7 @@ typedef enum {
   OP_LOOP,         /// <OP_LOOP ++>  往前跳
   OP_CALL,         /// <OP_CALL argCount>
   OP_RETURN,       ///<OP_RETURN>
+  OP_CLOSURE,
 } OpCode;
 
 //============================================================================

+ 52 - 2
compiler.c

@@ -47,7 +47,10 @@ typedef struct {
   Token name;
   int depth;
 } Local;
-
+typedef struct {
+  uint8_t index;
+  bool isLocal;
+} Upvalue;
 typedef enum {
   TYPE_FUNCTION,
   TYPE_SCRIPT,
@@ -60,6 +63,7 @@ typedef struct Compiler {
 
   Local locals[UINT8_COUNT];
   int localCount;/// how many locals are in scope
+  Upvalue upvalues[UINT8_COUNT];
   int scopeDepth;
 } Compiler;
 
@@ -197,7 +201,12 @@ static void function(FunctionType type) {
   block();
 
   ObjFunction *fun = endCompiler();
-  emitBytes(OP_CONSTANT, makeConstant(OBJ_VAL(fun)));
+  emitBytes(OP_CLOSURE, makeConstant(OBJ_VAL(fun)));
+
+  for (int i = 0; i < fun->upvalueCount; i++) {
+    emitByte(compiler.upvalues[i].isLocal ? 1 : 0);
+    emitByte(compiler.upvalues[i].index);
+  }
 }
 static void funDeclaration() {
   uint8_t global = parseVariable("Expect function name.");
@@ -519,12 +528,53 @@ static int resolveLocal(Compiler *compile, Token *name) {
   }
   return -1;
 }
+static int addUpvalue(Compiler *compiler, uint8_t index, bool isLocal) {
+  int upvalueCount = compiler->function->upvalueCount;
+
+  //! before we add a new upvalue,
+  //! we first check to see if the function already has an upvalue
+  //! that closes over that variable.
+  for (int i = 0; i < upvalueCount; i++) {
+    Upvalue *upvalue = &compiler->upvalues[i];
+    if (upvalue->index == index && upvalue->isLocal == isLocal) {
+      return i;
+    }
+  }
+
+  if (upvalueCount == UINT8_COUNT) {
+    error("Too many closure variable in function.");
+    return 0;
+  }
+
+  compiler->upvalues[upvalueCount].isLocal = isLocal;
+  compiler->upvalues[upvalueCount].index = index;
+  return compiler->function->upvalueCount++;
+}
+static int resolveUpvalue(Compiler *compiler, Token *name) {
+  if (compiler->enclosing == NULL) return -1;
+
+  int local = resolveLocal(compiler->enclosing, name);
+  if (local != -1) {
+    return addUpvalue(compiler, (uint8_t) local, true);
+  }
+
+  // 递归
+  int upvalue = resolveUpvalue(compiler->enclosing, name);
+  if (upvalue != -1) {
+    return addUpvalue(compiler, (uint8_t) upvalue, false);
+  }
+
+  return -1;
+}
 static void namedVariable(Token name, bool canAssign) {
   uint8_t getOp, setOp;
   int arg = resolveLocal(current, &name);
   if (arg != -1) {
     getOp = OP_GET_LOCAL;
     setOp = OP_SET_LOCAL;
+  } else if ((arg = resolveUpvalue(current, &name)) != -1) {
+    getOp = OP_GET_UPVALUE;
+    setOp = OP_SET_UPVALUE;
   } else {
     arg = identifierConstant(&name);
     getOp = OP_GET_GLOBAL;

+ 19 - 0
debug.c

@@ -3,6 +3,7 @@
 //
 
 #include "debug.h"
+#include "object.h"
 #include "value.h"
 #include <stdio.h>
 static int constantInstruction(const char *, Chunk *, int);
@@ -58,6 +59,24 @@ int disassembleInstruction(Chunk *chunk, int offset) {
     case OP_LOOP: return jumpInstruction("OP_LOOP", -1, chunk, offset);
     case OP_CALL: return byteInstruction("OP_CALL", chunk, offset);
     case OP_JUMP_IF_FALSE: return jumpInstruction("OP_JUMP_IF_FALSE", 1, chunk, offset);
+    case OP_CLOSURE: {
+      offset++;
+      uint8_t constant = chunk->code[offset++];
+      printf("%-16s %4d ", "OP_CLOSURE", constant);
+      printValue(chunk->constants.values[constant]);
+      printf("\n");
+
+      ObjFunction *function = AS_FUNCTION(chunk->constants.values[constant]);
+      for (int j = 0; j < function->upvalueCount; j++) {
+        int isLocal = chunk->code[offset++];
+        int index = chunk->code[offset++];
+        printf("%04d    |                     %s %d\n", offset - 2, isLocal ? "local" : "upvalue", index);
+      }
+
+      return offset;
+    }
+    case OP_GET_UPVALUE: return byteInstruction("OP_GET_UPVALUE", chunk, offset);
+    case OP_SET_UPVALUE: return byteInstruction("OP_SET_UPVALUE", chunk, offset);
     case OP_RETURN: return simpleInstruction("OP_RETURN", offset);
     default:
       printf("Unknown opcode %d\n", instruction);

+ 9 - 0
memory.c

@@ -44,6 +44,15 @@ static void freeObject(Obj *object) {
       FREE_ARRAY(char, ((ObjString *) object)->chars, ((ObjString *) object)->length + 1);
       FREE(ObjString, object);
       break;
+    case OBJ_CLOSURE: {
+      ObjClosure *closure = (ObjClosure *) object;
+      FREE_ARRAY(ObjUpvalue *, closure->upvalues, closure->upvalueCount);
+      FREE(ObjClosure, object);
+      break;
+    }
+    case OBJ_UPVALUE:
+      FREE(ObjUpvalue, object);
+      break;
   }
 }
 void freeObjects() {

+ 26 - 1
object.c

@@ -62,11 +62,18 @@ void printObject(Value value) {
     case OBJ_STRING:
       printf("%s", AS_CSTRING(value));
       break;
+    case OBJ_CLOSURE:
+      printFunction(AS_CLOSURE(value)->function);
+      break;
+    case OBJ_UPVALUE:
+      printf("upvalue");
+      break;
   }
 }
 ObjFunction *newFunction() {
   ObjFunction *function = ALLOCATE_OBJ(ObjFunction, OBJ_FUNCTION);
   function->arity = 0;
+  function->upvalueCount = 0;
   function->name = NULL;
   initChunk(&function->chunk);
   return function;
@@ -87,7 +94,25 @@ ObjNative *newNative(NativeFn function) {
   native->function = function;
   return native;
 }
-/// 分配内在空间
+ObjClosure *newClosure(ObjFunction *function) {
+  ObjUpvalue **upvalues = ALLOCATE(ObjUpvalue *, function->upvalueCount);
+
+  for (int i = 0; i < function->upvalueCount; i++) {
+    upvalues[i] = NULL;
+  }
+
+  ObjClosure *closure = ALLOCATE_OBJ(ObjClosure, OBJ_CLOSURE);
+  closure->function = function;
+  closure->upvalues = upvalues;
+  closure->upvalueCount = function->upvalueCount;
+  return closure;
+}
+ObjUpvalue *newUpvalue(Value *slot) {
+  ObjUpvalue *upvalue = ALLOCATE_OBJ(ObjUpvalue, OBJ_UPVALUE);
+  upvalue->location = slot;
+  return upvalue;
+}
+/// 分配内存空间
 /// \param size 空间大小
 /// \param type 对象类型
 /// \return Obj*

+ 19 - 0
object.h

@@ -17,10 +17,12 @@
 
 #define OBJ_TYPE(value) (AS_OBJ(value)->type)
 
+#define IS_CLOSURE(value) isObjType(value, OBJ_CLOSURE)
 #define IS_FUNCTION(value) isObjType(value, OBJ_FUNCTION)
 #define IS_NATIVE(value) isObjType(value, OBJ_NATIVE)
 #define IS_STRING(value) isObjType(value, OBJ_STRING)
 
+#define AS_CLOSURE(value) ((ObjClosure *) AS_OBJ(value))
 #define AS_FUNCTION(value) ((ObjFunction *) AS_OBJ(value))
 #define AS_NATIVE(value) (((ObjNative *) AS_OBJ(value))->function)
 #define AS_STRING(value) ((ObjString *) AS_OBJ(value))
@@ -30,6 +32,8 @@ typedef enum {
   OBJ_STRING,
   OBJ_FUNCTION,
   OBJ_NATIVE,
+  OBJ_CLOSURE,
+  OBJ_UPVALUE,
 } ObjType;
 
 struct Obj {
@@ -40,6 +44,7 @@ struct Obj {
 typedef struct {
   struct Obj obj;
   int arity;//  stores the number of parameters the function expects
+  int upvalueCount;
   Chunk chunk;
   ObjString *name;
 } ObjFunction;
@@ -58,10 +63,24 @@ struct ObjString {
   uint32_t hash;// 缓存 hash 值
 };
 
+typedef struct ObjUpvalue {
+  struct Obj obj;
+  Value *location;
+} ObjUpvalue;
+
+typedef struct {
+  struct Obj obj;
+  ObjFunction *function;
+  ObjUpvalue **upvalues;
+  int upvalueCount;
+} ObjClosure;
+
+ObjClosure *newClosure(ObjFunction *function);
 ObjFunction *newFunction();
 ObjNative *newNative(NativeFn function);
 ObjString *takeString(char *chars, int length);
 ObjString *copyString(const char *chars, int length);
+ObjUpvalue *newUpvalue(Value *slot);
 void printObject(Value value);
 
 static inline bool isObjType(Value value, ObjType type) {

+ 11 - 0
test/closures.lox

@@ -0,0 +1,11 @@
+fun outer() {
+  var x = "outside";
+  fun inner() {
+    print x;
+  }
+
+  return inner;
+}
+
+var closure = outer();
+closure();

+ 48 - 12
vm.c

@@ -31,8 +31,8 @@ static void runtimeError(const char *format, ...) {
   fputs("\n", stderr);
   for (int i = vm.frameCount - 1; i >= 0; i--) {
     CallFrame *frame = &vm.frames[i];
-    ObjFunction *function = frame->function;
-    size_t instruction = frame->ip - frame->function->chunk.code - 1;
+    ObjFunction *function = frame->closure->function;
+    size_t instruction = frame->ip - function->chunk.code - 1;
     int line = function->chunk.lines[instruction];
     fprintf(stderr, "[line %d] in ", line);
     if (function->name == NULL) {
@@ -70,9 +70,9 @@ void initVM() {
 static Value peek(int distance) {
   return vm.stackTop[-1 - distance];
 }
-static bool call(ObjFunction *function, int argCount) {
-  if (argCount != function->arity) {
-    runtimeError("Expected %d arguments but got %d", function->arity, argCount);
+static bool call(ObjClosure *closure, int argCount) {
+  if (argCount != closure->function->arity) {
+    runtimeError("Expected %d arguments but got %d", closure->function->arity, argCount);
     return false;
   }
 
@@ -82,16 +82,16 @@ static bool call(ObjFunction *function, int argCount) {
   }
 
   CallFrame *frame = &vm.frames[vm.frameCount++];
-  frame->function = function;
-  frame->ip = function->chunk.code;
+  frame->closure = closure;
+  frame->ip = closure->function->chunk.code;
   frame->slots = vm.stackTop - argCount - 1;
   return true;
 }
 static bool callValue(Value callee, int argCount) {
   if (IS_OBJ(callee)) {
     switch (OBJ_TYPE(callee)) {
-      case OBJ_FUNCTION:
-        return call(AS_FUNCTION(callee), argCount);
+      case OBJ_CLOSURE:
+        return call(AS_CLOSURE(callee), argCount);
       case OBJ_NATIVE: {
         NativeFn fn = AS_NATIVE(callee);
         Value result = fn(argCount, vm.stackTop - argCount);
@@ -124,6 +124,12 @@ static void concatenate() {
   ObjString *result = takeString(chars, length);
   push(OBJ_VAL(result));
 }
+static ObjUpvalue *captureUpvalue(Value *local) {
+  ObjUpvalue *createdUpvalue = newUpvalue(local);
+  return createdUpvalue;
+}
+/// VM run function - exec opcode
+/// \return
 static InterpretResult run() {
   CallFrame *frame = &vm.frames[vm.frameCount - 1];
 
@@ -132,7 +138,9 @@ static InterpretResult run() {
 #define READ_BYTE() (*frame->ip++)
 //! reads the next byte from the bytecode
 //! treats the resulting number as an index
-#define READ_CONSTANT() (frame->function->chunk.constants.values[READ_BYTE()])
+#define READ_CONSTANT() \
+  (frame->closure->function->chunk.constants.values[READ_BYTE()])
+
 #define READ_SHORT() \
   (frame->ip += 2, (uint16_t) ((frame->ip[-2] << 8) | frame->ip[-1]))
 #define READ_STRING() AS_STRING(READ_CONSTANT())
@@ -164,7 +172,7 @@ static InterpretResult run() {
     //! <Stack tracing> end
 
     printf("The Instruction: \n");
-    disassembleInstruction(&frame->function->chunk, (int) (frame->ip - frame->function->chunk.code));
+    disassembleInstruction(&frame->closure->function->chunk, (int) (frame->ip - frame->closure->function->chunk.code));
 #endif
     uint8_t opCode = READ_BYTE();
     switch (opCode) {
@@ -293,6 +301,31 @@ static InterpretResult run() {
         frame = &vm.frames[vm.frameCount - 1];
         break;
       }
+      case OP_CLOSURE: {
+        ObjFunction *function = AS_FUNCTION(READ_CONSTANT());
+        ObjClosure *closure = newClosure(function);
+        push(OBJ_VAL(closure));
+        for (int i = 0; i < closure->upvalueCount; i++) {
+          uint8_t isLocal = READ_BYTE();
+          uint8_t index = READ_BYTE();
+          if (isLocal) {
+            closure->upvalues[i] = captureUpvalue(frame->slots + index);
+          } else {
+            closure->upvalues[i] = frame->closure->upvalues[index];
+          }
+        }
+        break;
+      }
+      case OP_GET_UPVALUE: {
+        uint8_t slot = READ_BYTE();
+        push(*frame->closure->upvalues[slot]->location);
+        break;
+      }
+      case OP_SET_UPVALUE: {
+        uint8_t slot = READ_BYTE();
+        *frame->closure->upvalues[slot]->location = peek(0);
+        break;
+      }
       case OP_RETURN: {
         Value result = pop();
         vm.frameCount--;
@@ -329,12 +362,15 @@ InterpretResult interpret(const char *source) {
   if (function == NULL) return INTERPRET_COMPILE_ERROR;
 
   push(OBJ_VAL(function));
+  ObjClosure *closure = newClosure(function);
+  pop();
+  push(OBJ_VAL(closure));
 
   //  CallFrame *frame = &vm.frames[vm.frameCount++];
   //  frame->function = function;
   //  frame->ip = function->chunk.code;
   //  frame->slots = vm.stack;
-  call(function, 0);
+  call(closure, 0);
 
   return run();
 }

+ 3 - 3
vm.h

@@ -11,20 +11,20 @@
 #ifndef CLOX__VM_H_
 #define CLOX__VM_H_
 
-#include "object.h"
 #include "compiler.h"
+#include "object.h"
 #include "table.h"
 
 #define FRAMES_MAX 64
 #define STACK_MAX (FRAMES_MAX * UINT8_COUNT)
 
 typedef struct {
-  ObjFunction *function;
+  ObjClosure *closure;
   uint8_t *ip;
   Value *slots;
 } CallFrame;
 
-typedef struct {
+typedef struct VM {
   CallFrame frames[FRAMES_MAX];
   int frameCount;