From 204faa871b1ed52751b51d7f0c77b00edd940c8e Mon Sep 17 00:00:00 2001 From: stuart-warren Date: Wed, 26 Jun 2024 13:02:27 +0100 Subject: [PATCH] feat: add initial support for fleet provisioning fixes #171 --- examples/fleetprovisioning/.gitignore | 6 + examples/fleetprovisioning/main.go | 163 +++++++++ fleetprovisioning/fleetprovisioning.go | 358 ++++++++++++++++++++ fleetprovisioning/fleetprovisioning_test.go | 211 ++++++++++++ fleetprovisioning/state.go | 59 ++++ fleetprovisioning/state_test.go | 34 ++ 6 files changed, 831 insertions(+) create mode 100644 examples/fleetprovisioning/.gitignore create mode 100644 examples/fleetprovisioning/main.go create mode 100644 fleetprovisioning/fleetprovisioning.go create mode 100644 fleetprovisioning/fleetprovisioning_test.go create mode 100644 fleetprovisioning/state.go create mode 100644 fleetprovisioning/state_test.go diff --git a/examples/fleetprovisioning/.gitignore b/examples/fleetprovisioning/.gitignore new file mode 100644 index 00000000..2f19d61c --- /dev/null +++ b/examples/fleetprovisioning/.gitignore @@ -0,0 +1,6 @@ +fleetprovisioning +root-CA.crt +certificate.pem.crt +certificate-run.pem.crt +private.pem.key +private-run.pem.key diff --git a/examples/fleetprovisioning/main.go b/examples/fleetprovisioning/main.go new file mode 100644 index 00000000..54c71180 --- /dev/null +++ b/examples/fleetprovisioning/main.go @@ -0,0 +1,163 @@ +// Copyright 2020 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/at-wat/mqtt-go" + "github.com/seqsense/aws-iot-device-sdk-go/v6" + "github.com/seqsense/aws-iot-device-sdk-go/v6/fleetprovisioning" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if len(os.Args) != 4 { + println("usage: fleetprovisioning AWS_IOT_ENDPOINT CLIENT_ID TEMPLATE_NAME") + println("") + println("This example registers a new thing in AWS IoT.") + println("TEMPLATE_NAME must be created to your account of AWS IoT beforehand.") + println("") + println("Following files must be placed under the current working directory:") + println(" root-CA.crt: root CA certificate") + println(" certificate.pem.crt: client certificate with assigned policy to provision devices") + println(" - see https://docs.aws.amazon.com/iot/latest/developerguide/provision-wo-cert.html") + println(" private.pem.key: private key associated to above certificate") + os.Exit(1) + } + host := os.Args[1] + thingName := os.Args[2] + templateName := os.Args[3] + + for _, file := range []string{ + "root-CA.crt", + "certificate.pem.crt", + "private.pem.key", + } { + _, err := os.Stat(file) + if os.IsNotExist(err) { + println(file, "not found") + os.Exit(1) + } + if err != nil { + println(file, err) + } + } + + cli, err := awsiotdev.New( + thingName, + &mqtt.URLDialer{ + URL: fmt.Sprintf("mqtts://%s:8883", host), + Options: []mqtt.DialOption{ + mqtt.WithTLSCertFiles( + host, + "root-CA.crt", + "certificate.pem.crt", + "private.pem.key", + ), + mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) { + fmt.Printf("%s: %v\n", s, err) + }), + }, + }, + mqtt.WithReconnectWait(500*time.Millisecond, 2*time.Second), + ) + if err != nil { + panic(err) + } + + // Multiplex message handler to route messages to multiple features. + var mux mqtt.ServeMux + cli.Handle(&mux) + + f, err := fleetprovisioning.New(ctx, cli, templateName) + if err != nil { + panic(err) + } + f.OnError(func(err error) { + fmt.Printf("async error: %v\n", err) + }) + mux.Handle("#", f) // Handle messages. + + if _, err := cli.Connect(ctx, + thingName, + mqtt.WithKeepAlive(30), + ); err != nil { + panic(err) + } + println("Connected to", host) + + _, _, _, certToken, err := f.CreateKeysAndCertificate(ctx) + if err != nil { + panic(err) + } + println("Created keys and certificate") + + cert, err := os.OpenFile("certificate-run.pem.crt", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic(err) + } + key, err := os.OpenFile("private-run.pem.key", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic(err) + } + f.WriteCertificate(cert) + f.WritePrivateKey(key) + cert.Close() + key.Close() + thingName, _, err = f.RegisterThing(ctx, certToken, map[string]string{"SerialNumber": "124"}) + if err != nil { + panic(err) + } + println("Registered thing:", thingName) + + cli.Disconnect(ctx) + + cli, err = awsiotdev.New( + thingName, + &mqtt.URLDialer{ + URL: fmt.Sprintf("mqtts://%s:8883", host), + Options: []mqtt.DialOption{ + mqtt.WithTLSCertFiles( + host, + "root-CA.crt", + "certificate-run.pem.crt", + "private-run.pem.key", + ), + mqtt.WithConnStateHandler(func(s mqtt.ConnState, err error) { + fmt.Printf("%s: %v\n", s, err) + }), + }, + }, + mqtt.WithReconnectWait(500*time.Millisecond, 2*time.Second), + ) + + cli.Handle(&mux) + + if _, err := cli.Connect(ctx, + thingName, + mqtt.WithKeepAlive(30), + ); err != nil { + panic(err) + } + println("Connected to", host) + + select {} +} diff --git a/fleetprovisioning/fleetprovisioning.go b/fleetprovisioning/fleetprovisioning.go new file mode 100644 index 00000000..3e548b97 --- /dev/null +++ b/fleetprovisioning/fleetprovisioning.go @@ -0,0 +1,358 @@ +// Copyright 2020 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fleetprovisioning + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "sync" + + "github.com/at-wat/mqtt-go" + awsiotdev "github.com/seqsense/aws-iot-device-sdk-go/v6" + "github.com/seqsense/aws-iot-device-sdk-go/v6/internal/ioterr" +) + +type FleetProvisioning interface { + mqtt.Handler + // CreateKeysAndCertificate creates keys and certificate. + CreateKeysAndCertificate(ctx context.Context) (string, *x509.Certificate, *rsa.PrivateKey, string, error) + // CreateCertificateFromCSR creates certificate from CSR. + CreateCertificateFromCSR(ctx context.Context, csr *x509.CertificateRequest) (string, *x509.Certificate, string, error) + // RegisterThing registers a thing. + RegisterThing(ctx context.Context, certToken string, parameters map[string]string) (string, map[string]string, error) + // WriteCertificate writes PEM certificate to writer. + WriteCertificate(writer io.Writer) (int, error) + // WritePrivateKey writes PEM private key to writer. + WritePrivateKey(writer io.Writer) (int, error) + // OnError sets handler of asyncronous errors. + OnError(func(err error)) +} + +type fleetProvisioning struct { + mqtt.ServeMux + cli mqtt.Client + templateName string + pemCert []byte + pemKey []byte + + mu sync.Mutex + onError func(err error) + + chResps map[string]chan interface{} +} + +const ( + // Tokens for response handling. + tokenCreateCertificate = "certificates/create" + tokenCreateFromCSR = "certificates/create-from-csr" + tokenRegisterThing = "provisioning-templates/provision" +) + +func New(ctx context.Context, cli awsiotdev.Device, templateName string) (FleetProvisioning, error) { + f := &fleetProvisioning{ + cli: cli, + chResps: make(map[string]chan interface{}), + templateName: templateName, + } + for _, sub := range []struct { + topic string + hand mqtt.Handler + }{ + {"$aws/certificates/create/json/accepted", mqtt.HandlerFunc(f.certificateCreateAccepted)}, + {"$aws/certificates/create/json/rejected", mqtt.HandlerFunc(f.rejectedCertificateCreate)}, + {"$aws/certificates/create-from-csr/json/accepted", mqtt.HandlerFunc(f.certificateCreateFromCSRAccepted)}, + {"$aws/certificates/create-from-csr/json/rejected", mqtt.HandlerFunc(f.rejectedCertificatesCreateFromCSR)}, + {"$aws/provisioning-templates/" + templateName + "/provision/json/accepted", mqtt.HandlerFunc(f.registerThingAccepted)}, + {"$aws/provisioning-templates/" + templateName + "/provision/json/rejected", mqtt.HandlerFunc(f.rejectedProvisioningTemplatesProvision)}, + } { + if err := f.ServeMux.Handle(sub.topic, sub.hand); err != nil { + return nil, ioterr.New(err, "registering message handlers") + } + } + + _, err := cli.Subscribe(ctx, + mqtt.Subscription{Topic: "$aws/certificates/create/json/accepted", QoS: mqtt.QoS1}, + mqtt.Subscription{Topic: "$aws/certificates/create/json/rejected", QoS: mqtt.QoS1}, + mqtt.Subscription{Topic: "$aws/certificates/create-from-csr/json/accepted", QoS: mqtt.QoS1}, + mqtt.Subscription{Topic: "$aws/certificates/create-from-csr/json/rejected", QoS: mqtt.QoS1}, + mqtt.Subscription{Topic: "$aws/provisioning-templates/" + templateName + "/provision/json/accepted", QoS: mqtt.QoS1}, + mqtt.Subscription{Topic: "$aws/provisioning-templates/" + templateName + "/provision/json/rejected", QoS: mqtt.QoS1}, + ) + if err != nil { + return nil, ioterr.New(err, "subscribing to topics") + } + return f, nil +} + +func parseCertificate(certPem []byte) (*x509.Certificate, error) { + block, _ := pem.Decode(certPem) + if block == nil || block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("failed to decode PEM block") + } + return x509.ParseCertificate(block.Bytes) +} + +func parsePrivateKey(keyPem []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(keyPem) + if block == nil || (block.Type != "PRIVATE KEY" && block.Type != "RSA PRIVATE KEY") { + return nil, fmt.Errorf("failed to decode PEM block") + } + key, _ := x509.ParsePKCS8PrivateKey(block.Bytes) + pkcs8, ok := key.(*rsa.PrivateKey) + if ok { + return pkcs8, nil + } + return x509.ParsePKCS1PrivateKey(block.Bytes) +} + +// CreateKeysAndCertificate creates keys and certificate. +func (f *fleetProvisioning) CreateKeysAndCertificate(ctx context.Context) (certID string, cert *x509.Certificate, privateKey *rsa.PrivateKey, certToken string, err error) { + errReturn := func(err error) (string, *x509.Certificate, *rsa.PrivateKey, string, error) { + return "", &x509.Certificate{}, &rsa.PrivateKey{}, "", err + } + token := tokenCreateCertificate + ch := make(chan interface{}, 1) + f.mu.Lock() + f.chResps[token] = ch + f.mu.Unlock() + defer func() { + f.mu.Lock() + delete(f.chResps, token) + f.mu.Unlock() + }() + if err := f.cli.Publish(ctx, &mqtt.Message{ + Topic: "$aws/certificates/create/json", + QoS: mqtt.QoS1, + Payload: json.RawMessage(fmt.Sprintf(`{}`)), + }); err != nil { + return errReturn(ioterr.New(err, "publishing create certificate request")) + } + + select { + case <-ctx.Done(): + return errReturn(ioterr.New(ctx.Err(), "waiting for response")) + case resp := <-ch: + switch r := resp.(type) { + case *CertificateCreateResponse: + cert, err := parseCertificate([]byte(r.CertificatePEM)) + if err != nil { + return errReturn(ioterr.New(err, "parsing certificate")) + } + key, err := parsePrivateKey([]byte(r.PrivateKey)) + if err != nil { + return errReturn(ioterr.New(err, "parsing private key")) + } + f.mu.Lock() + f.pemCert = []byte(r.CertificatePEM) + f.pemKey = []byte(r.PrivateKey) + f.mu.Unlock() + return r.CertificateID, cert, key, r.CertificateOwnershipToken, nil + case *ErrorResponse: + return errReturn(ioterr.New(fmt.Errorf("%d: %s", r.Code, r.Message), "creating certificate")) + default: + return errReturn(ioterr.New(fmt.Errorf("unexpected response: %T", r), "parsing response")) + } + } +} + +func (f *fleetProvisioning) CreateCertificateFromCSR(ctx context.Context, csr *x509.CertificateRequest) (certID string, cert *x509.Certificate, certToken string, err error) { + errReturn := func(err error) (string, *x509.Certificate, string, error) { + return "", &x509.Certificate{}, "", err + } + token := tokenCreateFromCSR + ch := make(chan interface{}, 1) + f.mu.Lock() + f.chResps[token] = ch + f.mu.Unlock() + defer func() { + f.mu.Lock() + delete(f.chResps, token) + f.mu.Unlock() + }() + pemCSR := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr.Raw}) + req := &CertificateCreateFromCSRRequest{ + CertificateSigningRequest: string(pemCSR), + } + jsonReq, err := json.Marshal(req) + if err != nil { + return errReturn(ioterr.New(err, "marshaling create certificate from CSR request")) + } + if err := f.cli.Publish(ctx, &mqtt.Message{ + Topic: "$aws/certificates/create-from-csr/json", + QoS: mqtt.QoS1, + Payload: jsonReq, + }); err != nil { + return errReturn(ioterr.New(err, "publishing create certificate request")) + } + + select { + case <-ctx.Done(): + return errReturn(ioterr.New(ctx.Err(), "waiting for response")) + case resp := <-ch: + switch r := resp.(type) { + case *CertificateCreateFromCSRResponse: + cert, err := parseCertificate([]byte(r.CertificatePEM)) + if err != nil { + return errReturn(ioterr.New(err, "parsing certificate")) + } + f.mu.Lock() + f.pemCert = []byte(r.CertificatePEM) + f.mu.Unlock() + return r.CertificateID, cert, r.CertificateOwnershipToken, nil + case *ErrorResponse: + return errReturn(ioterr.New(fmt.Errorf("%d: %s", r.Code, r.Message), "creating certificate")) + default: + return errReturn(ioterr.New(fmt.Errorf("unexpected response: %T", r), "parsing response")) + } + } +} + +func (f *fleetProvisioning) RegisterThing(ctx context.Context, certToken string, parameters map[string]string) (thingName string, deviceConfig map[string]string, err error) { + errReturn := func(err error) (string, map[string]string, error) { + return "", map[string]string{}, err + } + token := tokenRegisterThing + ch := make(chan interface{}, 1) + f.mu.Lock() + f.chResps[token] = ch + f.mu.Unlock() + defer func() { + f.mu.Lock() + delete(f.chResps, token) + f.mu.Unlock() + }() + jsonReq, err := json.Marshal(&RegisterThingRequest{ + CertificateOwnershipToken: certToken, + Parameters: parameters, + }) + if err != nil { + return errReturn(ioterr.New(err, "marshaling register thing request")) + } + if err := f.cli.Publish(ctx, &mqtt.Message{ + Topic: "$aws/provisioning-templates/" + f.templateName + "/provision/json", + QoS: mqtt.QoS1, + Payload: jsonReq, + }); err != nil { + return errReturn(ioterr.New(err, "publishing register thing request")) + } + + select { + case <-ctx.Done(): + return errReturn(ioterr.New(ctx.Err(), "waiting for response")) + case resp := <-ch: + switch r := resp.(type) { + case *RegisterThingResponse: + return r.ThingName, r.DeviceConfiguration, nil + case *ErrorResponse: + return errReturn(ioterr.New(fmt.Errorf("%d: %s", r.Code, r.Message), "register thing failed")) + default: + return errReturn(ioterr.New(fmt.Errorf("unexpected response: %T", r), "parsing response")) + } + } +} + +func (p *fleetProvisioning) WriteCertificate(writer io.Writer) (int, error) { + if p.pemCert == nil { + return 0, fmt.Errorf("certificate not created") + } + return writer.Write(p.pemCert) +} + +func (p *fleetProvisioning) WritePrivateKey(writer io.Writer) (int, error) { + if p.pemKey == nil { + return 0, fmt.Errorf("private key not created") + } + return writer.Write(p.pemKey) +} + +func (f *fleetProvisioning) certificateCreateAccepted(msg *mqtt.Message) { + r := &CertificateCreateResponse{} + if err := json.Unmarshal(msg.Payload, r); err != nil { + fmt.Printf(string(msg.Payload)) + f.handleError(ioterr.New(err, "unmarshaling certificate create response")) + return + } + f.handleResponse(r, tokenCreateCertificate) +} + +func (f *fleetProvisioning) certificateCreateFromCSRAccepted(msg *mqtt.Message) { + r := &CertificateCreateFromCSRResponse{} + if err := json.Unmarshal(msg.Payload, r); err != nil { + f.handleError(ioterr.New(err, "unmarshaling certificate create from CSR response")) + return + } + f.handleResponse(r, tokenCreateFromCSR) +} + +func (f *fleetProvisioning) registerThingAccepted(msg *mqtt.Message) { + r := &RegisterThingResponse{} + if err := json.Unmarshal(msg.Payload, r); err != nil { + f.handleError(ioterr.New(err, "unmarshaling register thing response")) + return + } + f.handleResponse(r, tokenRegisterThing) +} + +func (f *fleetProvisioning) rejectedCertificateCreate(msg *mqtt.Message) { + f.rejected(msg, tokenCreateCertificate) +} +func (f *fleetProvisioning) rejectedCertificatesCreateFromCSR(msg *mqtt.Message) { + f.rejected(msg, tokenCreateFromCSR) +} +func (f *fleetProvisioning) rejectedProvisioningTemplatesProvision(msg *mqtt.Message) { + f.rejected(msg, tokenRegisterThing) +} + +func (f *fleetProvisioning) rejected(msg *mqtt.Message, token string) { + e := &ErrorResponse{} + if err := json.Unmarshal(msg.Payload, e); err != nil { + f.handleError(ioterr.New(err, "unmarshaling error response")) + return + } + f.handleResponse(e, token) +} + +func (f *fleetProvisioning) handleResponse(r interface{}, token string) { + f.mu.Lock() + ch, ok := f.chResps[token] + f.mu.Unlock() + if !ok { + return + } + select { + case ch <- r: + default: + } +} + +func (f *fleetProvisioning) OnError(cb func(err error)) { + f.mu.Lock() + f.onError = cb + f.mu.Unlock() +} + +func (f *fleetProvisioning) handleError(err error) { + f.mu.Lock() + cb := f.onError + f.mu.Unlock() + if cb != nil { + cb(err) + } +} diff --git a/fleetprovisioning/fleetprovisioning_test.go b/fleetprovisioning/fleetprovisioning_test.go new file mode 100644 index 00000000..1ac52d26 --- /dev/null +++ b/fleetprovisioning/fleetprovisioning_test.go @@ -0,0 +1,211 @@ +// Copyright 2020 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fleetprovisioning + +import ( + "bytes" + "context" + "crypto/x509" + "errors" + "strings" + "testing" + "time" + + "github.com/at-wat/mqtt-go" + mockmqtt "github.com/at-wat/mqtt-go/mock" + + "github.com/seqsense/aws-iot-device-sdk-go/v6/internal/ioterr" +) + +type mockClient interface { + mqtt.Client + mqtt.Handler +} + +type mockDevice struct { + mockClient + mqtt.Retryer +} + +func (d *mockDevice) ThingName() string { + return "test" +} + +func TestNew(t *testing.T) { + errDummy := errors.New("dummy error") + + t.Run("SubscribeError", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + cli := &mockDevice{mockClient: &mockmqtt.Client{ + SubscribeFn: func(ctx context.Context, subs ...mqtt.Subscription) ([]mqtt.Subscription, error) { + return nil, errDummy + }, + }} + _, err := New(ctx, cli, "template") + var ie *ioterr.Error + if !errors.As(err, &ie) { + t.Errorf("Expected error type: %T, got: %T", ie, err) + } + if !errors.Is(err, errDummy) { + t.Errorf("Expected error: %v, got: %v", errDummy, err) + } + }) +} + +const ( + TEST_KEY = `-----BEGIN PRIVATE KEY-----\nMIIBVAIBADANBgkqhkiG9w0BAQEFAASCAT4wggE6AgEAAkEA+bT+piDo8Um/LUAe\nDrNBLx7qihAsThuvCn//SPXsoICofAgjAxtu2n1nVEb5ZnxKh8P72KXC4wE7G97u\n0u1tvQIDAQABAkAgiusJAY76Ky9EGXARYGElX/UXCyaLA2abirTdcFdnTzs+19nX\n4OI/jBbiEd76yjfW6RkdAN6aPezDkRnlspSBAiEA/oqwVbNHJc8D5uzARMENt34j\nYVGimAl+I3VwjaxBkd0CIQD7IzccW+L35RGpdnXMooTD+tKRaZv7XnTx5jfI3Sm9\nYQIhAK8idZlBtN5KxYCJvPCRdAKgg29eX+UEAwoar8qKjsLxAiBdv/KlyoN7CO9D\n9K3a+1xWkL6ke+k3uDYty0RN3onjYQIgbGhXwTcEzLoh3OCuuF2KJVJ5gmBANKrP\n+6AHJFrGjHY=\n-----END PRIVATE KEY-----` + TEST_PUBKEY = `-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAPm0/qYg6PFJvy1AHg6zQS8e6ooQLE4b\nrwp//0j17KCAqHwIIwMbbtp9Z1RG+WZ8SofD+9ilwuMBOxve7tLtbb0CAwEAAQ==\n-----END PUBLIC KEY-----` + TEST_CSR = `-----BEGIN CERTIFICATE REQUEST-----\nMIIBDjCBuQIBADBUMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEh\nMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMQ0wCwYDVQQDDAR0ZXN0\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAPm0/qYg6PFJvy1AHg6zQS8e6ooQLE4b\nrwp//0j17KCAqHwIIwMbbtp9Z1RG+WZ8SofD+9ilwuMBOxve7tLtbb0CAwEAAaAA\nMA0GCSqGSIb3DQEBCwUAA0EAwM1MsaPCJIadGpsnIbqtNSSvg+F331yvja3kMr0R\nCdtdQ2uymQ5hxv/Qtg30WLdyXQ3XRWfRh1Fb/mkJMG1DGQ==\n-----END CERTIFICATE REQUEST-----` + TEST_CERT = `-----BEGIN CERTIFICATE-----\nMIIBpzCCAVECFBcYnnOAmVrDMGVQQNTbHDiDuNinMA0GCSqGSIb3DQEBCwUAMFQx\nCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl\ncm5ldCBXaWRnaXRzIFB0eSBMdGQxDTALBgNVBAMMBHRlc3QwIBcNMjQwNjI0MTQz\nOTEyWhgPMjA1MTExMDkxNDM5MTJaMFQxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApT\nb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDTAL\nBgNVBAMMBHRlc3QwXDANBgkqhkiG9w0BAQEFAANLADBIAkEA+bT+piDo8Um/LUAe\nDrNBLx7qihAsThuvCn//SPXsoICofAgjAxtu2n1nVEb5ZnxKh8P72KXC4wE7G97u\n0u1tvQIDAQABMA0GCSqGSIb3DQEBCwUAA0EAZ+GS/x322f/APuFi1WQHrp5Ebe+S\nTwVHzOhhLLF5xYOp/KlXAk2ObJEov88McOYJG7A3Oc+qI739EX+oGmM3MQ==\n-----END CERTIFICATE-----` + TEST_CERT_ACCEPTED = `{"certificateId":"test","certificatePem":"` + TEST_CERT + `","privateKey":"` + TEST_KEY + `","certificateOwnershipToken":"token"}` + TEST_CSR_ACCEPTED = `{"certificateId":"test","certificatePem":"` + TEST_CERT + `","certificateOwnershipToken":"token"}` + TEST_THING_NAME = "test" + TEST_REGISTER_ACCEPTED = `{"deviceConfiguration":{},"thingName":"` + TEST_THING_NAME + `"}` +) + +func TestHandlers(t *testing.T) { + t.Run("CreateCertificateAccepted", func(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var f FleetProvisioning + cli := &mockDevice{mockClient: &mockmqtt.Client{PublishFn: func(ctx context.Context, msg *mqtt.Message) error { + f.Serve(&mqtt.Message{ + Topic: "$aws/certificates/create/json/accepted", + Payload: []byte(TEST_CERT_ACCEPTED), + }) + return nil + }}} + f, err := New(ctx, cli, "test") + if err != nil { + t.Fatal(err) + } + cli.Handle(f) + _, _, _, _, err = f.CreateKeysAndCertificate(ctx) + if err != nil { + t.Fatal(err) + } + certificate := bytes.NewBuffer(nil) + f.WriteCertificate(certificate) + cert := strings.ReplaceAll(certificate.String(), "\n", "\\n") + if cert != TEST_CERT { + t.Errorf("Expected certificate: %s, got: %s", TEST_CERT, certificate.String()) + } + + }) + t.Run("CreateCertificateFromCSRAccepted", func(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var f FleetProvisioning + cli := &mockDevice{mockClient: &mockmqtt.Client{PublishFn: func(ctx context.Context, msg *mqtt.Message) error { + f.Serve(&mqtt.Message{ + Topic: "$aws/certificates/create-from-csr/json/accepted", + Payload: []byte(TEST_CSR_ACCEPTED), + }) + return nil + }}} + f, err := New(ctx, cli, "test") + if err != nil { + t.Fatal(err) + } + cli.Handle(f) + + _, _, _, err = f.CreateCertificateFromCSR(ctx, &x509.CertificateRequest{}) + if err != nil { + t.Fatal(err) + } + certificate := bytes.NewBuffer(nil) + f.WriteCertificate(certificate) + cert := strings.ReplaceAll(certificate.String(), "\n", "\\n") + if cert != TEST_CERT { + t.Errorf("Expected certificate: %s, got: %s", TEST_CERT, certificate.String()) + } + }) + t.Run("RegisterThingAccepted", func(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + var f FleetProvisioning + cli := &mockDevice{mockClient: &mockmqtt.Client{PublishFn: func(ctx context.Context, msg *mqtt.Message) error { + f.Serve(&mqtt.Message{ + Topic: "$aws/provisioning-templates/test/provision/json/accepted", + Payload: []byte(TEST_REGISTER_ACCEPTED), + }) + return nil + }}} + f, err := New(ctx, cli, "test") + if err != nil { + t.Fatal(err) + } + cli.Handle(f) + thingName, _, err := f.RegisterThing(ctx, "", map[string]string{}) + if err != nil { + t.Fatal(err) + } + if thingName != TEST_THING_NAME { + t.Errorf("Expected thingName: %s, got: %s", TEST_THING_NAME, thingName) + } + + }) +} + +func TestHandlers_InvalidResponse(t *testing.T) { + for _, topic := range []string{ + "$aws/certificates/create/json/accepted", + "$aws/certificates/create/json/rejected", + "$aws/certificates/create-from-csr/json/accepted", + "$aws/certificates/create-from-csr/json/rejected", + "$aws/provisioning-templates/test/provision/json/accepted", + "$aws/provisioning-templates/test/provision/json/rejected", + } { + topic := topic + t.Run(topic, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + var cli *mockDevice + cli = &mockDevice{mockClient: &mockmqtt.Client{}} + + f, err := New(ctx, cli, "test") + if err != nil { + t.Fatal(err) + } + chErr := make(chan error, 1) + f.OnError(func(err error) { chErr <- err }) + cli.Handle(f) + + cli.Serve(&mqtt.Message{ + Topic: topic, + Payload: []byte{0xff, 0xff, 0xff}, + }) + + select { + case err := <-chErr: + var ie *ioterr.Error + if !errors.As(err, &ie) { + t.Errorf("Expected error type: %T, got: %T", ie, err) + } + case <-ctx.Done(): + t.Fatal("Timeout") + } + }) + } +} diff --git a/fleetprovisioning/state.go b/fleetprovisioning/state.go new file mode 100644 index 00000000..ee96a292 --- /dev/null +++ b/fleetprovisioning/state.go @@ -0,0 +1,59 @@ +// Copyright 2020 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fleetprovisioning + +import ( + "fmt" +) + +// ErrorResponse represents error response from AWS IoT. +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` + ClientToken string `json:"clientToken"` +} + +// Error implements error interface. +func (e *ErrorResponse) Error() string { + return fmt.Sprintf("%d (%s): %s", e.Code, e.ClientToken, e.Message) +} + +type CertificateCreateResponse struct { + CertificateID string `json:"certificateId"` + CertificatePEM string `json:"certificatePem"` + PrivateKey string `json:"privateKey"` + CertificateOwnershipToken string `json:"certificateOwnershipToken"` +} + +type CertificateCreateFromCSRRequest struct { + CertificateSigningRequest string `json:"certificateSigningRequest"` +} + +type CertificateCreateFromCSRResponse struct { + CertificateID string `json:"certificateId"` + CertificatePEM string `json:"certificatePem"` + CertificateOwnershipToken string `json:"certificateOwnershipToken"` +} + +type RegisterThingRequest struct { + CertificateOwnershipToken string `json:"certificateOwnershipToken"` + Parameters map[string]string `json:"parameters"` +} + +type RegisterThingResponse struct { + DeviceConfiguration map[string]string `json:"deviceConfiguration"` + ThingName string `json:"thingName"` +} diff --git a/fleetprovisioning/state_test.go b/fleetprovisioning/state_test.go new file mode 100644 index 00000000..359929f7 --- /dev/null +++ b/fleetprovisioning/state_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fleetprovisioning + +import ( + "strings" + "testing" +) + +func TestErrorResponse(t *testing.T) { + err := &ErrorResponse{ + Code: 100, + Message: "error message", + } + errStr := err.Error() + if !strings.Contains(errStr, "100") { + t.Error("Error string should contain error code") + } + if !strings.Contains(errStr, "error message") { + t.Error("Error string should contain error message") + } +}