diff --git a/internals/overlord/planstate/manager.go b/internals/overlord/planstate/manager.go index 49fb202f..a4fb4414 100644 --- a/internals/overlord/planstate/manager.go +++ b/internals/overlord/planstate/manager.go @@ -181,6 +181,7 @@ func (m *PlanManager) updatePlanLayers(layers []*plan.Layer) (*plan.Plan, error) Services: combined.Services, Checks: combined.Checks, LogTargets: combined.LogTargets, + Sections: combined.Sections, } err = p.Validate() if err != nil { diff --git a/internals/overlord/planstate/manager_test.go b/internals/overlord/planstate/manager_test.go index 9d84bef1..300e0411 100644 --- a/internals/overlord/planstate/manager_test.go +++ b/internals/overlord/planstate/manager_test.go @@ -45,6 +45,10 @@ var loadLayers = []string{` summary: Svc1 override: replace command: echo svc1 + test-field: + test1: + override: merge + a: something `, ` summary: Layer 2 description: Layer 2 desc. @@ -53,9 +57,15 @@ var loadLayers = []string{` summary: Svc2 override: replace command: echo svc2 + test-field: + test1: + override: merge + b: something else `} func (ps *planSuite) TestLoadLayers(c *C) { + plan.RegisterSectionExtension(testField, testExtension{}) + defer plan.UnregisterSectionExtension(testField) var err error ps.planMgr, err = planstate.NewManager(ps.layersDir) c.Assert(err, IsNil) @@ -80,10 +90,17 @@ services: summary: Svc2 override: replace command: echo svc2 +test-field: + test1: + override: merge + a: something + b: something else `[1:]) } func (ps *planSuite) TestAppendLayers(c *C) { + plan.RegisterSectionExtension(testField, testExtension{}) + defer plan.UnregisterSectionExtension(testField) var err error ps.planMgr, err = planstate.NewManager(ps.layersDir) c.Assert(err, IsNil) @@ -94,6 +111,10 @@ services: svc1: override: replace command: /bin/sh +test-field: + test1: + override: replace + a: something `) err = ps.planMgr.AppendLayer(layer) c.Assert(err, IsNil) @@ -103,6 +124,10 @@ services: svc1: override: replace command: /bin/sh +test-field: + test1: + override: replace + a: something `[1:]) ps.planLayersHasLen(c, 1) @@ -112,6 +137,10 @@ services: svc1: override: foobar command: /bin/bar +test-field: + test1: + override: foobar + a: something else `) err = ps.planMgr.AppendLayer(layer) c.Assert(err.(*planstate.LabelExists).Label, Equals, "label1") @@ -120,6 +149,10 @@ services: svc1: override: replace command: /bin/sh +test-field: + test1: + override: replace + a: something `[1:]) ps.planLayersHasLen(c, 1) @@ -129,6 +162,10 @@ services: svc1: override: replace command: /bin/bash +test-field: + test1: + override: replace + a: else `) err = ps.planMgr.AppendLayer(layer) c.Assert(err, IsNil) @@ -138,6 +175,10 @@ services: svc1: override: replace command: /bin/bash +test-field: + test1: + override: replace + a: else `[1:]) ps.planLayersHasLen(c, 2) @@ -147,6 +188,10 @@ services: svc2: override: replace command: /bin/foo +test-field: + test2: + override: replace + a: something `) err = ps.planMgr.AppendLayer(layer) c.Assert(err, IsNil) @@ -159,11 +204,20 @@ services: svc2: override: replace command: /bin/foo +test-field: + test1: + override: replace + a: else + test2: + override: replace + a: something `[1:]) ps.planLayersHasLen(c, 3) } func (ps *planSuite) TestCombineLayers(c *C) { + plan.RegisterSectionExtension(testField, testExtension{}) + defer plan.UnregisterSectionExtension(testField) var err error ps.planMgr, err = planstate.NewManager(ps.layersDir) c.Assert(err, IsNil) @@ -174,6 +228,10 @@ services: svc1: override: replace command: /bin/sh +test-field: + test1: + override: replace + a: something `) err = ps.planMgr.CombineLayer(layer) c.Assert(err, IsNil) @@ -183,6 +241,10 @@ services: svc1: override: replace command: /bin/sh +test-field: + test1: + override: replace + a: something `[1:]) ps.planLayersHasLen(c, 1) @@ -192,6 +254,10 @@ services: svc2: override: replace command: /bin/foo +test-field: + test2: + override: replace + a: else `) err = ps.planMgr.CombineLayer(layer) c.Assert(err, IsNil) @@ -204,6 +270,13 @@ services: svc2: override: replace command: /bin/foo +test-field: + test1: + override: replace + a: something + test2: + override: replace + a: else `[1:]) ps.planLayersHasLen(c, 2) @@ -213,6 +286,10 @@ services: svc1: override: replace command: /bin/bash +test-field: + test1: + override: replace + a: else `) err = ps.planMgr.CombineLayer(layer) c.Assert(err, IsNil) @@ -225,6 +302,13 @@ services: svc2: override: replace command: /bin/foo +test-field: + test1: + override: replace + a: else + test2: + override: replace + a: else `[1:]) ps.planLayersHasLen(c, 2) @@ -234,6 +318,10 @@ services: svc2: override: replace command: /bin/bar +test-field: + test2: + override: replace + a: something `) err = ps.planMgr.CombineLayer(layer) c.Assert(err, IsNil) @@ -246,6 +334,13 @@ services: svc2: override: replace command: /bin/bar +test-field: + test1: + override: replace + a: else + test2: + override: replace + a: something `[1:]) ps.planLayersHasLen(c, 2) @@ -258,6 +353,13 @@ services: svc2: override: replace command: /bin/b +test-field: + test1: + override: replace + a: nothing + test2: + override: replace + a: nothing `) err = ps.planMgr.CombineLayer(layer) c.Assert(err, IsNil) @@ -270,6 +372,13 @@ services: svc2: override: replace command: /bin/b +test-field: + test1: + override: replace + a: nothing + test2: + override: replace + a: nothing `[1:]) ps.planLayersHasLen(c, 3) @@ -283,6 +392,15 @@ checks: port: 8080 `)) c.Check(err, ErrorMatches, `(?s).*plan check.*must be "alive" or "ready".*`) + + // Make sure that layer validation is happening for extensions. + layer, err = plan.ParseLayer(0, "label4", []byte(` +test-field: + my1: + override: replace + a: nothing +`)) + c.Check(err, ErrorMatches, `.*entry names must start with.*`) } func (ps *planSuite) TestSetServiceArgs(c *C) { diff --git a/internals/overlord/planstate/package_test.go b/internals/overlord/planstate/package_test.go index 9f2f5e8f..4e79d7e6 100644 --- a/internals/overlord/planstate/package_test.go +++ b/internals/overlord/planstate/package_test.go @@ -42,7 +42,9 @@ type planSuite struct { var _ = Suite(&planSuite{}) func (ps *planSuite) SetUpTest(c *C) { - ps.layersDir = c.MkDir() + ps.layersDir = filepath.Join(c.MkDir(), "layers") + err := os.Mkdir(ps.layersDir, 0755) + c.Assert(err, IsNil) //Reset write layer counter ps.writeLayerCounter = 1 @@ -100,3 +102,119 @@ func reindent(in string) []byte { } return buf.Bytes() } + +const testField string = "test-field" + +// testExtension implements the LayerSectionExtension interface. +type testExtension struct{} + +func (te testExtension) ParseSection(data yaml.Node) (plan.Section, error) { + ts := &testSection{} + err := data.Decode(ts) + if err != nil { + return nil, err + } + // Populate Name. + for name, entry := range ts.Entries { + if entry != nil { + ts.Entries[name].Name = name + } + } + return ts, nil +} + +func (te testExtension) CombineSections(sections ...plan.Section) (plan.Section, error) { + ts := &testSection{} + for _, section := range sections { + err := ts.Combine(section) + if err != nil { + return nil, err + } + } + return ts, nil +} + +func (te testExtension) ValidatePlan(p *plan.Plan) error { + // This extension has no dependencies on the Plan to validate. + return nil +} + +// testSection is the backing type for testExtension. +type testSection struct { + Entries map[string]*T `yaml:",inline"` +} + +func (ts *testSection) Validate() error { + // Fictitious test requirement: fields must start with t + prefix := "t" + for field, _ := range ts.Entries { + if !strings.HasPrefix(field, prefix) { + return fmt.Errorf("%q entry names must start with %q", testField, prefix) + } + } + return nil +} + +func (ts *testSection) IsZero() bool { + return ts.Entries == nil +} + +func (ts *testSection) Combine(other plan.Section) error { + otherTSection, ok := other.(*testSection) + if !ok { + return fmt.Errorf("invalid section type") + } + + for field, entry := range otherTSection.Entries { + ts.Entries = makeMapIfNil(ts.Entries) + switch entry.Override { + case plan.MergeOverride: + if old, ok := ts.Entries[field]; ok { + copied := old.Copy() + copied.Merge(entry) + ts.Entries[field] = copied + break + } + fallthrough + case plan.ReplaceOverride: + ts.Entries[field] = entry.Copy() + case plan.UnknownOverride: + return &plan.FormatError{ + Message: fmt.Sprintf(`invalid "override" value for entry %q`, field), + } + default: + return &plan.FormatError{ + Message: fmt.Sprintf(`unknown "override" value for entry %q`, field), + } + } + } + return nil +} + +type T struct { + Name string `yaml:"-"` + Override plan.Override `yaml:"override,omitempty"` + A string `yaml:"a,omitempty"` + B string `yaml:"b,omitempty"` +} + +func (t *T) Copy() *T { + copied := *t + return &copied +} + +func (t *T) Merge(other *T) { + if other.A != "" { + t.A = other.A + } + if other.B != "" { + t.B = other.B + } +} + +func makeMapIfNil[K comparable, V any](m map[K]V) map[K]V { + if m == nil { + m = make(map[K]V) + } + return m +}