Skip to content

Commit bd7d515

Browse files
committed
add EnrichToken to the Keys structure
1 parent 7bb8739 commit bd7d515

File tree

4 files changed

+143
-37
lines changed

4 files changed

+143
-37
lines changed

enrich.go

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,35 +95,8 @@ import (
9595
func Enrich(key PrivateKey, accessToken []byte, extraClaims any) ([]byte, error) {
9696
decodedToken, err := Decode(accessToken)
9797
if err != nil {
98-
return nil, fmt.Errorf("failed to parse original token header: %w", err)
98+
return nil, fmt.Errorf("failed to parse original token: %w", err)
9999
}
100100

101-
alg, err := decodedToken.Alg()
102-
if err != nil {
103-
return nil, fmt.Errorf("failed to determine algorithm from original token: %w", err)
104-
}
105-
106-
// Merge the original claims with extra claims.
107-
// No extra validation is needed since we assume the original token is valid.
108-
payload, err := Merge(decodedToken.Payload, extraClaims)
109-
if err != nil {
110-
return nil, fmt.Errorf("failed to merge claims: %w", err)
111-
}
112-
payload = Base64Encode(payload)
113-
114-
// Use the existing header from the original token.
115-
// This ensures the new token has the same header structure.
116-
existingHeader := Base64Encode(decodedToken.Header)
117-
headerPayload := joinParts(existingHeader, payload)
118-
119-
// The signature should be created using the same algorithm and key.
120-
// This ensures the new token is properly signed and can be verified with the new claims.
121-
signature, err := createSignature(alg, key, headerPayload)
122-
if err != nil {
123-
return nil, fmt.Errorf("encodeToken: signature: %w", err)
124-
}
125-
126-
// header.payload.signature
127-
token := joinParts(headerPayload, signature)
128-
return token, nil
101+
return decodedToken.Enrich(key, extraClaims)
129102
}

enrich_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func TestEnrichErrors(t *testing.T) {
199199
accessToken: []byte("aW52YWxpZC1oZWFkZXI.eyJzdWIiOiJ1c2VyMTIzIn0.signature"),
200200
extraClaims: map[string]any{"role": "admin"},
201201
expectError: true,
202-
errorContains: "failed to parse original token header",
202+
errorContains: "decode token: signature",
203203
},
204204
{
205205
name: "invalid extra claims",

kid_keys.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,45 @@ func (keys Keys) SignToken(kid string, claims any, opts ...SignOption) ([]byte,
768768
}, opts...)
769769
}
770770

771+
// EnrichToken creates a new JWT token by merging the original token's claims with additional claims.
772+
//
773+
// This method allows you to extend the original token's payload with new claims
774+
// while preserving the original token's header and signature structure.
775+
//
776+
// It uses the same algorithm and key as the original token to ensure the new token is valid.
777+
//
778+
// Parameters:
779+
// - key: PrivateKey used to sign the new token
780+
// - extraClaims: Map of additional claims to merge with the original token's payload
781+
//
782+
// Returns:
783+
// - []byte: New JWT token with merged claims
784+
// - error: Error if the original token's algorithm cannot be determined or if merging fails.
785+
//
786+
// Note: this only enrich plain tokens and not encrypted tokens.
787+
func (keys Keys) EnrichToken(plainToken []byte, extraClaims any) ([]byte, error) {
788+
decodedToken, err := Decode(plainToken)
789+
if err != nil {
790+
return nil, fmt.Errorf("failed to parse original token: %w", err)
791+
}
792+
793+
kid, err := decodedToken.Kid()
794+
if err != nil {
795+
return nil, fmt.Errorf("failed to get kid from token: %w", err)
796+
}
797+
798+
k, ok := keys.Get(kid)
799+
if !ok {
800+
return nil, ErrUnknownKid
801+
}
802+
803+
if k.Encrypt != nil || k.Decrypt != nil {
804+
return nil, fmt.Errorf("jwt: cannot enrich encrypted tokens")
805+
}
806+
807+
return decodedToken.Enrich(k.Private, extraClaims)
808+
}
809+
771810
// VerifyToken verifies a JWT token using automatic key selection and extracts claims.
772811
//
773812
// This method provides a high-level interface for JWT verification with multi-key support.

