Skip to content

Commit ddeb4b7

Browse files
authored
Merge pull request #21 from vejed/feature-transit-sign-verify
Feature transit sign verify
2 parents 9813995 + 6bfbd29 commit ddeb4b7

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed

transit.go

+167
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package vault
33
import (
44
"encoding/base64"
55
"errors"
6+
"fmt"
67
"net/http"
78
"net/url"
9+
"regexp"
10+
"strconv"
811

912
"github.com/hashicorp/vault/api"
1013
)
@@ -308,6 +311,170 @@ func (t *Transit) DecryptBatch(key string, opts TransitDecryptOptionsBatch) (*Tr
308311
return res, nil
309312
}
310313

314+
type TransitSignOptions struct {
315+
Input string `json:"input"`
316+
KeyVersion *int `json:"key_version,omitempty"`
317+
HashAlgorithm string `json:"hash_algorithm,omitempty"`
318+
Context string `json:"context,omitempty"`
319+
Prehashed bool `json:"prehashed,omitempty"`
320+
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
321+
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
322+
SaltLength string `json:"salt_length,omitempty"`
323+
}
324+
325+
type TransitSignResponse struct {
326+
Data struct {
327+
Signature string `json:"signature"`
328+
KeyVersion int `json:"key_version,omitempty"`
329+
} `json:"data"`
330+
}
331+
332+
func (t *Transit) Sign(key string, opts *TransitSignOptions) (*TransitSignResponse, error) {
333+
res := &TransitSignResponse{}
334+
335+
opts.Input = base64.StdEncoding.EncodeToString([]byte(opts.Input))
336+
337+
err := t.client.Write([]string{"v1", t.MountPoint, "sign", url.PathEscape(key)}, opts, res, nil)
338+
if err != nil {
339+
return nil, err
340+
}
341+
342+
return res, nil
343+
}
344+
345+
type TransitBatchSignInput struct {
346+
Input string `json:"input"`
347+
Context string `json:"context,omitempty"`
348+
}
349+
350+
type TransitBatchSignature struct {
351+
Signature string `json:"signature"`
352+
KeyVersion int `json:"key_version,omitempty"`
353+
}
354+
355+
type TransitSignOptionsBatch struct {
356+
BatchInput []TransitBatchSignInput `json:"batch_input"`
357+
KeyVersion *int `json:"key_version,omitempty"`
358+
HashAlgorithm string `json:"hash_algorithm,omitempty"`
359+
Prehashed bool `json:"prehashed,omitempty"`
360+
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
361+
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
362+
SaltLength string `json:"salt_length,omitempty"`
363+
}
364+
365+
type TransitSignResponseBatch struct {
366+
Data struct {
367+
BatchResults []TransitBatchSignature `json:"batch_results"`
368+
} `json:"data"`
369+
}
370+
371+
func (t *Transit) SignBatch(key string, opts *TransitSignOptionsBatch) (*TransitSignResponseBatch, error) {
372+
res := &TransitSignResponseBatch{}
373+
374+
for i := range opts.BatchInput {
375+
opts.BatchInput[i].Input = base64.StdEncoding.EncodeToString([]byte(opts.BatchInput[i].Input))
376+
}
377+
378+
err := t.client.Write([]string{"v1", t.MountPoint, "sign", url.PathEscape(key)}, opts, res, nil)
379+
if err != nil {
380+
return nil, err
381+
}
382+
383+
return res, nil
384+
}
385+
386+
type TransitVerifyOptions struct {
387+
Input string `json:"input"`
388+
Signature string `json:"signature"`
389+
HashAlgorithm string `json:"hash_algorithm,omitempty"`
390+
Context string `json:"context,omitempty"`
391+
Prehashed bool `json:"prehashed,omitempty"`
392+
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
393+
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
394+
SaltLength string `json:"salt_length,omitempty"`
395+
}
396+
397+
type TransitVerifyResponse struct {
398+
Data struct {
399+
Valid bool `json:"valid"`
400+
} `json:"data"`
401+
}
402+
403+
func (t *Transit) Verify(key string, opts *TransitVerifyOptions) (*TransitVerifyResponse, error) {
404+
res := &TransitVerifyResponse{}
405+
406+
opts.Input = base64.StdEncoding.EncodeToString([]byte(opts.Input))
407+
408+
err := t.client.Write([]string{"v1", t.MountPoint, "verify", url.PathEscape(key)}, opts, res, nil)
409+
if err != nil {
410+
return nil, err
411+
}
412+
413+
return res, nil
414+
}
415+
416+
type TransitBatchVerifyInput struct {
417+
Input string `json:"input"`
418+
Signature string `json:"signature"`
419+
Context string `json:"context,omitempty"`
420+
}
421+
422+
type TransitBatchVerifyData struct {
423+
Valid bool `json:"valid"`
424+
}
425+
426+
type TransitVerifyOptionsBatch struct {
427+
BatchInput []TransitBatchVerifyInput `json:"batch_input"`
428+
HashAlgorithm string `json:"hash_algorithm,omitempty"`
429+
Context string `json:"context,omitempty"`
430+
Prehashed bool `json:"prehashed,omitempty"`
431+
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
432+
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
433+
SaltLength string `json:"salt_length,omitempty"`
434+
}
435+
436+
type TransitVerifyResponseBatch struct {
437+
Data struct {
438+
BatchResults []TransitBatchVerifyData `json:"batch_results"`
439+
} `json:"data"`
440+
}
441+
442+
func (t *Transit) VerifyBatch(key string, opts *TransitVerifyOptionsBatch) (*TransitVerifyResponseBatch, error) {
443+
res := &TransitVerifyResponseBatch{}
444+
445+
for i := range opts.BatchInput {
446+
opts.BatchInput[i].Input = base64.StdEncoding.EncodeToString([]byte(opts.BatchInput[i].Input))
447+
}
448+
449+
err := t.client.Write([]string{"v1", t.MountPoint, "verify", url.PathEscape(key)}, opts, res, nil)
450+
if err != nil {
451+
return nil, err
452+
}
453+
454+
return res, nil
455+
}
456+
457+
// DecodeCipherText gets payload from vault ciphertext format (removes "vault:v<ver>:" prefix)
458+
func DecodeCipherText(vaultCipherText string) (string, int, error) {
459+
regex := regexp.MustCompile(`^vault:v(\d+):(.+)$`)
460+
matches := regex.FindStringSubmatch(vaultCipherText)
461+
if len(matches) != 3 {
462+
return "", 0, errors.New("invalid vault ciphertext format")
463+
}
464+
465+
keyVersion, err := strconv.Atoi(matches[1])
466+
if err != nil {
467+
return "", 0, errors.New("can't parse key version")
468+
}
469+
470+
return matches[2], keyVersion, nil
471+
}
472+
473+
// EncodeCipherText encodes payload to vault ciphertext format (adda "vault:v<ver>:" prefix)
474+
func EncodeCipherText(cipherText string, keyVersion int) string {
475+
return fmt.Sprintf("vault:v%d:%s", keyVersion, cipherText)
476+
}
477+
311478
func (t *Transit) mapError(err error) error {
312479
resErr := &api.ResponseError{}
313480
if errors.As(err, &resErr) {

transit_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,67 @@ func (s *TransitTestSuite) TestCreateKeyThatDoesAlreadyExist() {
238238
err = s.client.Create("testCeateKeyThatDoesAlreadyExist", &TransitCreateOptions{})
239239
require.NoError(s.T(), err)
240240
}
241+
242+
func (s *TransitTestSuite) TestSignVerify() {
243+
err := s.client.Create("testSignVerify", &TransitCreateOptions{Type: "rsa-2048"})
244+
require.NoError(s.T(), err)
245+
246+
text := "test"
247+
248+
signRes, err := s.client.Sign("testSignVerify", &TransitSignOptions{
249+
Input: text,
250+
})
251+
require.NoError(s.T(), err)
252+
253+
verifyRes, err := s.client.Verify("testSignVerify", &TransitVerifyOptions{
254+
Input: text,
255+
Signature: signRes.Data.Signature,
256+
})
257+
require.NoError(s.T(), err)
258+
259+
s.True(verifyRes.Data.Valid)
260+
}
261+
262+
func (s *TransitTestSuite) TestSignVerifyBatch() {
263+
err := s.client.Create("testSignVerify", &TransitCreateOptions{Type: "rsa-2048"})
264+
require.NoError(s.T(), err)
265+
266+
text1 := "test1"
267+
text2 := "test2"
268+
269+
signRes, err := s.client.SignBatch("testSignVerify", &TransitSignOptionsBatch{
270+
BatchInput: []TransitBatchSignInput{
271+
{Input: text1},
272+
{Input: text2},
273+
},
274+
})
275+
require.NoError(s.T(), err)
276+
277+
verifyRes, err := s.client.VerifyBatch("testSignVerify", &TransitVerifyOptionsBatch{
278+
BatchInput: []TransitBatchVerifyInput{
279+
{Input: text1, Signature: signRes.Data.BatchResults[0].Signature},
280+
{Input: text2, Signature: signRes.Data.BatchResults[1].Signature},
281+
},
282+
})
283+
require.NoError(s.T(), err)
284+
285+
s.True(verifyRes.Data.BatchResults[0].Valid)
286+
s.True(verifyRes.Data.BatchResults[1].Valid)
287+
}
288+
289+
func (s *TransitTestSuite) TestDecodeCipherText() {
290+
dec, ver, err := DecodeCipherText("vault:v123:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
291+
require.NoError(s.T(), err)
292+
s.Equal("SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", dec)
293+
s.Equal(123, ver)
294+
}
295+
296+
func (s *TransitTestSuite) TestDecodeCipherTextError() {
297+
_, _, err := DecodeCipherText("vault:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
298+
s.NotNil(err)
299+
}
300+
301+
func (s *TransitTestSuite) TestEncodeCipherText() {
302+
enc := EncodeCipherText("SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", 123)
303+
s.Equal("vault:v123:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", enc)
304+
}

0 commit comments

Comments
 (0)