mirror of
https://github.com/status-im/consul.git
synced 2025-01-24 12:40:17 +00:00
5fb9df1640
* Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
582 lines
13 KiB
Go
582 lines
13 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
flagPath = flag.String("path", "", "path of file to load")
|
|
verbose = flag.Bool("v", false, "verbose output")
|
|
)
|
|
|
|
const (
|
|
annotationPrefix = "@consul-rpc-glue:"
|
|
outputFileSuffix = ".rpcglue.pb.go"
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
log.SetFlags(0)
|
|
|
|
if *flagPath == "" {
|
|
log.Fatal("missing required -path argument")
|
|
}
|
|
|
|
if err := run(*flagPath); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func run(path string) error {
|
|
fi, err := os.Stat(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if fi.IsDir() {
|
|
return fmt.Errorf("argument must be a file: %s", path)
|
|
}
|
|
|
|
if !strings.HasSuffix(path, ".pb.go") {
|
|
return fmt.Errorf("file must end with .pb.go: %s", path)
|
|
}
|
|
|
|
if err := processFile(path); err != nil {
|
|
return fmt.Errorf("error processing file %q: %v", path, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func processFile(path string) error {
|
|
if *verbose {
|
|
log.Printf("visiting file %q", path)
|
|
}
|
|
|
|
fset := token.NewFileSet()
|
|
tree, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
v := visitor{}
|
|
ast.Walk(&v, tree)
|
|
if err := v.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(v.Types) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if *verbose {
|
|
log.Printf("Package: %s", v.Package)
|
|
log.Printf("BuildTags: %v", v.BuildTags)
|
|
log.Println()
|
|
for _, typ := range v.Types {
|
|
log.Printf("Type: %s", typ.Name)
|
|
ann := typ.Annotation
|
|
if ann.ReadRequest != "" {
|
|
log.Printf(" ReadRequest from %s", ann.ReadRequest)
|
|
}
|
|
if ann.WriteRequest != "" {
|
|
log.Printf(" WriteRequest from %s", ann.WriteRequest)
|
|
}
|
|
if ann.TargetDatacenter != "" {
|
|
log.Printf(" TargetDatacenter from %s", ann.TargetDatacenter)
|
|
}
|
|
if ann.QueryOptions != "" {
|
|
log.Printf(" QueryOptions from %s", ann.QueryOptions)
|
|
}
|
|
if ann.QueryMeta != "" {
|
|
log.Printf(" QueryMeta from %s", ann.QueryMeta)
|
|
}
|
|
if ann.Datacenter != "" {
|
|
log.Printf(" Datacenter from %s", ann.Datacenter)
|
|
}
|
|
}
|
|
}
|
|
|
|
// generate output
|
|
|
|
var buf bytes.Buffer
|
|
|
|
if len(v.BuildTags) > 0 {
|
|
for _, line := range v.BuildTags {
|
|
buf.WriteString(line + "\n")
|
|
}
|
|
buf.WriteString("\n")
|
|
}
|
|
buf.WriteString("// Code generated by proto-gen-rpc-glue. DO NOT EDIT.\n\n")
|
|
buf.WriteString("package " + v.Package + "\n")
|
|
buf.WriteString(`
|
|
import (
|
|
"time"
|
|
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
)
|
|
|
|
// Reference imports to suppress errors if they are not otherwise used.
|
|
var _ structs.RPCInfo
|
|
var _ time.Month
|
|
|
|
`)
|
|
for _, typ := range v.Types {
|
|
if typ.Annotation.WriteRequest != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplWriteRequest, typ.Name, typ.Annotation.WriteRequest))
|
|
}
|
|
if typ.Annotation.ReadRequest != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplReadRequest, typ.Name, typ.Annotation.ReadRequest))
|
|
}
|
|
if typ.Annotation.TargetDatacenter != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplTargetDatacenter, typ.Name, typ.Annotation.TargetDatacenter))
|
|
}
|
|
if typ.Annotation.QueryOptions != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplQueryOptions, typ.Name, typ.Annotation.QueryOptions))
|
|
}
|
|
if typ.Annotation.QueryMeta != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplQueryMeta, typ.Name, typ.Annotation.QueryMeta))
|
|
}
|
|
if typ.Annotation.Datacenter != "" {
|
|
buf.WriteString(fmt.Sprintf(tmplDatacenter, typ.Name, typ.Annotation.Datacenter))
|
|
}
|
|
}
|
|
|
|
// write to disk
|
|
outFile := strings.TrimSuffix(path, ".pb.go") + outputFileSuffix
|
|
if err := os.WriteFile(outFile, buf.Bytes(), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
// clean up
|
|
cmd := exec.Command("gofmt", "-s", "-w", outFile)
|
|
cmd.Stdout = nil
|
|
cmd.Stderr = os.Stderr
|
|
cmd.Stdin = nil
|
|
if err := cmd.Run(); err != nil {
|
|
return fmt.Errorf("error running 'gofmt -s -w %q': %v", outFile, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type TypeInfo struct {
|
|
Name string
|
|
Annotation Annotation
|
|
}
|
|
|
|
type visitor struct {
|
|
Package string
|
|
BuildTags []string
|
|
Types []TypeInfo
|
|
Errs []error
|
|
}
|
|
|
|
func (v *visitor) Err() error {
|
|
switch len(v.Errs) {
|
|
case 0:
|
|
return nil
|
|
case 1:
|
|
return v.Errs[0]
|
|
default:
|
|
//
|
|
var s []string
|
|
for _, e := range v.Errs {
|
|
s = append(s, e.Error())
|
|
}
|
|
return errors.New(strings.Join(s, "; "))
|
|
}
|
|
}
|
|
|
|
var _ ast.Visitor = (*visitor)(nil)
|
|
|
|
func (v *visitor) Visit(node ast.Node) ast.Visitor {
|
|
if node == nil {
|
|
return v
|
|
}
|
|
|
|
switch x := node.(type) {
|
|
case *ast.File:
|
|
v.Package = x.Name.Name
|
|
v.BuildTags = getRawBuildTags(x)
|
|
for _, d := range x.Decls {
|
|
gd, ok := d.(*ast.GenDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if gd.Doc == nil {
|
|
continue
|
|
} else if len(gd.Specs) != 1 {
|
|
continue
|
|
}
|
|
spec := gd.Specs[0]
|
|
|
|
typeSpec, ok := spec.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
ann, err := getAnnotation(gd.Doc.List)
|
|
if err != nil {
|
|
v.Errs = append(v.Errs, err)
|
|
continue
|
|
} else if ann.IsZero() {
|
|
continue
|
|
}
|
|
|
|
v.Types = append(v.Types, TypeInfo{
|
|
Name: typeSpec.Name.Name,
|
|
Annotation: ann,
|
|
})
|
|
|
|
}
|
|
}
|
|
return v
|
|
}
|
|
|
|
type Annotation struct {
|
|
QueryMeta string
|
|
QueryOptions string
|
|
ReadRequest string
|
|
WriteRequest string
|
|
TargetDatacenter string
|
|
Datacenter string
|
|
ReadTODO string
|
|
LeaderReadTODO string
|
|
WriteTODO string
|
|
}
|
|
|
|
func (a Annotation) IsZero() bool {
|
|
return a == Annotation{}
|
|
}
|
|
|
|
func getAnnotation(doc []*ast.Comment) (Annotation, error) {
|
|
raw, ok := getRawStructAnnotation(doc)
|
|
if !ok {
|
|
return Annotation{}, nil
|
|
}
|
|
|
|
var ann Annotation
|
|
|
|
parts := strings.Split(raw, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
switch {
|
|
case part == "ReadRequest":
|
|
ann.ReadRequest = "ReadRequest"
|
|
case strings.HasPrefix(part, "ReadRequest="):
|
|
ann.ReadRequest = strings.TrimPrefix(part, "ReadRequest=")
|
|
|
|
case part == "WriteRequest":
|
|
ann.WriteRequest = "WriteRequest"
|
|
case strings.HasPrefix(part, "WriteRequest="):
|
|
ann.WriteRequest = strings.TrimPrefix(part, "WriteRequest=")
|
|
|
|
case part == "TargetDatacenter":
|
|
ann.TargetDatacenter = "TargetDatacenter"
|
|
case strings.HasPrefix(part, "TargetDatacenter="):
|
|
ann.TargetDatacenter = strings.TrimPrefix(part, "TargetDatacenter=")
|
|
|
|
case part == "QueryOptions":
|
|
ann.QueryOptions = "QueryOptions"
|
|
case strings.HasPrefix(part, "QueryOptions="):
|
|
ann.QueryOptions = strings.TrimPrefix(part, "QueryOptions=")
|
|
|
|
case part == "QueryMeta":
|
|
ann.QueryMeta = "QueryMeta"
|
|
case strings.HasPrefix(part, "QueryMeta="):
|
|
ann.QueryMeta = strings.TrimPrefix(part, "QueryMeta=")
|
|
|
|
case part == "Datacenter":
|
|
ann.Datacenter = "Datacenter"
|
|
case strings.HasPrefix(part, "Datacenter="):
|
|
ann.Datacenter = strings.TrimPrefix(part, "Datacenter=")
|
|
|
|
default:
|
|
return Annotation{}, fmt.Errorf("unexpected annotation part: %s", part)
|
|
}
|
|
}
|
|
|
|
return ann, nil
|
|
}
|
|
|
|
func getRawStructAnnotation(doc []*ast.Comment) (string, bool) {
|
|
for _, line := range doc {
|
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
|
|
|
|
ann := strings.TrimSpace(strings.TrimPrefix(text, annotationPrefix))
|
|
|
|
if text != ann {
|
|
return ann, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func getRawBuildTags(file *ast.File) []string {
|
|
// build tags are always the first group, at the very top
|
|
if len(file.Comments) == 0 {
|
|
return nil
|
|
}
|
|
cg := file.Comments[0]
|
|
|
|
var out []string
|
|
for _, line := range cg.List {
|
|
text := strings.TrimSpace(strings.TrimLeft(line.Text, "/"))
|
|
|
|
if !strings.HasPrefix(text, "go:build ") && !strings.HasPrefix(text, "+build") {
|
|
break // stop at first non-build-tag
|
|
}
|
|
|
|
out = append(out, line.Text)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
const tmplWriteRequest = `
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
return false
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return false
|
|
}
|
|
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
`
|
|
|
|
const tmplReadRequest = `
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
// TODO: initialize if nil
|
|
return msg.%[2]s.AllowStaleRead()
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
`
|
|
|
|
const tmplTargetDatacenter = `
|
|
// RequestDatacenter implements structs.RPCInfo
|
|
func (msg *%[1]s) RequestDatacenter() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.GetDatacenter()
|
|
}
|
|
`
|
|
|
|
const tmplDatacenter = `
|
|
// RequestDatacenter implements structs.RPCInfo
|
|
func (msg *%[1]s) RequestDatacenter() string {
|
|
if msg == nil {
|
|
return ""
|
|
}
|
|
return msg.Datacenter
|
|
}
|
|
`
|
|
|
|
const tmplQueryOptions = `
|
|
// IsRead implements structs.RPCInfo
|
|
func (msg *%[1]s) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
// AllowStaleRead implements structs.RPCInfo
|
|
func (msg *%[1]s) AllowStaleRead() bool {
|
|
return msg.%[2]s.AllowStaleRead()
|
|
}
|
|
|
|
// BlockingTimeout implements pool.BlockableQuery
|
|
func (msg *%[1]s) BlockingTimeout(maxQueryTime, defaultQueryTime time.Duration) time.Duration {
|
|
maxTime := structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime())
|
|
o := structs.QueryOptions{
|
|
MaxQueryTime: maxTime,
|
|
MinQueryIndex: msg.%[2]s.GetMinQueryIndex(),
|
|
}
|
|
return o.BlockingTimeout(maxQueryTime, defaultQueryTime)
|
|
}
|
|
|
|
// HasTimedOut implements structs.RPCInfo
|
|
func (msg *%[1]s) HasTimedOut(start time.Time, rpcHoldTimeout time.Duration, a time.Duration, b time.Duration) (bool, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false, nil
|
|
}
|
|
return msg.%[2]s.HasTimedOut(start, rpcHoldTimeout, a, b)
|
|
}
|
|
|
|
// SetTokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) SetTokenSecret(s string) {
|
|
// TODO: initialize if nil
|
|
msg.%[2]s.SetTokenSecret(s)
|
|
}
|
|
|
|
// TokenSecret implements structs.RPCInfo
|
|
func (msg *%[1]s) TokenSecret() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.TokenSecret()
|
|
}
|
|
|
|
// Token implements structs.RPCInfo
|
|
func (msg *%[1]s) Token() string {
|
|
if msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
return msg.%[2]s.Token
|
|
}
|
|
// GetToken is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetToken() string {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return ""
|
|
}
|
|
|
|
return msg.%[2]s.GetToken()
|
|
}
|
|
// GetMinQueryIndex is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetMinQueryIndex() uint64 {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0
|
|
}
|
|
|
|
return msg.%[2]s.GetMinQueryIndex()
|
|
}
|
|
// GetMaxQueryTime is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetMaxQueryTime() (time.Duration, error) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
return structs.DurationFromProto(msg.%[2]s.GetMaxQueryTime()), nil
|
|
}
|
|
|
|
// GetRequireConsistent is required to implement blockingQueryOptions
|
|
func (msg *%[1]s) GetRequireConsistent() bool {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return false
|
|
}
|
|
return msg.%[2]s.RequireConsistent
|
|
}
|
|
`
|
|
|
|
const tmplQueryMeta = `
|
|
// SetLastContact is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetLastContact(d time.Duration) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetLastContact(d)
|
|
}
|
|
|
|
// SetKnownLeader is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetKnownLeader(b bool) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetKnownLeader(b)
|
|
}
|
|
|
|
// GetIndex is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) GetIndex() uint64 {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return 0
|
|
}
|
|
return msg.%[2]s.GetIndex()
|
|
}
|
|
|
|
// SetIndex is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetIndex(i uint64) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetIndex(i)
|
|
}
|
|
|
|
// SetResultsFilteredByACLs is required to implement blockingQueryResponseMeta
|
|
func (msg *%[1]s) SetResultsFilteredByACLs(b bool) {
|
|
if msg == nil || msg.%[2]s == nil {
|
|
return
|
|
}
|
|
msg.%[2]s.SetResultsFilteredByACLs(b)
|
|
}
|
|
`
|