Skip to content

Commit

Permalink
Implement per-object custom typed fields (#391)
Browse files Browse the repository at this point in the history
* preliminary per-object typed claim handling

* Tests for issue 384

* Edit tests

* unify style
* use table-driven tests
* remove test for field name

* appease linter

* Tweak things to be a bit more generic

This is preparation to make things available for other packages

* slightly optimize

* Use stronger words of warning

I think this is useful, but also very dangerous. here be dragons

* Return an error if a typed claim was specified, but token does not support it

* Add missing definitions

* Implement TypedField in jwk

* consolidate definitions

* Implement typed fields via jwk.Parse (i.e. jwk.Set)

* Add notes in Changes

* appease linter

Co-authored-by: Mikhail Fludkov <[email protected]>
  • Loading branch information
lestrrat and floodkoff authored Jun 1, 2021
1 parent b56d088 commit 63a21f9
Show file tree
Hide file tree
Showing 20 changed files with 725 additions and 36 deletions.
17 changes: 17 additions & 0 deletions Changes
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
Changes
=======

v1.2.1
[New features]
* Option `jwt.WithTypedClaim()` and `jwk.WithTypedField()` have been added.
They allow a per-object custom conversion from their JSON representation
to a Go object, much like `RegisterCustomField`.
The difference is that whereas `RegisterCustomField` has global effect,
these typed fields only take effect in the call where the option was
explicitly passed.

`jws` and `jwe` does not have these options because
(1) JWS and JWE messages don't generally carry much in terms of custom data
(2) This requires changes in function signatures.

Only use these options when you absolutely need to. While it is a powerful
tool, they do have many caveats, and abusing these features will have
negative effects. See the documentation for details

v1.2.0 30 Apr 2021

This is a security fix release with minor incompatibilities from earlier version
Expand Down
26 changes: 26 additions & 0 deletions internal/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,29 @@ func EncodeAudience(enc *Encoder, aud []string) error {
}
return enc.Encode(val)
}

// DecodeCtx is an interface for objects that needs that extra something
// when decoding JSON into an object.
type DecodeCtx interface {
Registry() *Registry
}

// DecodeCtxContainer is used to differentiate objects that can carry extra
// decoding hints and those who can't.
type DecodeCtxContainer interface {
DecodeCtx() DecodeCtx
SetDecodeCtx(DecodeCtx)
}

// stock decodeCtx. should cover 80% of the cases
type decodeCtx struct {
registry *Registry
}

func NewDecodeCtx(r *Registry) DecodeCtx {
return &decodeCtx{registry: r}
}

func (dc *decodeCtx) Registry() *Registry {
return dc.registry
}
58 changes: 52 additions & 6 deletions jwk/ecdsa_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type ecdsaPrivateKey struct {
y []byte
privateParams map[string]interface{}
mu *sync.RWMutex
dc DecodeCtx
}

