diff --git a/cwt_test.go b/cwt_test.go index 3ee506f..2af4509 100644 --- a/cwt_test.go +++ b/cwt_test.go @@ -21,7 +21,11 @@ func ExampleCWTClaims() { cose.CWTClaimIssuer: "issuer.example", cose.CWTClaimSubject: "subject.example", } - msgToSign.Headers.Protected.SetCWTClaims(claims) + + claims, err := msgToSign.Headers.Protected.SetCWTClaims(claims) + if err != nil { + panic(err) + } msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1") diff --git a/headers.go b/headers.go index c4e980f..03c5a7e 100644 --- a/headers.go +++ b/headers.go @@ -115,15 +115,53 @@ func (h ProtectedHeader) SetType(typ any) (any, error) { // SetCWTClaims sets the CWT Claims value of the protected header. func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) { - iss, hasIss := claims[1] - if hasIss && !canTstr(iss) { - return claims, errors.New("cwt claim: iss: require tstr") - } - sub, hasSub := claims[2] - if hasSub && !canTstr(sub) { - return claims, errors.New("cwt claim: sub: require tstr") + for name, value := range claims { + switch name { + case CWTClaimIssuer: + if !canTstr(value) { + return claims, errors.New("cwt claim: iss: require tstr") + } + case CWTClaimSubject: + if !canTstr(value) { + return claims, errors.New("cwt claim: sub: require tstr") + } + case 3: + aud, hasAud := claims[name] + if hasAud && !canTstr(aud) { + return claims, errors.New("cwt claim: aud: require tstr") + } + case 4: + exp, hasExp := claims[name] + if hasExp && !canInt(exp) && !canFloat(exp) { + return claims, errors.New("cwt claim: exp: require int or float") + } + case 5: + nbf, hasNbf := claims[name] + if hasNbf && !canInt(nbf) && !canFloat(nbf) { + return claims, errors.New("cwt claim: nbf: require int or float") + } + case 6: + iat, hasIat := claims[name] + if hasIat && !canInt(iat) && !canFloat(iat) { + return claims, errors.New("cwt claim: iat: require int or float") + } + case 7: + cti, hasCti := claims[name] + if hasCti && !canBstr(cti) { + return claims, errors.New("cwt claim: cti: require tstr") + } + case 8: + cnf, hasCnf := claims[name] + if hasCnf && !canMap(cnf) { + return claims, errors.New("cwt claim: cnf: require map") + } + case 9: + scope, hasScope := claims[name] + if hasScope && !canBstr(scope) && !canTstr(scope) { + return claims, errors.New("cwt claim: scope: require bstr or tstr") + } + } } - // TODO: validate claims, other claims h[HeaderLabelCWTClaims] = claims return claims, nil } @@ -620,6 +658,15 @@ func canInt(v any) bool { return false } +// canFloat reports whether v can be used as a CBOR float type +func canFloat(v any) bool { + switch v.(type) { + case float32, float64: + return true + } + return false +} + // canTstr reports whether v can be used as a CBOR tstr type. func canTstr(v any) bool { _, ok := v.(string) @@ -632,6 +679,12 @@ func canBstr(v any) bool { return ok } +// canMap reports whether v can be used as a CBOR map type. +func canMap(v any) bool { + _, ok := v.(map[any]any) + return ok +} + // normalizeLabel tries to cast label into a int64 or a string. // Returns (nil, false) if the label type is not valid. func normalizeLabel(label any) (any, bool) {