2022-04-06 10:36:06 -04:00

299 lines
6.6 KiB
Go

package parse
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"io"
"os"
"strings"
"unicode"
"golang.org/x/tools/imports"
)
var header = []byte(`
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
`)
var (
packageKeyword = []byte("package")
importKeyword = []byte("import")
openBrace = []byte("(")
closeBrace = []byte(")")
genericPackage = "generic"
genericType = "generic.Type"
genericNumber = "generic.Number"
linefeed = "\r\n"
)
var unwantedLinePrefixes = [][]byte{
[]byte("//go:generate genny "),
}
func subIntoLiteral(lit, typeTemplate, specificType string) string {
if lit == typeTemplate {
return specificType
}
if !strings.Contains(lit, typeTemplate) {
return lit
}
specificLg := wordify(specificType, true)
specificSm := wordify(specificType, false)
result := strings.Replace(lit, typeTemplate, specificLg, -1)
if strings.HasPrefix(result, specificLg) && !isExported(lit) {
return strings.Replace(result, specificLg, specificSm, 1)
}
return result
}
func subTypeIntoComment(line, typeTemplate, specificType string) string {
var subbed string
for _, w := range strings.Fields(line) {
subbed = subbed + subIntoLiteral(w, typeTemplate, specificType) + " "
}
return subbed
}
// Does the heavy lifting of taking a line of our code and
// sbustituting a type into there for our generic type
func subTypeIntoLine(line, typeTemplate, specificType string) string {
src := []byte(line)
var s scanner.Scanner
fset := token.NewFileSet()
file := fset.AddFile("", fset.Base(), len(src))
s.Init(file, src, nil, scanner.ScanComments)
output := ""
for {
_, tok, lit := s.Scan()
if tok == token.EOF {
break
} else if tok == token.COMMENT {
subbed := subTypeIntoComment(lit, typeTemplate, specificType)
output = output + subbed + " "
} else if tok.IsLiteral() {
subbed := subIntoLiteral(lit, typeTemplate, specificType)
output = output + subbed + " "
} else {
output = output + tok.String() + " "
}
}
return output
}
// typeSet looks like "KeyType: int, ValueType: string"
func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) {
// ensure we are at the beginning of the file
in.Seek(0, os.SEEK_SET)
// parse the source file
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, filename, in, 0)
if err != nil {
return nil, &errSource{Err: err}
}
// make sure every generic.Type is represented in the types
// argument.
for _, decl := range file.Decls {
switch it := decl.(type) {
case *ast.GenDecl:
for _, spec := range it.Specs {
ts, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
switch tt := ts.Type.(type) {
case *ast.SelectorExpr:
if name, ok := tt.X.(*ast.Ident); ok {
if name.Name == genericPackage {
if _, ok := typeSet[ts.Name.Name]; !ok {
return nil, &errMissingSpecificType{GenericType: ts.Name.Name}
}
}
}
}
}
}
}
in.Seek(0, os.SEEK_SET)
var buf bytes.Buffer
comment := ""
scanner := bufio.NewScanner(in)
for scanner.Scan() {
line := scanner.Text()
// does this line contain generic.Type?
if strings.Contains(line, genericType) || strings.Contains(line, genericNumber) {
comment = ""
continue
}
for t, specificType := range typeSet {
if strings.Contains(line, t) {
newLine := subTypeIntoLine(line, t, specificType)
line = newLine
}
}
if comment != "" {
buf.WriteString(makeLine(comment))
comment = ""
}
// is this line a comment?
// TODO: should we handle /* */ comments?
if strings.HasPrefix(line, "//") {
// record this line to print later
comment = line
continue
}
// write the line
buf.WriteString(makeLine(line))
}
// write it out
return buf.Bytes(), nil
}
// Generics parses the source file and generates the bytes replacing the
// generic types for the keys map with the specific types (its value).
func Generics(filename, pkgName string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) {
totalOutput := header
for _, typeSet := range typeSets {
// generate the specifics
parsed, err := generateSpecific(filename, in, typeSet)
if err != nil {
return nil, err
}
totalOutput = append(totalOutput, parsed...)
}
// clean up the code line by line
packageFound := false
insideImportBlock := false
var cleanOutputLines []string
scanner := bufio.NewScanner(bytes.NewReader(totalOutput))
for scanner.Scan() {
// end of imports block?
if insideImportBlock {
if bytes.HasSuffix(scanner.Bytes(), closeBrace) {
insideImportBlock = false
}
continue
}
if bytes.HasPrefix(scanner.Bytes(), packageKeyword) {
if packageFound {
continue
} else {
packageFound = true
}
} else if bytes.HasPrefix(scanner.Bytes(), importKeyword) {
if bytes.HasSuffix(scanner.Bytes(), openBrace) {
insideImportBlock = true
}
continue
}
// check all unwantedLinePrefixes - and skip them
skipline := false
for _, prefix := range unwantedLinePrefixes {
if bytes.HasPrefix(scanner.Bytes(), prefix) {
skipline = true
continue
}
}
if skipline {
continue
}
cleanOutputLines = append(cleanOutputLines, makeLine(scanner.Text()))
}
cleanOutput := strings.Join(cleanOutputLines, "")
output := []byte(cleanOutput)
var err error
// change package name
if pkgName != "" {
output = changePackage(bytes.NewReader([]byte(output)), pkgName)
}
// fix the imports
output, err = imports.Process(filename, output, nil)
if err != nil {
return nil, &errImports{Err: err}
}
return output, nil
}
func makeLine(s string) string {
return fmt.Sprintln(strings.TrimRight(s, linefeed))
}
// isAlphaNumeric gets whether the rune is alphanumeric or _.
func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
}
// wordify turns a type into a nice word for function and type
// names etc.
func wordify(s string, exported bool) string {
s = strings.TrimRight(s, "{}")
s = strings.TrimLeft(s, "*&")
s = strings.Replace(s, ".", "", -1)
if !exported {
return s
}
return strings.ToUpper(string(s[0])) + s[1:]
}
func changePackage(r io.Reader, pkgName string) []byte {
var out bytes.Buffer
sc := bufio.NewScanner(r)
done := false
for sc.Scan() {
s := sc.Text()
if !done && strings.HasPrefix(s, "package") {
parts := strings.Split(s, " ")
parts[1] = pkgName
s = strings.Join(parts, " ")
done = true
}
fmt.Fprintln(&out, s)
}
return out.Bytes()
}
func isExported(lit string) bool {
if len(lit) == 0 {
return false
}
return unicode.IsUpper(rune(lit[0]))
}