func NewECDSAPrivateKey() ECDSAPrivateKey {
Expand Down Expand Up @@ -412,6 +413,18 @@ func (k *ecdsaPrivateKey) Clone() (Key, error) {
return cloneKey(k)
}

func (k *ecdsaPrivateKey) DecodeCtx() DecodeCtx {
k.mu.RLock()
defer k.mu.RUnlock()
return k.dc
}

func (k *ecdsaPrivateKey) SetDecodeCtx(dc DecodeCtx) {
k.mu.Lock()
defer k.mu.Unlock()
k.dc = dc
}

func (h *ecdsaPrivateKey) UnmarshalJSON(buf []byte) error {
h.algorithm = nil
h.crv = nil
Expand Down Expand Up @@ -506,11 +519,21 @@ LOOP:
return errors.Wrapf(err, `failed to decode value for key %s`, ECDSAYKey)
}
default:
if dc := h.dc; dc != nil {
if localReg := dc.Registry(); localReg != nil {
decoded, err := localReg.Decode(dec, tok)
if err == nil {
h.setNoLock(tok, decoded)
continue
}
}
}
decoded, err := registry.Decode(dec, tok)
if err != nil {
return err
if err == nil {
h.setNoLock(tok, decoded)
continue
}
h.setNoLock(tok, decoded)
return errors.Wrapf(err, `could not decode field %s`, tok)
}
default:
return errors.Errorf(`invalid token %T`, tok)
Expand Down Expand Up @@ -619,6 +642,7 @@ type ecdsaPublicKey struct {
y []byte
privateParams map[string]interface{}
mu *sync.RWMutex
dc DecodeCtx
}

func NewECDSAPublicKey() ECDSAPublicKey {
Expand Down Expand Up @@ -960,6 +984,18 @@ func (k *ecdsaPublicKey) Clone() (Key, error) {
return cloneKey(k)
}

func (k *ecdsaPublicKey) DecodeCtx() DecodeCtx {
k.mu.RLock()
defer k.mu.RUnlock()
return k.dc
}

func (k *ecdsaPublicKey) SetDecodeCtx(dc DecodeCtx) {
k.mu.Lock()
defer k.mu.Unlock()
k.dc = dc
}

func (h *ecdsaPublicKey) UnmarshalJSON(buf []byte) error {
h.algorithm = nil
h.crv = nil
Expand Down Expand Up @@ -1049,11 +1085,21 @@ LOOP:
return errors.Wrapf(err, `failed to decode value for key %s`, ECDSAYKey)
}
default:
if dc := h.dc; dc != nil {
if localReg := dc.Registry(); localReg != nil {
decoded, err := localReg.Decode(dec, tok)
if err == nil {
h.setNoLock(tok, decoded)
continue
}
}
}
decoded, err := registry.Decode(dec, tok)
if err != nil {
return err
if err == nil {
h.setNoLock(tok, decoded)
continue
}
h.setNoLock(tok, decoded)
return errors.Wrapf(err, `could not decode field %s`, tok)
}
default:
return errors.Errorf(`invalid token %T`, tok)
Expand Down
5 changes: 5 additions & 0 deletions jwk/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/lestrrat-go/iter/arrayiter"
"github.com/lestrrat-go/iter/mapiter"
"github.com/lestrrat-go/jwx/internal/iter"
"github.com/lestrrat-go/jwx/internal/json"
)

// KeyUsageType is used to denote what this key should be used for
Expand Down Expand Up @@ -82,6 +83,7 @@ type Set interface {
type set struct {
keys []Key
mu sync.RWMutex
dc DecodeCtx
}

type HeaderVisitor = iter.MapVisitor
Expand All @@ -103,3 +105,6 @@ type PublicKeyer interface {
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}

type DecodeCtx = json.DecodeCtx
type KeyWithDecodeCtx = json.DecodeCtxContainer
32 changes: 29 additions & 3 deletions jwk/internal/cmd/genheader/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ func generateHeader(kt keyType) error {
}
fmt.Fprintf(&buf, "\nprivateParams map[string]interface{}")
fmt.Fprintf(&buf, "\nmu *sync.RWMutex")
fmt.Fprintf(&buf, "\ndc DecodeCtx")
fmt.Fprintf(&buf, "\n}")

fmt.Fprintf(&buf, "\n\nfunc New%[1]s() %[1]s {", ifName)
Expand Down Expand Up @@ -813,6 +814,18 @@ func generateHeader(kt keyType) error {
fmt.Fprintf(&buf, "\nreturn cloneKey(k)")
fmt.Fprintf(&buf, "\n}")

fmt.Fprintf(&buf, "\n\nfunc (k *%s) DecodeCtx() DecodeCtx {", structName)
fmt.Fprintf(&buf, "\nk.mu.RLock()")
fmt.Fprintf(&buf, "\ndefer k.mu.RUnlock()")
fmt.Fprintf(&buf, "\nreturn k.dc")
fmt.Fprintf(&buf, "\n}")

fmt.Fprintf(&buf, "\n\nfunc (k *%s) SetDecodeCtx(dc DecodeCtx) {", structName)
fmt.Fprintf(&buf, "\nk.mu.Lock()")
fmt.Fprintf(&buf, "\ndefer k.mu.Unlock()")
fmt.Fprintf(&buf, "\nk.dc = dc")
fmt.Fprintf(&buf, "\n}")

fmt.Fprintf(&buf, "\n\nfunc (h *%s) UnmarshalJSON(buf []byte) error {", structName)
for _, f := range ht.allHeaders {
fmt.Fprintf(&buf, "\nh.%s = nil", f.name)
Expand Down Expand Up @@ -876,11 +889,24 @@ func generateHeader(kt keyType) error {
}
}
fmt.Fprintf(&buf, "\ndefault:")
fmt.Fprintf(&buf, "\ndecoded, err := registry.Decode(dec, tok)")
fmt.Fprintf(&buf, "\nif err != nil {")
fmt.Fprintf(&buf, "\nreturn err")
// This looks like bad code, but we're unrolling things for maximum
// runtime efficiency
fmt.Fprintf(&buf, "\nif dc := h.dc; dc != nil {")
fmt.Fprintf(&buf, "\nif localReg := dc.Registry(); localReg != nil {")
fmt.Fprintf(&buf, "\ndecoded, err := localReg.Decode(dec, tok)")
fmt.Fprintf(&buf, "\nif err == nil {")
fmt.Fprintf(&buf, "\nh.setNoLock(tok, decoded)")
fmt.Fprintf(&buf, "\ncontinue")
fmt.Fprintf(&buf, "\n}")
fmt.Fprintf(&buf, "\n}")
fmt.Fprintf(&buf, "\n}")

fmt.Fprintf(&buf, "\ndecoded, err := registry.Decode(dec, tok)")
fmt.Fprintf(&buf, "\nif err == nil {")
fmt.Fprintf(&buf, "\nh.setNoLock(tok, decoded)")
fmt.Fprintf(&buf, "\ncontinue")
fmt.Fprintf(&buf, "\n}")
fmt.Fprintf(&buf, "\nreturn errors.Wrapf(err, `could not decode field %%s`, tok)")
fmt.Fprintf(&buf, "\n}")
fmt.Fprintf(&buf, "\ndefault:")
fmt.Fprintf(&buf, "\nreturn errors.Errorf(`invalid token %%T`, tok)")
Expand Down
40 changes: 40 additions & 0 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,23 @@ func parsePEMEncodedRawKey(src []byte) (interface{}, []byte, error) {
// parameters are performed, etc.
func ParseKey(data []byte, options ...ParseOption) (Key, error) {
var parsePEM bool
var localReg *json.Registry
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
case identPEM{}:
parsePEM = option.Value().(bool)
case identLocalRegistry{}:
// in reality you can only pass either withLocalRegistry or
// WithTypedField, but since withLocalRegistry is used only by us,
// we skip checking
localReg = option.Value().(*json.Registry)
case identTypedField{}:
pair := option.Value().(typedFieldPair)
if localReg == nil {
localReg = json.NewRegistry()
}
localReg.Register(pair.Name, pair.Value)
}
}

Expand Down Expand Up @@ -407,6 +419,16 @@ func ParseKey(data []byte, options ...ParseOption) (Key, error) {
return nil, errors.Errorf(`invalid key type from JSON (%s)`, hint.Kty)
}

if localReg != nil {
dcKey, ok := key.(KeyWithDecodeCtx)
if !ok {
return nil, errors.Errorf(`typed field was requested, but the key (%T) does not support DecodeCtx`, key)
}
dc := json.NewDecodeCtx(localReg)
dcKey.SetDecodeCtx(dc)
defer func() { dcKey.SetDecodeCtx(nil) }()
}

if err := json.Unmarshal(data, key); err != nil {
return nil, errors.Wrapf(err, `failed to unmarshal JSON into key (%T)`, key)
}
Expand All @@ -430,15 +452,23 @@ func ParseKey(data []byte, options ...ParseOption) (Key, error) {
// for `jwk.ParseKey()`.
func Parse(src []byte, options ...ParseOption) (Set, error) {
var parsePEM bool
var localReg *json.Registry
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
case identPEM{}:
parsePEM = option.Value().(bool)
case identTypedField{}:
pair := option.Value().(typedFieldPair)
if localReg == nil {
localReg = json.NewRegistry()
}
localReg.Register(pair.Name, pair.Value)
}
}

s := NewSet()

if parsePEM {
src = bytes.TrimSpace(src)
for len(src) > 0 {
Expand All @@ -456,6 +486,16 @@ func Parse(src []byte, options ...ParseOption) (Set, error) {
return s, nil
}

if localReg != nil {
dcKs, ok := s.(KeyWithDecodeCtx)
if !ok {
return nil, errors.Errorf(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s)
}
dc := json.NewDecodeCtx(localReg)
dcKs.SetDecodeCtx(dc)
defer func() { dcKs.SetDecodeCtx(nil) }()
}

if err := json.Unmarshal(src, s); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal JWK set")
}
Expand Down
Loading

0 comments on commit 63a21f9

Please sign in to comment.