133 lines
2.7 KiB
Go
Raw Normal View History

2018-06-21 10:49:35 -04:00
package jws
import (
"bytes"
"encoding/json"
)
// Flat serializes the JWS to its "flattened" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.2
func (j *jws) Flat(key interface{}) ([]byte, error) {
if len(j.sb) < 1 {
return nil, ErrNotEnoughMethods
}
if err := j.sign(key); err != nil {
return nil, err
}
return json.Marshal(struct {
Payload rawBase64 `json:"payload"`
sigHead
}{
Payload: j.plcache,
sigHead: j.sb[0],
})
}
// General serializes the JWS into its "general" form per
// https://tools.ietf.org/html/rfc7515#section-7.2.1
//
// If only one key is passed it's used for all the provided
// crypto.SigningMethods. Otherwise, len(keys) must equal the number
// of crypto.SigningMethods added.
func (j *jws) General(keys ...interface{}) ([]byte, error) {
if err := j.sign(keys...); err != nil {
return nil, err
}
return json.Marshal(struct {
Payload rawBase64 `json:"payload"`
Signatures []sigHead `json:"signatures"`
}{
Payload: j.plcache,
Signatures: j.sb,
})
}
// Compact serializes the JWS into its "compact" form per
// https://tools.ietf.org/html/rfc7515#section-7.1
func (j *jws) Compact(key interface{}) ([]byte, error) {
if len(j.sb) < 1 {
return nil, ErrNotEnoughMethods
}
if err := j.sign(key); err != nil {
return nil, err
}
sig, err := j.sb[0].Signature.Base64()
if err != nil {
return nil, err
}
return format(
j.sb[0].Protected,
j.plcache,
sig,
), nil
}
// sign signs each index of j's sb member.
func (j *jws) sign(keys ...interface{}) error {
if err := j.cache(); err != nil {
return err
}
if len(keys) < 1 ||
len(keys) > 1 && len(keys) != len(j.sb) {
return ErrNotEnoughKeys
}
if len(keys) == 1 {
k := keys[0]
keys = make([]interface{}, len(j.sb))
for i := range keys {
keys[i] = k
}
}
for i := range j.sb {
if err := j.sb[i].cache(); err != nil {
return err
}
raw := format(j.sb[i].Protected, j.plcache)
sig, err := j.sb[i].method.Sign(raw, keys[i])
if err != nil {
return err
}
j.sb[i].Signature = sig
}
return nil
}
// cache marshals the payload, but only if it's changed since the last cache.
func (j *jws) cache() (err error) {
if !j.clean {
j.plcache, err = j.payload.Base64()
j.clean = err == nil
}
return err
}
// cache marshals the protected and unprotected headers, but only if
// they've changed since their last cache.
func (s *sigHead) cache() (err error) {
if !s.clean {
s.Protected, err = s.protected.Base64()
if err != nil {
return err
}
s.Unprotected, err = s.unprotected.Base64()
if err != nil {
return err
}
}
s.clean = true
return nil
}
// format formats a slice of bytes in the order given, joining
// them with a period.
func format(a ...[]byte) []byte {
return bytes.Join(a, []byte{'.'})
}