diff --git a/internal/appsec/appsec.go b/internal/appsec/appsec.go index 98be47e8df..6ac60bb215 100644 --- a/internal/appsec/appsec.go +++ b/internal/appsec/appsec.go @@ -14,9 +14,8 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" - "gopkg.in/DataDog/dd-trace-go.v1/internal/remoteconfig" - "github.com/DataDog/go-libddwaf" + waf "github.com/DataDog/go-libddwaf" ) // Enabled returns true when AppSec is up and running. Meaning that the appsec build tag is enabled, the env var @@ -66,7 +65,9 @@ func Start(opts ...StartOption) { // Start the remote configuration client log.Debug("appsec: starting the remote configuration client") - appsec.startRC() + if err := appsec.startRC(); err != nil { + log.Error("appsec: Remote config: disabled due to an instanciation error: %v", err) + } if !set { // AppSec is not enforced by the env var and can be enabled through remote config @@ -114,23 +115,13 @@ func setActiveAppSec(a *appsec) { type appsec struct { cfg *Config limiter *TokenTicker - rc *remoteconfig.Client wafHandle *wafHandle started bool } func newAppSec(cfg *Config) *appsec { - var client *remoteconfig.Client - var err error - if cfg.rc != nil { - client, err = remoteconfig.NewClient(*cfg.rc) - } - if err != nil { - log.Error("appsec: Remote config: disabled due to a client creation error: %v", err) - } return &appsec{ cfg: cfg, - rc: client, } } diff --git a/internal/appsec/remoteconfig.go b/internal/appsec/remoteconfig.go index 23a116299f..9f433e0330 100644 --- a/internal/appsec/remoteconfig.go +++ b/internal/appsec/remoteconfig.go @@ -287,95 +287,109 @@ func mergeRulesDataEntries(entries1, entries2 []rc.ASMDataRuleDataEntry) []rc.AS return entries } -func (a *appsec) startRC() { - if a.rc != nil { - a.rc.Start() +func (a *appsec) startRC() error { + if a.cfg.rc != nil { + return remoteconfig.Start(*a.cfg.rc) } + return nil } func (a *appsec) stopRC() { - if a.rc != nil { - a.rc.Stop() + if a.cfg.rc != nil { + remoteconfig.Stop() } } func (a *appsec) registerRCProduct(p string) error { - if a.rc == nil { + if a.cfg.rc == nil { return fmt.Errorf("no valid remote configuration client") } - a.cfg.rc.Products[p] = struct{}{} - a.rc.RegisterProduct(p) - return nil -} - -func (a *appsec) unregisterRCProduct(p string) error { - if a.rc == nil { - return fmt.Errorf("no valid remote configuration client") - } - delete(a.cfg.rc.Products, p) - a.rc.UnregisterProduct(p) - return nil + return remoteconfig.RegisterProduct(p) } func (a *appsec) registerRCCapability(c remoteconfig.Capability) error { - a.cfg.rc.Capabilities[c] = struct{}{} - if a.rc == nil { + if a.cfg.rc == nil { return fmt.Errorf("no valid remote configuration client") } - a.rc.RegisterCapability(c) - return nil + return remoteconfig.RegisterCapability(c) } -func (a *appsec) unregisterRCCapability(c remoteconfig.Capability) { - if a.rc == nil { +func (a *appsec) unregisterRCCapability(c remoteconfig.Capability) error { + if a.cfg.rc == nil { log.Debug("appsec: Remote config: no valid remote configuration client") - return + return nil } - delete(a.cfg.rc.Capabilities, c) - a.rc.UnregisterCapability(c) + return remoteconfig.UnregisterCapability(c) } func (a *appsec) enableRemoteActivation() error { - if a.rc == nil { + if a.cfg.rc == nil { return fmt.Errorf("no valid remote configuration client") } - a.registerRCProduct(rc.ProductASMFeatures) - a.registerRCCapability(remoteconfig.ASMActivation) - a.rc.RegisterCallback(a.onRemoteActivation) - return nil + err := a.registerRCProduct(rc.ProductASMFeatures) + if err != nil { + return err + } + err = a.registerRCCapability(remoteconfig.ASMActivation) + if err != nil { + return err + } + return remoteconfig.RegisterCallback(a.onRemoteActivation) } func (a *appsec) enableRCBlocking() { - if a.rc == nil { + if a.cfg.rc == nil { log.Debug("appsec: Remote config: no valid remote configuration client") return } - a.registerRCProduct(rc.ProductASM) - a.registerRCProduct(rc.ProductASMDD) - a.registerRCProduct(rc.ProductASMData) - a.rc.RegisterCallback(a.onRCRulesUpdate) + products := []string{rc.ProductASM, rc.ProductASMDD, rc.ProductASMData} + for _, p := range products { + if err := a.registerRCProduct(p); err != nil { + log.Debug("appsec: Remote config: couldn't register product %s: %v", p, err) + } + } + + if err := remoteconfig.RegisterCallback(a.onRCRulesUpdate); err != nil { + log.Debug("appsec: Remote config: couldn't register callback: %v", err) + } if _, isSet := os.LookupEnv(rulesEnvVar); !isSet { - a.registerRCCapability(remoteconfig.ASMUserBlocking) - a.registerRCCapability(remoteconfig.ASMRequestBlocking) - a.registerRCCapability(remoteconfig.ASMIPBlocking) - a.registerRCCapability(remoteconfig.ASMDDRules) - a.registerRCCapability(remoteconfig.ASMExclusions) - a.registerRCCapability(remoteconfig.ASMCustomRules) - a.registerRCCapability(remoteconfig.ASMCustomBlockingResponse) + caps := []remoteconfig.Capability{ + remoteconfig.ASMUserBlocking, + remoteconfig.ASMRequestBlocking, + remoteconfig.ASMIPBlocking, + remoteconfig.ASMDDRules, + remoteconfig.ASMExclusions, + remoteconfig.ASMCustomRules, + remoteconfig.ASMCustomBlockingResponse, + } + for _, c := range caps { + if err := a.registerRCCapability(c); err != nil { + log.Debug("appsec: Remote config: couldn't register capability %v: %v", c, err) + } + } } } func (a *appsec) disableRCBlocking() { - if a.rc == nil { + if a.cfg.rc == nil { return } - a.unregisterRCCapability(remoteconfig.ASMDDRules) - a.unregisterRCCapability(remoteconfig.ASMExclusions) - a.unregisterRCCapability(remoteconfig.ASMIPBlocking) - a.unregisterRCCapability(remoteconfig.ASMRequestBlocking) - a.unregisterRCCapability(remoteconfig.ASMUserBlocking) - a.unregisterRCCapability(remoteconfig.ASMCustomRules) - a.rc.UnregisterCallback(a.onRCRulesUpdate) + caps := []remoteconfig.Capability{ + remoteconfig.ASMDDRules, + remoteconfig.ASMExclusions, + remoteconfig.ASMIPBlocking, + remoteconfig.ASMRequestBlocking, + remoteconfig.ASMUserBlocking, + remoteconfig.ASMCustomRules, + } + for _, c := range caps { + if err := a.unregisterRCCapability(c); err != nil { + log.Debug("appsec: Remote config: couldn't unregister capability %v: %v", c, err) + } + } + if err := remoteconfig.UnregisterCallback(a.onRCRulesUpdate); err != nil { + log.Debug("appsec: Remote config: couldn't unregister callback: %v", err) + } } diff --git a/internal/appsec/remoteconfig_test.go b/internal/appsec/remoteconfig_test.go index cd5af7f630..3da664430a 100644 --- a/internal/appsec/remoteconfig_test.go +++ b/internal/appsec/remoteconfig_test.go @@ -33,6 +33,8 @@ func TestASMFeaturesCallback(t *testing.T) { cfg, err := newConfig() require.NoError(t, err) a := newAppSec(cfg) + err = a.startRC() + require.NoError(t, err) t.Setenv(enabledEnvVar, "") os.Unsetenv(enabledEnvVar) @@ -333,22 +335,27 @@ func TestRemoteActivationScenarios(t *testing.T) { require.NotNil(t, activeAppSec) require.False(t, Enabled()) - client := activeAppSec.rc - require.NotNil(t, client) - require.Contains(t, client.Capabilities, remoteconfig.ASMActivation) - require.Contains(t, client.Products, rc.ProductASMFeatures) + found, err := remoteconfig.HasCapability(remoteconfig.ASMActivation) + require.NoError(t, err) + require.True(t, found) + found, err = remoteconfig.HasProduct(rc.ProductASMFeatures) + require.NoError(t, err) + require.True(t, found) }) t.Run("DD_APPSEC_ENABLED=true", func(t *testing.T) { t.Setenv(enabledEnvVar, "true") + remoteconfig.Reset() Start(WithRCConfig(remoteconfig.DefaultClientConfig())) defer Stop() require.True(t, Enabled()) - client := activeAppSec.rc - require.NotNil(t, client) - require.NotContains(t, client.Capabilities, remoteconfig.ASMActivation) - require.NotContains(t, client.Products, rc.ProductASMFeatures) + found, err := remoteconfig.HasCapability(remoteconfig.ASMActivation) + require.NoError(t, err) + require.False(t, found) + found, err = remoteconfig.HasProduct(rc.ProductASMFeatures) + require.NoError(t, err) + require.False(t, found) }) t.Run("DD_APPSEC_ENABLED=false", func(t *testing.T) { @@ -397,11 +404,10 @@ func TestCapabilities(t *testing.T) { if !Enabled() && activeAppSec == nil { t.Skip() } - require.NotNil(t, activeAppSec.rc) - require.Len(t, activeAppSec.rc.Capabilities, len(tc.expected)) for _, cap := range tc.expected { - _, contained := activeAppSec.rc.Capabilities[cap] - require.True(t, contained) + found, err := remoteconfig.HasCapability(cap) + require.NoError(t, err) + require.True(t, found) } }) } diff --git a/internal/remoteconfig/config.go b/internal/remoteconfig/config.go index a80c57b067..3d2b8ed631 100644 --- a/internal/remoteconfig/config.go +++ b/internal/remoteconfig/config.go @@ -30,8 +30,6 @@ type ClientConfig struct { Env string // The time interval between two client polls to the agent for updates PollInterval time.Duration - // The products this client is interested in - Products map[string]struct{} // The tracer's runtime id RuntimeID string // The name of the user's application @@ -40,8 +38,6 @@ type ClientConfig struct { TracerVersion string // The base TUF root metadata file TUFRoot string - // The capabilities of the client - Capabilities map[Capability]struct{} // HTTP is the HTTP client used to receive config updates HTTP *http.Client } @@ -49,8 +45,6 @@ type ClientConfig struct { // DefaultClientConfig returns the default remote config client configuration func DefaultClientConfig() ClientConfig { return ClientConfig{ - Capabilities: map[Capability]struct{}{}, - Products: map[string]struct{}{}, Env: os.Getenv("DD_ENV"), HTTP: &http.Client{Timeout: 10 * time.Second}, PollInterval: pollIntervalFromEnv(), diff --git a/internal/remoteconfig/remoteconfig.go b/internal/remoteconfig/remoteconfig.go index 8998c50e93..f79a0a3786 100644 --- a/internal/remoteconfig/remoteconfig.go +++ b/internal/remoteconfig/remoteconfig.go @@ -10,12 +10,14 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "math/big" "net/http" "reflect" "strings" + "sync" "time" "gopkg.in/DataDog/dd-trace-go.v1/internal/log" @@ -55,6 +57,9 @@ const ( ASMCustomBlockingResponse ) +// ErrClientNotStarted is returned when the remote config client is not started. +var ErrClientNotStarted = errors.New("remote config client not started") + // ProductUpdate represents an update for a specific product. // It is a map of file path to raw file content type ProductUpdate map[string][]byte @@ -62,6 +67,7 @@ type ProductUpdate map[string][]byte // A Client interacts with an Agent to update and track the state of remote // configuration type Client struct { + sync.RWMutex ClientConfig clientID string @@ -69,13 +75,24 @@ type Client struct { repository *rc.Repository stop chan struct{} - callbacks []Callback + callbacks []Callback + products map[string]struct{} + capabilities map[Capability]struct{} lastError error } -// NewClient creates a new remoteconfig Client -func NewClient(config ClientConfig) (*Client, error) { +// client is a RC client singleton that can be accessed by multiple products (tracing, ASM, profiling etc.). +// Using a single RC client instance in the tracer is a requirement for remote configuration. +var client *Client + +var ( + startOnce sync.Once + stopOnce sync.Once +) + +// newClient creates a new remoteconfig Client +func newClient(config ClientConfig) (*Client, error) { repo, err := rc.NewUnverifiedRepository() if err != nil { return nil, err @@ -92,37 +109,63 @@ func NewClient(config ClientConfig) (*Client, error) { stop: make(chan struct{}), lastError: nil, callbacks: []Callback{}, + capabilities: map[Capability]struct{}{}, + products: map[string]struct{}{}, }, nil } -// Start starts the client's update poll loop in a fresh goroutine -func (c *Client) Start() { - go func() { - ticker := time.NewTicker(c.PollInterval) - defer ticker.Stop() - - for { - select { - case <-c.stop: - close(c.stop) - return - case <-ticker.C: - c.updateState() +// Start starts the client's update poll loop in a fresh goroutine. +// Noop if the client has already started. +func Start(config ClientConfig) error { + var err error + startOnce.Do(func() { + client, err = newClient(config) + if err != nil { + return + } + go func() { + ticker := time.NewTicker(client.PollInterval) + defer ticker.Stop() + + for { + select { + case <-client.stop: + close(client.stop) + return + case <-ticker.C: + client.Lock() + client.updateState() + client.Unlock() + } } + }() + }) + return err +} + +// Stop stops the client's update poll loop. +// Noop if the client has already been stopped. +// The remote config client is supposed to have the same lifecycle as the tracer. +// It can't be restarted after a call to Stop() unless explicitly calling Reset(). +func Stop() { + stopOnce.Do(func() { + log.Debug("remoteconfig: gracefully stopping the client") + client.stop <- struct{}{} + select { + case <-client.stop: + log.Debug("remoteconfig: client stopped successfully") + case <-time.After(time.Second): + log.Debug("remoteconfig: client stopping timeout") } - }() + }) } -// Stop stops the client's update poll loop -func (c *Client) Stop() { - log.Debug("remoteconfig: gracefully stopping the client") - c.stop <- struct{}{} - select { - case <-c.stop: - log.Debug("remoteconfig: client stopped successfully") - case <-time.After(time.Second): - log.Debug("remoteconfig: client stopping timeout") - } +// Reset destroys the client instance. +// To be used only in tests to reset the state of the client. +func Reset() { + client = nil + startOnce = sync.Once{} + stopOnce = sync.Once{} } func (c *Client) updateState() { @@ -176,52 +219,110 @@ func (c *Client) updateState() { // RegisterCallback allows registering a callback that will be invoked when the client // receives configuration updates. It is up to that callback to then decide what to do // depending on the product related to the configuration update. -func (c *Client) RegisterCallback(f Callback) { - c.callbacks = append(c.callbacks, f) +func RegisterCallback(f Callback) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() + client.callbacks = append(client.callbacks, f) + return nil } // UnregisterCallback removes a previously registered callback from the active callbacks list // This remove operation preserves ordering -func (c *Client) UnregisterCallback(f Callback) { +func UnregisterCallback(f Callback) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() fValue := reflect.ValueOf(f) - for i, callback := range c.callbacks { + for i, callback := range client.callbacks { if reflect.ValueOf(callback) == fValue { - c.callbacks = append(c.callbacks[:i], c.callbacks[i+1:]...) + client.callbacks = append(client.callbacks[:i], client.callbacks[i+1:]...) } } + return nil } // RegisterProduct adds a product to the list of products listened by the client -func (c *Client) RegisterProduct(p string) { - c.Products[p] = struct{}{} +func RegisterProduct(p string) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() + client.products[p] = struct{}{} + return nil } // UnregisterProduct removes a product from the list of products listened by the client -func (c *Client) UnregisterProduct(p string) { - delete(c.Products, p) +func UnregisterProduct(p string) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() + delete(client.products, p) + return nil +} + +// HasProduct returns whether a given product was registered +func HasProduct(p string) (bool, error) { + if client == nil { + return false, ErrClientNotStarted + } + client.RLock() + defer client.RUnlock() + _, found := client.products[p] + return found, nil } // RegisterCapability adds a capability to the list of capabilities exposed by the client when requesting // configuration updates -func (c *Client) RegisterCapability(cap Capability) { - c.Capabilities[cap] = struct{}{} +func RegisterCapability(cap Capability) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() + client.capabilities[cap] = struct{}{} + return nil } // UnregisterCapability removes a capability from the list of capabilities exposed by the client when requesting // configuration updates -func (c *Client) UnregisterCapability(cap Capability) { - delete(c.Capabilities, cap) +func UnregisterCapability(cap Capability) error { + if client == nil { + return ErrClientNotStarted + } + client.Lock() + defer client.Unlock() + delete(client.capabilities, cap) + return nil +} + +// HasCapability returns whether a given capability was registered +func HasCapability(cap Capability) (bool, error) { + if client == nil { + return false, ErrClientNotStarted + } + client.RLock() + defer client.RUnlock() + _, found := client.capabilities[cap] + return found, nil } func (c *Client) applyUpdate(pbUpdate *clientGetConfigsResponse) error { fileMap := make(map[string][]byte, len(pbUpdate.TargetFiles)) - productUpdates := make(map[string]ProductUpdate, len(c.Products)) - for p := range c.Products { + productUpdates := make(map[string]ProductUpdate, len(c.products)) + for p := range c.products { productUpdates[p] = make(ProductUpdate) } for _, f := range pbUpdate.TargetFiles { fileMap[f.Path] = f.Raw - for p := range c.Products { + for p := range c.products { // Check the config file path to make sure it belongs to the right product if strings.Contains(f.Path, "/"+p+"/") { productUpdates[p][f.Path] = f.Raw @@ -353,11 +454,11 @@ func (c *Client) newUpdateRequest() (bytes.Buffer, error) { } capa := big.NewInt(0) - for i := range c.Capabilities { + for i := range c.capabilities { capa.SetBit(capa, int(i), 1) } - products := make([]string, 0, len(c.Products)) - for p := range c.Products { + products := make([]string, 0, len(c.products)) + for p := range c.products { products = append(products, p) } req := clientGetConfigsRequest{ diff --git a/internal/remoteconfig/remoteconfig_test.go b/internal/remoteconfig/remoteconfig_test.go index 80769ae5cf..b46d40733f 100644 --- a/internal/remoteconfig/remoteconfig_test.go +++ b/internal/remoteconfig/remoteconfig_test.go @@ -28,7 +28,8 @@ import ( func TestRCClient(t *testing.T) { cfg := DefaultClientConfig() cfg.ServiceName = "test" - client, err := NewClient(cfg) + var err error + client, err = newClient(cfg) require.NoError(t, err) t.Run("registerCallback", func(t *testing.T) { @@ -36,18 +37,21 @@ func TestRCClient(t *testing.T) { nilCallback := func(map[string]ProductUpdate) map[string]rc.ApplyStatus { return nil } defer func() { client.callbacks = []Callback{} }() require.Equal(t, 0, len(client.callbacks)) - client.RegisterCallback(nilCallback) + err = RegisterCallback(nilCallback) + require.NoError(t, err) require.Equal(t, 1, len(client.callbacks)) require.Equal(t, 1, len(client.callbacks)) - client.RegisterCallback(nilCallback) + err = RegisterCallback(nilCallback) + require.NoError(t, err) require.Equal(t, 2, len(client.callbacks)) }) t.Run("apply-update", func(t *testing.T) { client.callbacks = []Callback{} cfgPath := "datadog/2/ASM_FEATURES/asm_features_activation/config" - client.RegisterProduct(rc.ProductASMFeatures) - client.RegisterCallback(func(updates map[string]ProductUpdate) map[string]rc.ApplyStatus { + err = RegisterProduct(rc.ProductASMFeatures) + require.NoError(t, err) + err = RegisterCallback(func(updates map[string]ProductUpdate) map[string]rc.ApplyStatus { statuses := map[string]rc.ApplyStatus{} for p, u := range updates { if p == rc.ProductASMFeatures { @@ -59,6 +63,7 @@ func TestRCClient(t *testing.T) { } return statuses }) + require.NoError(t, err) resp := genUpdateResponse([]byte("test"), cfgPath) err := client.applyUpdate(resp) @@ -248,27 +253,36 @@ func dummyCallback4(map[string]ProductUpdate) map[string]rc.ApplyStatus { func TestRegistration(t *testing.T) { t.Run("callbacks", func(t *testing.T) { - client, err := NewClient(DefaultClientConfig()) + var err error + client, err = newClient(DefaultClientConfig()) require.NoError(t, err) - client.RegisterCallback(dummyCallback1) + err = RegisterCallback(dummyCallback1) + require.NoError(t, err) require.Len(t, client.callbacks, 1) - client.UnregisterCallback(dummyCallback1) + err = UnregisterCallback(dummyCallback1) + require.NoError(t, err) require.Empty(t, client.callbacks) - client.RegisterCallback(dummyCallback2) - client.RegisterCallback(dummyCallback3) - client.RegisterCallback(dummyCallback1) - client.RegisterCallback(dummyCallback4) + err = RegisterCallback(dummyCallback2) + require.NoError(t, err) + err = RegisterCallback(dummyCallback3) + require.NoError(t, err) + err = RegisterCallback(dummyCallback1) + require.NoError(t, err) + err = RegisterCallback(dummyCallback4) + require.NoError(t, err) require.Len(t, client.callbacks, 4) - client.UnregisterCallback(dummyCallback1) + err = UnregisterCallback(dummyCallback1) + require.NoError(t, err) require.Len(t, client.callbacks, 3) for _, c := range client.callbacks { require.NotEqual(t, reflect.ValueOf(dummyCallback1), reflect.ValueOf(c)) } - client.UnregisterCallback(dummyCallback3) + err = UnregisterCallback(dummyCallback3) + require.NoError(t, err) require.Len(t, client.callbacks, 2) for _, c := range client.callbacks { require.NotEqual(t, reflect.ValueOf(dummyCallback3), reflect.ValueOf(c))