Parcourir la source

CallExpression literal parsing

simon il y a 3 ans
Parent
commit
8a1ab9460d
4 fichiers modifiés avec 133 ajouts et 39 suppressions
  1. 33 3
      ast/ast.go
  2. 62 32
      parser/parser.go
  3. 34 0
      parser/parser_test.go
  4. 4 4
      parser/parser_tracing.go

+ 33 - 3
ast/ast.go

@@ -27,6 +27,7 @@ type Expression interface {
 }
 
 // ---------------------implementation of Node----------------------------------BEGIN-----------------------------------
+
 // Program Root Node of our parser produces.
 type Program struct {
 	Statements []Statement
@@ -94,7 +95,7 @@ func (i *Identifier) TokenLiteral() string {
 // ReturnStatement return <expression>;
 type ReturnStatement struct {
 	Token       token.Token // the token.RETURN
-	returnValue Expression
+	ReturnValue Expression
 }
 
 func (rs *ReturnStatement) TokenLiteral() string {
@@ -104,8 +105,8 @@ func (rs *ReturnStatement) String() string {
 	var out bytes.Buffer
 
 	out.WriteString(rs.TokenLiteral() + " ")
-	if rs.returnValue != nil {
-		out.WriteString(rs.returnValue.String())
+	if rs.ReturnValue != nil {
+		out.WriteString(rs.ReturnValue.String())
 	}
 	out.WriteString(";")
 	return out.String()
@@ -283,3 +284,32 @@ func (fl *FunctionLiteral) String() string {
 }
 
 func (fl *FunctionLiteral) expressionNode() {}
+
+// CallExpression <expression>(<comma separated expressions>);
+type CallExpression struct {
+	Token     token.Token // the '(' token
+	Function  Expression
+	Arguments []Expression
+}
+
+func (ce *CallExpression) TokenLiteral() string {
+	return ce.Token.Literal
+}
+
+func (ce *CallExpression) String() string {
+	var out bytes.Buffer
+
+	args := []string{}
+	for _, a := range ce.Arguments {
+		args = append(args, a.String())
+	}
+
+	out.WriteString(ce.Function.String())
+	out.WriteString("(")
+	out.WriteString(strings.Join(args, ", "))
+	out.WriteString(")")
+
+	return out.String()
+}
+
+func (ce *CallExpression) expressionNode() {}

+ 62 - 32
parser/parser.go

@@ -27,7 +27,7 @@ const (
 	CALL        // myFunction(X)
 )
 
-// 指派 token 类型的优先级
+// precedences 指派 token 类型的优先级
 var precedences = map[token.TypeToken]int{
 	token.EQ:       EQUALS,
 	token.NOT_EQ:   EQUALS,
@@ -37,6 +37,7 @@ var precedences = map[token.TypeToken]int{
 	token.MINUS:    SUM,
 	token.ASTERISK: PRODUCT,
 	token.SLASH:    PRODUCT,
+	token.LPAREN:   CALL,
 }
 
 type (
@@ -81,6 +82,7 @@ func New(l *lexer.Lexer) *Parser {
 	p.registerInfix(token.NOT_EQ, p.parseInfixExpression)
 	p.registerInfix(token.GT, p.parseInfixExpression)
 	p.registerInfix(token.LT, p.parseInfixExpression)
+	p.registerInfix(token.LPAREN, p.parseCallExpression)
 
 	// Read two tokens, so curToken and peekToken are both set
 	p.nextToken()
@@ -89,6 +91,32 @@ func New(l *lexer.Lexer) *Parser {
 	return p
 }
 
+func (p *Parser) ParseProgram() *ast.Program {
+	program := &ast.Program{}
+	program.Statements = []ast.Statement{}
+
+	for !p.curTokenIs(token.EOF) {
+		stmt := p.parseStatement()
+		if stmt != nil {
+			program.Statements = append(program.Statements, stmt)
+		}
+		p.nextToken()
+
+	}
+	return program
+}
+
+func (p *Parser) parseStatement() ast.Statement {
+	switch p.curToken.Type {
+	case token.LET:
+		return p.parseLetStatement()
+	case token.RETURN:
+		return p.parseReturnStatement()
+	default:
+		return p.parseExpressionStatement()
+	}
+}
+
 func (p *Parser) parseIdentifier() ast.Expression {
 	return &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}
 }
@@ -173,6 +201,13 @@ func (p *Parser) parseIntegerLiteral() ast.Expression {
 
 	return lit
 }
+func (p *Parser) parseCallExpression(left ast.Expression) ast.Expression {
+	defer untrace(trace("parseCallExpression"))
+	// add(2,3) --> Function: add
+	exp := &ast.CallExpression{Token: p.curToken, Function: left}
+	exp.Arguments = p.parseCallArguments()
+	return exp
+}
 func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression {
 	defer untrace(trace("parseInfixExpression"))
 	exp := &ast.InfixExpression{
@@ -233,32 +268,6 @@ func (p *Parser) nextToken() {
 	p.peekToken = p.l.NextToken()
 }
 
-func (p *Parser) ParseProgram() *ast.Program {
-	program := &ast.Program{}
-	program.Statements = []ast.Statement{}
-
-	for !p.curTokenIs(token.EOF) {
-		stmt := p.parseStatement()
-		if stmt != nil {
-			program.Statements = append(program.Statements, stmt)
-		}
-		p.nextToken()
-
-	}
-	return program
-}
-
-func (p *Parser) parseStatement() ast.Statement {
-	switch p.curToken.Type {
-	case token.LET:
-		return p.parseLetStatement()
-	case token.RETURN:
-		return p.parseReturnStatement()
-	default:
-		return p.parseExpressionStatement()
-	}
-}
-
 // let <identifier> = <expression>;
 func (p *Parser) parseLetStatement() *ast.LetStatement {
 	stmt := &ast.LetStatement{Token: p.curToken}
@@ -275,9 +284,8 @@ func (p *Parser) parseLetStatement() *ast.LetStatement {
 		return nil
 	}
 
-	// TODO: we're skipping the expression until we
-	//       we encounter a semicolon
-
+	p.nextToken()
+	stmt.Value = p.parseExpression(LOWEST)
 	// ;
 	for !p.curTokenIs(token.SEMICOLON) {
 		p.nextToken()
@@ -290,9 +298,8 @@ func (p *Parser) parseReturnStatement() *ast.ReturnStatement {
 	stmt := &ast.ReturnStatement{Token: p.curToken}
 
 	p.nextToken()
+	stmt.ReturnValue = p.parseExpression(LOWEST)
 
-	// TODO: we're skipping the expressions until we
-	//       encounter a semicolon
 	for !p.curTokenIs(token.SEMICOLON) {
 		p.nextToken()
 	}
@@ -405,3 +412,26 @@ func (p *Parser) ParseFunctionParameters() []*ast.Identifier {
 
 	return identifiers
 }
+
+func (p *Parser) parseCallArguments() []ast.Expression {
+	args := []ast.Expression{}
+
+	// the case (), then null args
+	if p.peekTokenIs(token.RPAREN) {
+		p.nextToken()
+		return args
+	}
+	p.nextToken()
+	args = append(args, p.parseExpression(LOWEST))
+
+	for p.peekTokenIs(token.COMMA) {
+		p.nextToken()
+		p.nextToken()
+		args = append(args, p.parseExpression(LOWEST))
+	}
+
+	if !p.expectPeek(token.RPAREN) {
+		return nil
+	}
+	return args
+}

+ 34 - 0
parser/parser_test.go

@@ -659,3 +659,37 @@ func TestFunctionParameterParsing(t *testing.T) {
 		}
 	}
 }
+
+func TestCallExpressionParsing(t *testing.T) {
+	input := `add(1, 2 * 3, 4+5);`
+
+	l := lexer.New(input)
+	p := New(l)
+	program := p.ParseProgram()
+	checkParseErrors(t, p)
+
+	if len(program.Statements) != 1 {
+		t.Fatalf("program.Statement does not contain %d statements. got=%d\n", 1, len(program.Statements))
+	}
+
+	stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
+	if !ok {
+		t.Fatalf("stmt is not ast.ExpressionStatement. got=%T", program.Statements[0])
+	}
+
+	exp, ok := stmt.Expression.(*ast.CallExpression)
+	if !ok {
+		t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T", stmt.Expression)
+	}
+
+	if !testIdentifier(t, exp.Function, "add") {
+		return
+	}
+	if len(exp.Arguments) != 3 {
+		t.Fatalf("wrong length arguments. got=%d", len(exp.Arguments))
+	}
+
+	testLiteralExpression(t, exp.Arguments[0], 1)
+	testInfixExpression(t, exp.Arguments[1], 2, "*", 3)
+	testInfixExpression(t, exp.Arguments[2], 4, "+", 5)
+}

+ 4 - 4
parser/parser_tracing.go

@@ -21,12 +21,12 @@ func incIdent() { traceLevel = traceLevel + 1 }
 func decIdent() { traceLevel = traceLevel - 1 }
 
 func trace(msg string) string {
-	incIdent()
-	tracePrint("BEGIN " + msg)
+	//incIdent()
+	//tracePrint("BEGIN " + msg)
 	return msg
 }
 
 func untrace(msg string) {
-	tracePrint("END " + msg)
-	decIdent()
+	//tracePrint("END " + msg)
+	//decIdent()
 }