2020-01-09 21:52:19 +01:00

1313 lines
32 KiB
Go

package tengo
import (
"fmt"
"io"
"io/ioutil"
"path/filepath"
"reflect"
"strings"
"github.com/d5/tengo/v2/parser"
"github.com/d5/tengo/v2/token"
)
// compilationScope represents a compiled instructions and the last two
// instructions that were emitted.
type compilationScope struct {
Instructions []byte
SymbolInit map[string]bool
SourceMap map[int]parser.Pos
}
// loop represents a loop construct that the compiler uses to track the current
// loop.
type loop struct {
Continues []int
Breaks []int
}
// CompilerError represents a compiler error.
type CompilerError struct {
FileSet *parser.SourceFileSet
Node parser.Node
Err error
}
func (e *CompilerError) Error() string {
filePos := e.FileSet.Position(e.Node.Pos())
return fmt.Sprintf("Compile Error: %s\n\tat %s", e.Err.Error(), filePos)
}
// Compiler compiles the AST into a bytecode.
type Compiler struct {
file *parser.SourceFile
parent *Compiler
modulePath string
constants []Object
symbolTable *SymbolTable
scopes []compilationScope
scopeIndex int
modules *ModuleMap
compiledModules map[string]*CompiledFunction
allowFileImport bool
loops []*loop
loopIndex int
trace io.Writer
indent int
}
// NewCompiler creates a Compiler.
func NewCompiler(
file *parser.SourceFile,
symbolTable *SymbolTable,
constants []Object,
modules *ModuleMap,
trace io.Writer,
) *Compiler {
mainScope := compilationScope{
SymbolInit: make(map[string]bool),
SourceMap: make(map[int]parser.Pos),
}
// symbol table
if symbolTable == nil {
symbolTable = NewSymbolTable()
}
// add builtin functions to the symbol table
for idx, fn := range builtinFuncs {
symbolTable.DefineBuiltin(idx, fn.Name)
}
// builtin modules
if modules == nil {
modules = NewModuleMap()
}
return &Compiler{
file: file,
symbolTable: symbolTable,
constants: constants,
scopes: []compilationScope{mainScope},
scopeIndex: 0,
loopIndex: -1,
trace: trace,
modules: modules,
compiledModules: make(map[string]*CompiledFunction),
}
}
// Compile compiles the AST node.
func (c *Compiler) Compile(node parser.Node) error {
if c.trace != nil {
if node != nil {
defer untracec(tracec(c, fmt.Sprintf("%s (%s)",
node.String(), reflect.TypeOf(node).Elem().Name())))
} else {
defer untracec(tracec(c, "<nil>"))
}
}
switch node := node.(type) {
case *parser.File:
for _, stmt := range node.Stmts {
if err := c.Compile(stmt); err != nil {
return err
}
}
case *parser.ExprStmt:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(node, parser.OpPop)
case *parser.IncDecStmt:
op := token.AddAssign
if node.Token == token.Dec {
op = token.SubAssign
}
return c.compileAssign(node, []parser.Expr{node.Expr},
[]parser.Expr{&parser.IntLit{Value: 1}}, op)
case *parser.ParenExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
case *parser.BinaryExpr:
if node.Token == token.LAnd || node.Token == token.LOr {
return c.compileLogical(node)
}
if node.Token == token.Less {
if err := c.Compile(node.RHS); err != nil {
return err
}
if err := c.Compile(node.LHS); err != nil {
return err
}
c.emit(node, parser.OpBinaryOp, int(token.Greater))
return nil
} else if node.Token == token.LessEq {
if err := c.Compile(node.RHS); err != nil {
return err
}
if err := c.Compile(node.LHS); err != nil {
return err
}
c.emit(node, parser.OpBinaryOp, int(token.GreaterEq))
return nil
}
if err := c.Compile(node.LHS); err != nil {
return err
}
if err := c.Compile(node.RHS); err != nil {
return err
}
switch node.Token {
case token.Add:
c.emit(node, parser.OpBinaryOp, int(token.Add))
case token.Sub:
c.emit(node, parser.OpBinaryOp, int(token.Sub))
case token.Mul:
c.emit(node, parser.OpBinaryOp, int(token.Mul))
case token.Quo:
c.emit(node, parser.OpBinaryOp, int(token.Quo))
case token.Rem:
c.emit(node, parser.OpBinaryOp, int(token.Rem))
case token.Greater:
c.emit(node, parser.OpBinaryOp, int(token.Greater))
case token.GreaterEq:
c.emit(node, parser.OpBinaryOp, int(token.GreaterEq))
case token.Equal:
c.emit(node, parser.OpEqual)
case token.NotEqual:
c.emit(node, parser.OpNotEqual)
case token.And:
c.emit(node, parser.OpBinaryOp, int(token.And))
case token.Or:
c.emit(node, parser.OpBinaryOp, int(token.Or))
case token.Xor:
c.emit(node, parser.OpBinaryOp, int(token.Xor))
case token.AndNot:
c.emit(node, parser.OpBinaryOp, int(token.AndNot))
case token.Shl:
c.emit(node, parser.OpBinaryOp, int(token.Shl))
case token.Shr:
c.emit(node, parser.OpBinaryOp, int(token.Shr))
default:
return c.errorf(node, "invalid binary operator: %s",
node.Token.String())
}
case *parser.IntLit:
c.emit(node, parser.OpConstant,
c.addConstant(&Int{Value: node.Value}))
case *parser.FloatLit:
c.emit(node, parser.OpConstant,
c.addConstant(&Float{Value: node.Value}))
case *parser.BoolLit:
if node.Value {
c.emit(node, parser.OpTrue)
} else {
c.emit(node, parser.OpFalse)
}
case *parser.StringLit:
if len(node.Value) > MaxStringLen {
return c.error(node, ErrStringLimit)
}
c.emit(node, parser.OpConstant,
c.addConstant(&String{Value: node.Value}))
case *parser.CharLit:
c.emit(node, parser.OpConstant,
c.addConstant(&Char{Value: node.Value}))
case *parser.UndefinedLit:
c.emit(node, parser.OpNull)
case *parser.UnaryExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
switch node.Token {
case token.Not:
c.emit(node, parser.OpLNot)
case token.Sub:
c.emit(node, parser.OpMinus)
case token.Xor:
c.emit(node, parser.OpBComplement)
case token.Add:
// do nothing?
default:
return c.errorf(node,
"invalid unary operator: %s", node.Token.String())
}
case *parser.IfStmt:
// open new symbol table for the statement
c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()
if node.Init != nil {
if err := c.Compile(node.Init); err != nil {
return err
}
}
if err := c.Compile(node.Cond); err != nil {
return err
}
// first jump placeholder
jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0)
if err := c.Compile(node.Body); err != nil {
return err
}
if node.Else != nil {
// second jump placeholder
jumpPos2 := c.emit(node, parser.OpJump, 0)
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
if err := c.Compile(node.Else); err != nil {
return err
}
// update second jump offset
curPos = len(c.currentInstructions())
c.changeOperand(jumpPos2, curPos)
} else {
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
}
case *parser.ForStmt:
return c.compileForStmt(node)
case *parser.ForInStmt:
return c.compileForInStmt(node)
case *parser.BranchStmt:
if node.Token == token.Break {
curLoop := c.currentLoop()
if curLoop == nil {
return c.errorf(node, "break not allowed outside loop")
}
pos := c.emit(node, parser.OpJump, 0)
curLoop.Breaks = append(curLoop.Breaks, pos)
} else if node.Token == token.Continue {
curLoop := c.currentLoop()
if curLoop == nil {
return c.errorf(node, "continue not allowed outside loop")
}
pos := c.emit(node, parser.OpJump, 0)
curLoop.Continues = append(curLoop.Continues, pos)
} else {
panic(fmt.Errorf("invalid branch statement: %s",
node.Token.String()))
}
case *parser.BlockStmt:
if len(node.Stmts) == 0 {
return nil
}
c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()
for _, stmt := range node.Stmts {
if err := c.Compile(stmt); err != nil {
return err
}
}
case *parser.AssignStmt:
err := c.compileAssign(node, node.LHS, node.RHS, node.Token)
if err != nil {
return err
}
case *parser.Ident:
symbol, _, ok := c.symbolTable.Resolve(node.Name)
if !ok {
return c.errorf(node, "unresolved reference '%s'", node.Name)
}
switch symbol.Scope {
case ScopeGlobal:
c.emit(node, parser.OpGetGlobal, symbol.Index)
case ScopeLocal:
c.emit(node, parser.OpGetLocal, symbol.Index)
case ScopeBuiltin:
c.emit(node, parser.OpGetBuiltin, symbol.Index)
case ScopeFree:
c.emit(node, parser.OpGetFree, symbol.Index)
}
case *parser.ArrayLit:
for _, elem := range node.Elements {
if err := c.Compile(elem); err != nil {
return err
}
}
c.emit(node, parser.OpArray, len(node.Elements))
case *parser.MapLit:
for _, elt := range node.Elements {
// key
if len(elt.Key) > MaxStringLen {
return c.error(node, ErrStringLimit)
}
c.emit(node, parser.OpConstant,
c.addConstant(&String{Value: elt.Key}))
// value
if err := c.Compile(elt.Value); err != nil {
return err
}
}
c.emit(node, parser.OpMap, len(node.Elements)*2)
case *parser.SelectorExpr: // selector on RHS side
if err := c.Compile(node.Expr); err != nil {
return err
}
if err := c.Compile(node.Sel); err != nil {
return err
}
c.emit(node, parser.OpIndex)
case *parser.IndexExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
if err := c.Compile(node.Index); err != nil {
return err
}
c.emit(node, parser.OpIndex)
case *parser.SliceExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
if node.Low != nil {
if err := c.Compile(node.Low); err != nil {
return err
}
} else {
c.emit(node, parser.OpNull)
}
if node.High != nil {
if err := c.Compile(node.High); err != nil {
return err
}
} else {
c.emit(node, parser.OpNull)
}
c.emit(node, parser.OpSliceIndex)
case *parser.FuncLit:
c.enterScope()
for _, p := range node.Type.Params.List {
s := c.symbolTable.Define(p.Name)
// function arguments is not assigned directly.
s.LocalAssigned = true
}
if err := c.Compile(node.Body); err != nil {
return err
}
// code optimization
c.optimizeFunc(node)
freeSymbols := c.symbolTable.FreeSymbols()
numLocals := c.symbolTable.MaxSymbols()
instructions, sourceMap := c.leaveScope()
for _, s := range freeSymbols {
switch s.Scope {
case ScopeLocal:
if !s.LocalAssigned {
// Here, the closure is capturing a local variable that's
// not yet assigned its value. One example is a local
// recursive function:
//
// func() {
// foo := func(x) {
// // ..
// return foo(x-1)
// }
// }
//
// which translate into
//
// 0000 GETL 0
// 0002 CLOSURE ? 1
// 0006 DEFL 0
//
// . So the local variable (0) is being captured before
// it's assigned the value.
//
// Solution is to transform the code into something like
// this:
//
// func() {
// foo := undefined
// foo = func(x) {
// // ..
// return foo(x-1)
// }
// }
//
// that is equivalent to
//
// 0000 NULL
// 0001 DEFL 0
// 0003 GETL 0
// 0005 CLOSURE ? 1
// 0009 SETL 0
//
c.emit(node, parser.OpNull)
c.emit(node, parser.OpDefineLocal, s.Index)
s.LocalAssigned = true
}
c.emit(node, parser.OpGetLocalPtr, s.Index)
case ScopeFree:
c.emit(node, parser.OpGetFreePtr, s.Index)
}
}
compiledFunction := &CompiledFunction{
Instructions: instructions,
NumLocals: numLocals,
NumParameters: len(node.Type.Params.List),
VarArgs: node.Type.Params.VarArgs,
SourceMap: sourceMap,
}
if len(freeSymbols) > 0 {
c.emit(node, parser.OpClosure,
c.addConstant(compiledFunction), len(freeSymbols))
} else {
c.emit(node, parser.OpConstant, c.addConstant(compiledFunction))
}
case *parser.ReturnStmt:
if c.symbolTable.Parent(true) == nil {
// outside the function
return c.errorf(node, "return not allowed outside function")
}
if node.Result == nil {
c.emit(node, parser.OpReturn, 0)
} else {
if err := c.Compile(node.Result); err != nil {
return err
}
c.emit(node, parser.OpReturn, 1)
}
case *parser.CallExpr:
if err := c.Compile(node.Func); err != nil {
return err
}
for _, arg := range node.Args {
if err := c.Compile(arg); err != nil {
return err
}
}
c.emit(node, parser.OpCall, len(node.Args))
case *parser.ImportExpr:
if node.ModuleName == "" {
return c.errorf(node, "empty module name")
}
if mod := c.modules.Get(node.ModuleName); mod != nil {
v, err := mod.Import(node.ModuleName)
if err != nil {
return err
}
switch v := v.(type) {
case []byte: // module written in Tengo
compiled, err := c.compileModule(node,
node.ModuleName, node.ModuleName, v)
if err != nil {
return err
}
c.emit(node, parser.OpConstant, c.addConstant(compiled))
c.emit(node, parser.OpCall, 0)
case Object: // builtin module
c.emit(node, parser.OpConstant, c.addConstant(v))
default:
panic(fmt.Errorf("invalid import value type: %T", v))
}
} else if c.allowFileImport {
moduleName := node.ModuleName
if !strings.HasSuffix(moduleName, ".tengo") {
moduleName += ".tengo"
}
modulePath, err := filepath.Abs(moduleName)
if err != nil {
return c.errorf(node, "module file path error: %s",
err.Error())
}
if err := c.checkCyclicImports(node, modulePath); err != nil {
return err
}
moduleSrc, err := ioutil.ReadFile(moduleName)
if err != nil {
return c.errorf(node, "module file read error: %s",
err.Error())
}
compiled, err := c.compileModule(node,
moduleName, modulePath, moduleSrc)
if err != nil {
return err
}
c.emit(node, parser.OpConstant, c.addConstant(compiled))
c.emit(node, parser.OpCall, 0)
} else {
return c.errorf(node, "module '%s' not found", node.ModuleName)
}
case *parser.ExportStmt:
// export statement must be in top-level scope
if c.scopeIndex != 0 {
return c.errorf(node, "export not allowed inside function")
}
// export statement is simply ignore when compiling non-module code
if c.parent == nil {
break
}
if err := c.Compile(node.Result); err != nil {
return err
}
c.emit(node, parser.OpImmutable)
c.emit(node, parser.OpReturn, 1)
case *parser.ErrorExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(node, parser.OpError)
case *parser.ImmutableExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(node, parser.OpImmutable)
case *parser.CondExpr:
if err := c.Compile(node.Cond); err != nil {
return err
}
// first jump placeholder
jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0)
if err := c.Compile(node.True); err != nil {
return err
}
// second jump placeholder
jumpPos2 := c.emit(node, parser.OpJump, 0)
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
if err := c.Compile(node.False); err != nil {
return err
}
// update second jump offset
curPos = len(c.currentInstructions())
c.changeOperand(jumpPos2, curPos)
}
return nil
}
// Bytecode returns a compiled bytecode.
func (c *Compiler) Bytecode() *Bytecode {
return &Bytecode{
FileSet: c.file.Set(),
MainFunction: &CompiledFunction{
Instructions: append(c.currentInstructions(), parser.OpSuspend),
SourceMap: c.currentSourceMap(),
},
Constants: c.constants,
}
}
// EnableFileImport enables or disables module loading from local files.
// Local file modules are disabled by default.
func (c *Compiler) EnableFileImport(enable bool) {
c.allowFileImport = enable
}
func (c *Compiler) compileAssign(
node parser.Node,
lhs, rhs []parser.Expr,
op token.Token,
) error {
numLHS, numRHS := len(lhs), len(rhs)
if numLHS > 1 || numRHS > 1 {
return c.errorf(node, "tuple assignment not allowed")
}
// resolve and compile left-hand side
ident, selectors := resolveAssignLHS(lhs[0])
numSel := len(selectors)
if op == token.Define && numSel > 0 {
// using selector on new variable does not make sense
return c.errorf(node, "operator ':=' not allowed with selector")
}
symbol, depth, exists := c.symbolTable.Resolve(ident)
if op == token.Define {
if depth == 0 && exists {
return c.errorf(node, "'%s' redeclared in this block", ident)
}
symbol = c.symbolTable.Define(ident)
} else {
if !exists {
return c.errorf(node, "unresolved reference '%s'", ident)
}
}
// +=, -=, *=, /=
if op != token.Assign && op != token.Define {
if err := c.Compile(lhs[0]); err != nil {
return err
}
}
// compile RHSs
for _, expr := range rhs {
if err := c.Compile(expr); err != nil {
return err
}
}
switch op {
case token.AddAssign:
c.emit(node, parser.OpBinaryOp, int(token.Add))
case token.SubAssign:
c.emit(node, parser.OpBinaryOp, int(token.Sub))
case token.MulAssign:
c.emit(node, parser.OpBinaryOp, int(token.Mul))
case token.QuoAssign:
c.emit(node, parser.OpBinaryOp, int(token.Quo))
case token.RemAssign:
c.emit(node, parser.OpBinaryOp, int(token.Rem))
case token.AndAssign:
c.emit(node, parser.OpBinaryOp, int(token.And))
case token.OrAssign:
c.emit(node, parser.OpBinaryOp, int(token.Or))
case token.AndNotAssign:
c.emit(node, parser.OpBinaryOp, int(token.AndNot))
case token.XorAssign:
c.emit(node, parser.OpBinaryOp, int(token.Xor))
case token.ShlAssign:
c.emit(node, parser.OpBinaryOp, int(token.Shl))
case token.ShrAssign:
c.emit(node, parser.OpBinaryOp, int(token.Shr))
}
// compile selector expressions (right to left)
for i := numSel - 1; i >= 0; i-- {
if err := c.Compile(selectors[i]); err != nil {
return err
}
}
switch symbol.Scope {
case ScopeGlobal:
if numSel > 0 {
c.emit(node, parser.OpSetSelGlobal, symbol.Index, numSel)
} else {
c.emit(node, parser.OpSetGlobal, symbol.Index)
}
case ScopeLocal:
if numSel > 0 {
c.emit(node, parser.OpSetSelLocal, symbol.Index, numSel)
} else {
if op == token.Define && !symbol.LocalAssigned {
c.emit(node, parser.OpDefineLocal, symbol.Index)
} else {
c.emit(node, parser.OpSetLocal, symbol.Index)
}
}
// mark the symbol as local-assigned
symbol.LocalAssigned = true
case ScopeFree:
if numSel > 0 {
c.emit(node, parser.OpSetSelFree, symbol.Index, numSel)
} else {
c.emit(node, parser.OpSetFree, symbol.Index)
}
default:
panic(fmt.Errorf("invalid assignment variable scope: %s",
symbol.Scope))
}
return nil
}
func (c *Compiler) compileLogical(node *parser.BinaryExpr) error {
// left side term
if err := c.Compile(node.LHS); err != nil {
return err
}
// jump position
var jumpPos int
if node.Token == token.LAnd {
jumpPos = c.emit(node, parser.OpAndJump, 0)
} else {
jumpPos = c.emit(node, parser.OpOrJump, 0)
}
// right side term
if err := c.Compile(node.RHS); err != nil {
return err
}
c.changeOperand(jumpPos, len(c.currentInstructions()))
return nil
}
func (c *Compiler) compileForStmt(stmt *parser.ForStmt) error {
c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()
// init statement
if stmt.Init != nil {
if err := c.Compile(stmt.Init); err != nil {
return err
}
}
// pre-condition position
preCondPos := len(c.currentInstructions())
// condition expression
postCondPos := -1
if stmt.Cond != nil {
if err := c.Compile(stmt.Cond); err != nil {
return err
}
// condition jump position
postCondPos = c.emit(stmt, parser.OpJumpFalsy, 0)
}
// enter loop
loop := c.enterLoop()
// body statement
if err := c.Compile(stmt.Body); err != nil {
c.leaveLoop()
return err
}
c.leaveLoop()
// post-body position
postBodyPos := len(c.currentInstructions())
// post statement
if stmt.Post != nil {
if err := c.Compile(stmt.Post); err != nil {
return err
}
}
// back to condition
c.emit(stmt, parser.OpJump, preCondPos)
// post-statement position
postStmtPos := len(c.currentInstructions())
if postCondPos >= 0 {
c.changeOperand(postCondPos, postStmtPos)
}
// update all break/continue jump positions
for _, pos := range loop.Breaks {
c.changeOperand(pos, postStmtPos)
}
for _, pos := range loop.Continues {
c.changeOperand(pos, postBodyPos)
}
return nil
}
func (c *Compiler) compileForInStmt(stmt *parser.ForInStmt) error {
c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()
// for-in statement is compiled like following:
//
// for :it := iterator(iterable); :it.next(); {
// k, v := :it.get() // DEFINE operator
//
// ... body ...
// }
//
// ":it" is a local variable but will be conflict with other user variables
// because character ":" is not allowed.
// init
// :it = iterator(iterable)
itSymbol := c.symbolTable.Define(":it")
if err := c.Compile(stmt.Iterable); err != nil {
return err
}
c.emit(stmt, parser.OpIteratorInit)
if itSymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpSetGlobal, itSymbol.Index)
} else {
c.emit(stmt, parser.OpDefineLocal, itSymbol.Index)
}
// pre-condition position
preCondPos := len(c.currentInstructions())
// condition
// :it.HasMore()
if itSymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
} else {
c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
}
c.emit(stmt, parser.OpIteratorNext)
// condition jump position
postCondPos := c.emit(stmt, parser.OpJumpFalsy, 0)
// enter loop
loop := c.enterLoop()
// assign key variable
if stmt.Key.Name != "_" {
keySymbol := c.symbolTable.Define(stmt.Key.Name)
if itSymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
} else {
c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
}
c.emit(stmt, parser.OpIteratorKey)
if keySymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpSetGlobal, keySymbol.Index)
} else {
c.emit(stmt, parser.OpDefineLocal, keySymbol.Index)
}
}
// assign value variable
if stmt.Value.Name != "_" {
valueSymbol := c.symbolTable.Define(stmt.Value.Name)
if itSymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
} else {
c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
}
c.emit(stmt, parser.OpIteratorValue)
if valueSymbol.Scope == ScopeGlobal {
c.emit(stmt, parser.OpSetGlobal, valueSymbol.Index)
} else {
c.emit(stmt, parser.OpDefineLocal, valueSymbol.Index)
}
}
// body statement
if err := c.Compile(stmt.Body); err != nil {
c.leaveLoop()
return err
}
c.leaveLoop()
// post-body position
postBodyPos := len(c.currentInstructions())
// back to condition
c.emit(stmt, parser.OpJump, preCondPos)
// post-statement position
postStmtPos := len(c.currentInstructions())
c.changeOperand(postCondPos, postStmtPos)
// update all break/continue jump positions
for _, pos := range loop.Breaks {
c.changeOperand(pos, postStmtPos)
}
for _, pos := range loop.Continues {
c.changeOperand(pos, postBodyPos)
}
return nil
}
func (c *Compiler) checkCyclicImports(
node parser.Node,
modulePath string,
) error {
if c.modulePath == modulePath {
return c.errorf(node, "cyclic module import: %s", modulePath)
} else if c.parent != nil {
return c.parent.checkCyclicImports(node, modulePath)
}
return nil
}
func (c *Compiler) compileModule(
node parser.Node,
moduleName, modulePath string,
src []byte,
) (*CompiledFunction, error) {
if err := c.checkCyclicImports(node, modulePath); err != nil {
return nil, err
}
compiledModule, exists := c.loadCompiledModule(modulePath)
if exists {
return compiledModule, nil
}
modFile := c.file.Set().AddFile(moduleName, -1, len(src))
p := parser.NewParser(modFile, src, nil)
file, err := p.ParseFile()
if err != nil {
return nil, err
}
// inherit builtin functions
symbolTable := NewSymbolTable()
for _, sym := range c.symbolTable.BuiltinSymbols() {
symbolTable.DefineBuiltin(sym.Index, sym.Name)
}
// no global scope for the module
symbolTable = symbolTable.Fork(false)
// compile module
moduleCompiler := c.fork(modFile, modulePath, symbolTable)
if err := moduleCompiler.Compile(file); err != nil {
return nil, err
}
// code optimization
moduleCompiler.optimizeFunc(node)
compiledFunc := moduleCompiler.Bytecode().MainFunction
compiledFunc.NumLocals = symbolTable.MaxSymbols()
c.storeCompiledModule(modulePath, compiledFunc)
return compiledFunc, nil
}
func (c *Compiler) loadCompiledModule(
modulePath string,
) (mod *CompiledFunction, ok bool) {
if c.parent != nil {
return c.parent.loadCompiledModule(modulePath)
}
mod, ok = c.compiledModules[modulePath]
return
}
func (c *Compiler) storeCompiledModule(
modulePath string,
module *CompiledFunction,
) {
if c.parent != nil {
c.parent.storeCompiledModule(modulePath, module)
}
c.compiledModules[modulePath] = module
}
func (c *Compiler) enterLoop() *loop {
loop := &loop{}
c.loops = append(c.loops, loop)
c.loopIndex++
if c.trace != nil {
c.printTrace("LOOPE", c.loopIndex)
}
return loop
}
func (c *Compiler) leaveLoop() {
if c.trace != nil {
c.printTrace("LOOPL", c.loopIndex)
}
c.loops = c.loops[:len(c.loops)-1]
c.loopIndex--
}
func (c *Compiler) currentLoop() *loop {
if c.loopIndex >= 0 {
return c.loops[c.loopIndex]
}
return nil
}
func (c *Compiler) currentInstructions() []byte {
return c.scopes[c.scopeIndex].Instructions
}
func (c *Compiler) currentSourceMap() map[int]parser.Pos {
return c.scopes[c.scopeIndex].SourceMap
}
func (c *Compiler) enterScope() {
scope := compilationScope{
SymbolInit: make(map[string]bool),
SourceMap: make(map[int]parser.Pos),
}
c.scopes = append(c.scopes, scope)
c.scopeIndex++
c.symbolTable = c.symbolTable.Fork(false)
if c.trace != nil {
c.printTrace("SCOPE", c.scopeIndex)
}
}
func (c *Compiler) leaveScope() (
instructions []byte,
sourceMap map[int]parser.Pos,
) {
instructions = c.currentInstructions()
sourceMap = c.currentSourceMap()
c.scopes = c.scopes[:len(c.scopes)-1]
c.scopeIndex--
c.symbolTable = c.symbolTable.Parent(true)
if c.trace != nil {
c.printTrace("SCOPL", c.scopeIndex)
}
return
}
func (c *Compiler) fork(
file *parser.SourceFile,
modulePath string,
symbolTable *SymbolTable,
) *Compiler {
child := NewCompiler(file, symbolTable, nil, c.modules, c.trace)
child.modulePath = modulePath // module file path
child.parent = c // parent to set to current compiler
return child
}
func (c *Compiler) error(node parser.Node, err error) error {
return &CompilerError{
FileSet: c.file.Set(),
Node: node,
Err: err,
}
}
func (c *Compiler) errorf(
node parser.Node,
format string,
args ...interface{},
) error {
return &CompilerError{
FileSet: c.file.Set(),
Node: node,
Err: fmt.Errorf(format, args...),
}
}
func (c *Compiler) addConstant(o Object) int {
if c.parent != nil {
// module compilers will use their parent's constants array
return c.parent.addConstant(o)
}
c.constants = append(c.constants, o)
if c.trace != nil {
c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o))
}
return len(c.constants) - 1
}
func (c *Compiler) addInstruction(b []byte) int {
posNewIns := len(c.currentInstructions())
c.scopes[c.scopeIndex].Instructions = append(
c.currentInstructions(), b...)
return posNewIns
}
func (c *Compiler) replaceInstruction(pos int, inst []byte) {
copy(c.currentInstructions()[pos:], inst)
if c.trace != nil {
c.printTrace(fmt.Sprintf("REPLC %s",
FormatInstructions(
c.scopes[c.scopeIndex].Instructions[pos:], pos)[0]))
}
}
func (c *Compiler) changeOperand(opPos int, operand ...int) {
op := c.currentInstructions()[opPos]
inst := MakeInstruction(op, operand...)
c.replaceInstruction(opPos, inst)
}
// optimizeFunc performs some code-level optimization for the current function
// instructions. It also removes unreachable (dead code) instructions and adds
// "returns" instruction if needed.
func (c *Compiler) optimizeFunc(node parser.Node) {
// any instructions between RETURN and the function end
// or instructions between RETURN and jump target position
// are considered as unreachable.
// pass 1. identify all jump destinations
dsts := make(map[int]bool)
iterateInstructions(c.scopes[c.scopeIndex].Instructions,
func(pos int, opcode parser.Opcode, operands []int) bool {
switch opcode {
case parser.OpJump, parser.OpJumpFalsy,
parser.OpAndJump, parser.OpOrJump:
dsts[operands[0]] = true
}
return true
})
// pass 2. eliminate dead code
var newInsts []byte
posMap := make(map[int]int) // old position to new position
var dstIdx int
var deadCode bool
iterateInstructions(c.scopes[c.scopeIndex].Instructions,
func(pos int, opcode parser.Opcode, operands []int) bool {
switch {
case opcode == parser.OpReturn:
if deadCode {
return true
}
deadCode = true
case dsts[pos]:
dstIdx++
deadCode = false
case deadCode:
return true
}
posMap[pos] = len(newInsts)
newInsts = append(newInsts,
MakeInstruction(opcode, operands...)...)
return true
})
// pass 3. update jump positions
var lastOp parser.Opcode
var appendReturn bool
endPos := len(c.scopes[c.scopeIndex].Instructions)
iterateInstructions(newInsts,
func(pos int, opcode parser.Opcode, operands []int) bool {
switch opcode {
case parser.OpJump, parser.OpJumpFalsy, parser.OpAndJump,
parser.OpOrJump:
newDst, ok := posMap[operands[0]]
if ok {
copy(newInsts[pos:],
MakeInstruction(opcode, newDst))
} else if endPos == operands[0] {
// there's a jump instruction that jumps to the end of
// function compiler should append "return".
appendReturn = true
} else {
panic(fmt.Errorf("invalid jump position: %d", newDst))
}
}
lastOp = opcode
return true
})
if lastOp != parser.OpReturn {
appendReturn = true
}
// pass 4. update source map
newSourceMap := make(map[int]parser.Pos)
for pos, srcPos := range c.scopes[c.scopeIndex].SourceMap {
newPos, ok := posMap[pos]
if ok {
newSourceMap[newPos] = srcPos
}
}
c.scopes[c.scopeIndex].Instructions = newInsts
c.scopes[c.scopeIndex].SourceMap = newSourceMap
// append "return"
if appendReturn {
c.emit(node, parser.OpReturn, 0)
}
}
func (c *Compiler) emit(
node parser.Node,
opcode parser.Opcode,
operands ...int,
) int {
filePos := parser.NoPos
if node != nil {
filePos = node.Pos()
}
inst := MakeInstruction(opcode, operands...)
pos := c.addInstruction(inst)
c.scopes[c.scopeIndex].SourceMap[pos] = filePos
if c.trace != nil {
c.printTrace(fmt.Sprintf("EMIT %s",
FormatInstructions(
c.scopes[c.scopeIndex].Instructions[pos:], pos)[0]))
}
return pos
}
func (c *Compiler) printTrace(a ...interface{}) {
const (
dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
n = len(dots)
)
i := 2 * c.indent
for i > n {
_, _ = fmt.Fprint(c.trace, dots)
i -= n
}
_, _ = fmt.Fprint(c.trace, dots[0:i])
_, _ = fmt.Fprintln(c.trace, a...)
}
func resolveAssignLHS(
expr parser.Expr,
) (name string, selectors []parser.Expr) {
switch term := expr.(type) {
case *parser.SelectorExpr:
name, selectors = resolveAssignLHS(term.Expr)
selectors = append(selectors, term.Sel)
return
case *parser.IndexExpr:
name, selectors = resolveAssignLHS(term.Expr)
selectors = append(selectors, term.Index)
case *parser.Ident:
name = term.Name
}
return
}
func iterateInstructions(
b []byte,
fn func(pos int, opcode parser.Opcode, operands []int) bool,
) {
for i := 0; i < len(b); i++ {
numOperands := parser.OpcodeOperands[b[i]]
operands, read := parser.ReadOperands(numOperands, b[i+1:])
if !fn(i, b[i], operands) {
break
}
i += read
}
}
func tracec(c *Compiler, msg string) *Compiler {
c.printTrace(msg, "{")
c.indent++
return c
}
func untracec(c *Compiler) {
c.indent--
c.printTrace("}")
}