mirror of
https://github.com/status-im/consul.git
synced 2025-02-14 22:58:42 +00:00
Fixes bug in #15564 where gofmt would strip out the generated code warning comment because it was on the same line as the build tag.
197 lines
4.5 KiB
Go
197 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"go/format"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
usage = "Usage: %s -input=/proto-dir-1 -input=/proto-dir-2 -output=/mappings.go\n"
|
|
|
|
fileHeader = `// generated by protoc-gen-consul-rate-limit; DO NOT EDIT.
|
|
package middleware
|
|
|
|
import "github.com/hashicorp/consul/agent/consul/rate"
|
|
`
|
|
|
|
entTags = `//go:build consulent
|
|
// +build consulent
|
|
`
|
|
)
|
|
|
|
func main() {
|
|
var (
|
|
inputPaths sliceFlags
|
|
outputPath string
|
|
)
|
|
flag.Var(&inputPaths, "input", "")
|
|
flag.StringVar(&outputPath, "output", "", "")
|
|
flag.Parse()
|
|
|
|
if len(inputPaths) == 0 || outputPath == "" {
|
|
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
|
os.Exit(1)
|
|
}
|
|
|
|
if err := run(inputPaths, outputPath); err != nil {
|
|
fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func run(inputPaths []string, outputPath string) error {
|
|
if !strings.HasSuffix(outputPath, ".go") {
|
|
return errors.New("-output path must end in .go")
|
|
}
|
|
|
|
oss, ent, err := collectSpecs(inputPaths)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ossSource, err := generateOSS(oss)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.WriteFile(outputPath, ossSource, 0666); err != nil {
|
|
return fmt.Errorf("failed to write output file: %s - %w", outputPath, err)
|
|
}
|
|
|
|
// ent should only be non-zero in the enterprise repository.
|
|
if len(ent) > 0 {
|
|
entSource, err := generateENT(ent)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := os.WriteFile(enterpriseFileName(outputPath), entSource, 0666); err != nil {
|
|
return fmt.Errorf("failed to write output file: %s - %w", outputPath, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// enterpriseFileName adds the _ent filename suffix before the extension.
|
|
//
|
|
// Example:
|
|
// enterpriseFileName("bar/baz/foo.gen.go") => "bar/baz/foo_ent.gen.go"
|
|
func enterpriseFileName(filename string) string {
|
|
fileName := filepath.Base(filename)
|
|
extStart := strings.Index(fileName, ".")
|
|
return filepath.Join(
|
|
filepath.Dir(filename),
|
|
fileName[0:extStart]+"_ent"+fileName[extStart:],
|
|
)
|
|
}
|
|
|
|
type spec struct {
|
|
MethodName string
|
|
OperationType string
|
|
Enterprise bool
|
|
}
|
|
|
|
func (s spec) GoOperationType() string {
|
|
switch s.OperationType {
|
|
case "OPERATION_TYPE_WRITE":
|
|
return "rate.OperationTypeWrite"
|
|
case "OPERATION_TYPE_READ":
|
|
return "rate.OperationTypeRead"
|
|
case "OPERATION_TYPE_EXEMPT":
|
|
return "rate.OperationTypeExempt"
|
|
}
|
|
panic(fmt.Sprintf("unknown rate limit operation type: %s", s.OperationType))
|
|
}
|
|
|
|
func collectSpecs(inputPaths []string) ([]spec, []spec, error) {
|
|
var specs []spec
|
|
for _, protoPath := range inputPaths {
|
|
specFiles, err := filepath.Glob(filepath.Join(protoPath, "*", ".ratelimit.tmp"))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to glob directory: %s - %s", protoPath, err)
|
|
}
|
|
|
|
for _, file := range specFiles {
|
|
b, err := os.ReadFile(file)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to read ratelimit file: %w", err)
|
|
}
|
|
|
|
var fileSpecs []spec
|
|
if err := json.Unmarshal(b, &fileSpecs); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to unmarshal ratelimit file %s - %w", file, err)
|
|
}
|
|
specs = append(specs, fileSpecs...)
|
|
}
|
|
}
|
|
|
|
sort.Slice(specs, func(a, b int) bool {
|
|
return specs[a].MethodName < specs[b].MethodName
|
|
})
|
|
|
|
var oss, ent []spec
|
|
for _, spec := range specs {
|
|
if spec.Enterprise {
|
|
ent = append(ent, spec)
|
|
} else {
|
|
oss = append(oss, spec)
|
|
}
|
|
}
|
|
|
|
return oss, ent, nil
|
|
}
|
|
|
|
func generateOSS(specs []spec) ([]byte, error) {
|
|
var output bytes.Buffer
|
|
output.WriteString(fileHeader)
|
|
|
|
fmt.Fprintln(&output, `var rpcRateLimitSpecs = map[string]rate.OperationType{`)
|
|
for _, spec := range specs {
|
|
fmt.Fprintf(&output, `"%s": %s,`, spec.MethodName, spec.GoOperationType())
|
|
output.WriteString("\n")
|
|
}
|
|
output.WriteString("}")
|
|
|
|
formatted, err := format.Source(output.Bytes())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to format source: %w", err)
|
|
}
|
|
return formatted, nil
|
|
}
|
|
|
|
func generateENT(specs []spec) ([]byte, error) {
|
|
var output bytes.Buffer
|
|
output.WriteString(entTags)
|
|
output.WriteString(fileHeader)
|
|
|
|
output.WriteString("func init() {\n")
|
|
for _, spec := range specs {
|
|
fmt.Fprintf(&output, `rpcRateLimitSpecs["%s"] = %s`, spec.MethodName, spec.GoOperationType())
|
|
output.WriteString("\n")
|
|
}
|
|
output.WriteString("}")
|
|
|
|
formatted, err := format.Source(output.Bytes())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to format source: %w", err)
|
|
}
|
|
return formatted, nil
|
|
}
|
|
|
|
type sliceFlags []string
|
|
|
|
func (i *sliceFlags) Set(value string) error {
|
|
*i = append(*i, value)
|
|
return nil
|
|
}
|
|
|
|
func (i *sliceFlags) String() string { return "" }
|