consul/internal/tools/protoc-gen-consul-rate-limit/main.go

151 lines
4.3 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
[COMPLIANCE] License changes (#18443) * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 13:12:13 +00:00
// SPDX-License-Identifier: BUSL-1.1
// protoc-gen-consul-rate-limit
// This protoc plugin maintains the mapping of gRPC method names to
2023-01-04 16:07:02 +00:00
// a specification of how they should be rate-limited. This is used by the gRPC
// InTapHandle function (see agent/grpc-middleware/rate.go) to enforce relevant
// limits without having to call the handler. It works in two phases:
2023-01-04 16:07:02 +00:00
//
// # Phase 1
2023-01-04 16:07:02 +00:00
//
// Buf/protoc invokes this plugin for each .proto file. We extract the rate limit
// specification from an annotation on the RPC:
2023-01-04 16:07:02 +00:00
//
// service Foo {
// rpc Bar(BarRequest) returns (BarResponse) {
// option (hashicorp.consul.internal.ratelimit.spec) = {
// operation_type: OPERATION_TYPE_WRITE,
// operation_category: OPERATION_CATEGORY_ACL
// };
// }
// }
2023-01-04 16:07:02 +00:00
//
// We write a JSON array of the limits to protobuf/package/path/.ratelimit.tmp:
2023-01-04 16:07:02 +00:00
//
// [
// {
// "MethodName": "/Foo/Bar",
// "OperationType": "OPERATION_TYPE_WRITE",
// "OperationCategory": "OPERATION_CATEGORY_ACL"
// }
// ]
2023-01-04 16:07:02 +00:00
//
// # Phase 2
//
// The protobuf.sh script (invoked by make proto) runs our postprocess script
// which reads all of the .ratelimit.tmp files in proto and proto-public and
// generates a single Go map in agent/grpc-middleware/rate_limit_mappings.gen.go
2023-01-04 16:07:02 +00:00
package main
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/consul/proto-public/annotations/ratelimit"
)
const (
outputFileName = ".ratelimit.tmp"
missingSpecTmpl = `RPC %s is missing rate-limit specification, fix it with:
import "proto-public/annotations/ratelimit/ratelimit.proto";
service %s {
rpc %s(...) returns (...) {
option (hashicorp.consul.internal.ratelimit.spec) = {
operation_type: OPERATION_TYPE_READ | OPERATION_TYPE_WRITE | OPERATION_TYPE_EXEMPT,
operation_category: OPERATION_CATEGORY_ACL | OPERATION_CATEGORY_PEER_STREAM | OPERATION_CATEGORY_CONNECT_CA | OPERATION_CATEGORY_PARTITION | OPERATION_CATEGORY_PEERING | OPERATION_CATEGORY_SERVER_DISCOVERY | OPERATION_CATEGORY_DATAPLANE | OPERATION_CATEGORY_DNS | OPERATION_CATEGORY_SUBSCRIBE | OPERATION_CATEGORY_OPERATOR | OPERATION_CATEGORY_RESOURCE | OPERATION_CATEGORY_CONFIGENTRY,
2023-01-04 16:07:02 +00:00
};
}
}
`
enterpriseBuildTag = "//go:build consulent"
)
type rateLimitSpec struct {
MethodName string
OperationType string
OperationCategory string
Enterprise bool
2023-01-04 16:07:02 +00:00
}
func main() {
var opts protogen.Options
opts.Run(func(plugin *protogen.Plugin) error {
for _, path := range plugin.Request.FileToGenerate {
file, ok := plugin.FilesByPath[path]
if !ok {
return fmt.Errorf("failed to get file descriptor: %s", path)
}
specs, err := rateLimitSpecs(file)
if err != nil {
return err
}
if len(specs) == 0 {
continue
2023-01-04 16:07:02 +00:00
}
outputPath := filepath.Join(filepath.Dir(path), outputFileName)
output := plugin.NewGeneratedFile(outputPath, "")
if err := json.NewEncoder(output).Encode(specs); err != nil {
return err
}
}
return nil
})
}
func rateLimitSpecs(file *protogen.File) ([]rateLimitSpec, error) {
enterprise, err := isEnterpriseFile(file)
if err != nil {
return nil, err
}
var specs []rateLimitSpec
for _, service := range file.Services {
for _, method := range service.Methods {
spec := rateLimitSpec{
// Format the method name in gRPC/HTTP path format.
MethodName: fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()),
Enterprise: enterprise,
}
// Read the rate limit spec from the method options.
options := method.Desc.Options()
if !proto.HasExtension(options, ratelimit.E_Spec) {
err := fmt.Errorf(missingSpecTmpl,
method.Desc.Name(),
service.Desc.Name(),
method.Desc.Name())
return nil, err
}
def := proto.GetExtension(options, ratelimit.E_Spec).(*ratelimit.Spec)
spec.OperationType = def.OperationType.String()
spec.OperationCategory = def.OperationCategory.String()
2023-01-04 16:07:02 +00:00
specs = append(specs, spec)
}
}
return specs, nil
}
func isEnterpriseFile(file *protogen.File) (bool, error) {
source, err := os.ReadFile(file.Desc.Path())
if err != nil {
return false, fmt.Errorf("failed to read proto file: %w", err)
}
return bytes.Contains(source, []byte(enterpriseBuildTag)), nil
}