token.go

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -539,17 +539,17 @@ func Decode(token []byte) (*UnverifiedToken, error) {
539539

540540
headerDecoded, err := Base64Decode(header)
541541
if err != nil {
542-
return nil, err
542+
return nil, fmt.Errorf("decode token: header: %w", err)
543543
}
544544

545545
signatureDecoded, err := Base64Decode(signature)
546546
if err != nil {
547-
return nil, err
547+
return nil, fmt.Errorf("decode token: signature: %w", err)
548548
}
549549

550550
payload, err = Base64Decode(payload)
551551
if err != nil {
552-
return nil, err
552+
return nil, fmt.Errorf("decode token: payload: %w", err)
553553
}
554554

555555
tok := &UnverifiedToken{
@@ -607,6 +607,10 @@ type UnverifiedToken struct {
607607
Header []byte // Decoded JWT header JSON bytes
608608
Payload []byte // Decoded JWT payload JSON bytes
609609
Signature []byte // Decoded signature bytes (unverified)
610+
611+
// cached.
612+
alg Alg
613+
kid string
610614
}
611615

612616
// Claims unmarshals the JWT payload into the provided destination structure.
@@ -673,16 +677,106 @@ type headerWithAlg struct {
673677
// If the header is malformed or the algorithm is unknown, it returns an error.
674678
// This method is useful for determining the algorithm used to sign the token.
675679
func (t *UnverifiedToken) Alg() (Alg, error) {
680+
if t.alg != nil {
681+
return t.alg, nil
682+
}
683+
676684
// Extract algorithm from the original token header.
677-
var headerAlg headerWithAlg
678-
if err := json.Unmarshal(t.Header, &headerAlg); err != nil {
685+
var header headerWithAlg
686+
if err := json.Unmarshal(t.Header, &header); err != nil {
679687
return nil, fmt.Errorf("failed to parse original token header: %w", err)
680688
}
681689

682-
alg := parseAlg(headerAlg.Alg)
690+
alg := parseAlg(header.Alg)
683691
if alg == nil {
684-
return nil, fmt.Errorf("%w: %s", ErrTokenAlg, headerAlg.Alg)
692+
return nil, fmt.Errorf("%w: %s", ErrTokenAlg, header.Alg)
685693
}
686694

695+
t.alg = alg
687696
return alg, nil
688697
}
698+
699+
// Kid returns the Key ID (kid) from the original token header.
700+
// It extracts the "kid" field from the JWT header and returns its value.
701+
// If the header is malformed or the "kid" field is missing, it returns an error.
702+
func (t *UnverifiedToken) Kid() (string, error) {
703+
if t.kid != "" {
704+
return t.kid, nil
705+
}
706+
707+
var header HeaderWithKid
708+
if err := json.Unmarshal(t.Header, &header); err != nil {
709+
return "", fmt.Errorf("failed to parse original token header: %w", err)
710+
}
711+
712+
alg := parseAlg(header.Alg)
713+
if alg == nil {
714+
return "", fmt.Errorf("%w: %s", ErrTokenAlg, header.Alg)
715+
}
716+
717+
// KID catches both "kid" and "kid" fields.
718+
t.alg = alg
719+
t.kid = header.Kid
720+
return t.kid, nil
721+
}
722+
723+
// Enrich creates a new JWT token by merging the original token's claims with additional claims.
724+
//
725+
// This method allows you to extend the original token's payload with new claims
726+
// while preserving the original token's header and signature structure.
727+
//
728+
// It uses the same algorithm and key as the original token to ensure the new token is valid.
729+
//
730+
// Parameters:
731+
// - key: PrivateKey used to sign the new token
732+
// - extraClaims: Map of additional claims to merge with the original token's payload
733+
//
734+
// Returns:
735+
// - []byte: New JWT token with merged claims
736+
// - error: Error if the original token's algorithm cannot be determined or if merging fails
737+
//
738+
// Example usage:
739+
//
740+
// originalToken, _ := jwt.Decode(tokenBytes)
741+
// extraClaims := map[string]any{
742+
// "role": "admin",
743+
// "permissions": []string{"read", "write"},
744+
// }
745+
// newToken, err := originalToken.Enrich(signingKey, extraClaims)
746+
//
747+
// // if err != nil {
748+
// log.Fatalf("Failed to enrich token: %v", err)
749+
// }
750+
//
751+
// This newToken will have the original header and signature, but with the additional claims merged in.
752+
// Note: This method does not verify the original token; it assumes the original token is valid.
753+
func (t *UnverifiedToken) Enrich(key PrivateKey, extraClaims any) ([]byte, error) {
754+
alg, err := t.Alg()
755+
if err != nil {
756+
return nil, fmt.Errorf("failed to determine algorithm from original token: %w", err)
757+
}
758+
759+
// Merge the original claims with extra claims.
760+
// No extra validation is needed since we assume the original token is valid.
761+
payload, err := Merge(t.Payload, extraClaims)
762+
if err != nil {
763+
return nil, fmt.Errorf("failed to merge claims: %w", err)
764+
}
765+
payload = Base64Encode(payload)
766+
767+
// Use the existing header from the original token.
768+
// This ensures the new token has the same header structure.
769+
existingHeader := Base64Encode(t.Header)
770+
headerPayload := joinParts(existingHeader, payload)
771+
772+
// The signature should be created using the same algorithm and key.
773+
// This ensures the new token is properly signed and can be verified with the new claims.
774+
signature, err := createSignature(alg, key, headerPayload)
775+
if err != nil {
776+
return nil, fmt.Errorf("encodeToken: signature: %w", err)
777+
}
778+
779+
// header.payload.signature
780+
token := joinParts(headerPayload, signature)
781+
return token, nil
782+
}

0 commit comments

Comments
 (0)