diff --git a/internals/plan/extensions_test.go b/internals/plan/extensions_test.go index ea6a2136..2cbd8b86 100644 --- a/internals/plan/extensions_test.go +++ b/internals/plan/extensions_test.go @@ -44,7 +44,7 @@ type planResult struct { type extension struct { field string - ext plan.LayerSectionExtension + ext plan.SectionExtension } var extensionTests = []struct { @@ -349,16 +349,30 @@ var extensionTests = []struct { }} func (s *S) TestPlanExtensions(c *C) { + registeredExtensions := []string{} + defer func() { + // Remove remaining registered extensions. + for _, field := range registeredExtensions { + plan.UnregisterExtension(field) + } + }() + nexttest: for testIndex, testData := range extensionTests { c.Logf("TestPlanExtensions :: %s (data index %v)", testData.summary, testIndex) + // Unregister extensions from previous test iteraton. + for _, field := range registeredExtensions { + plan.UnregisterExtension(field) + } + registeredExtensions = []string{} + // Write layers to test directory. layersDir := filepath.Join(c.MkDir(), "layers") s.writeLayerFiles(c, layersDir, testData.layers) var p *plan.Plan - // Register test extensions. + // Register extensions for this test iteration. for _, e := range testData.extensions { err := func() (err error) { defer func() { @@ -367,6 +381,7 @@ nexttest: } }() plan.RegisterExtension(e.field, e.ext) + registeredExtensions = append(registeredExtensions, e.field) return nil }() if err != nil { @@ -380,37 +395,33 @@ nexttest: if testData.error != "" || err != nil { // Expected error. c.Assert(err, ErrorMatches, testData.error) - } else { - if slices.ContainsFunc(testData.extensions, func(n extension) bool { - return n.field == xField - }) { - // Verify "x-field" data. - var x *xSection - x = p.Section(xField).(*xSection) - c.Assert(err, IsNil) - c.Assert(x.Entries, DeepEquals, testData.result.x.Entries) - } - - if slices.ContainsFunc(testData.extensions, func(n extension) bool { - return n.field == yField - }) { - // Verify "y-field" data. - var y *ySection - y = p.Section(yField).(*ySection) - c.Assert(err, IsNil) - c.Assert(y.Entries, DeepEquals, testData.result.y.Entries) - } + continue nexttest + } - // Verify combined plan YAML. - planYAML, err := yaml.Marshal(p) + if slices.ContainsFunc(testData.extensions, func(n extension) bool { + return n.field == xField + }) { + // Verify "x-field" data. + var x *xSection + x = p.Sections[xField].(*xSection) c.Assert(err, IsNil) - c.Assert(string(planYAML), Equals, testData.resultYaml) + c.Assert(x.Entries, DeepEquals, testData.result.x.Entries) } - // Unregister test extensions. - for _, e := range testData.extensions { - plan.UnregisterExtension(e.field) + if slices.ContainsFunc(testData.extensions, func(n extension) bool { + return n.field == yField + }) { + // Verify "y-field" data. + var y *ySection + y = p.Sections[yField].(*ySection) + c.Assert(err, IsNil) + c.Assert(y.Entries, DeepEquals, testData.result.y.Entries) } + + // Verify combined plan YAML. + planYAML, err := yaml.Marshal(p) + c.Assert(err, IsNil) + c.Assert(string(planYAML), Equals, testData.resultYaml) } } @@ -506,10 +517,10 @@ func (s *S) writeLayerFiles(c *C, layersDir string, inputs []*inputLayer) { const xField string = "x-field" -// xExtension implements the LayerSectionExtension interface. +// xExtension implements the SectionExtension interface. type xExtension struct{} -func (x xExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) { +func (x xExtension) ParseSection(data yaml.Node) (plan.Section, error) { xs := &xSection{} err := data.Decode(xs) if err != nil { @@ -524,7 +535,7 @@ func (x xExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) { return xs, nil } -func (x xExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSection, error) { +func (x xExtension) CombineSections(sections ...plan.Section) (plan.Section, error) { xs := &xSection{} for _, section := range sections { err := xs.Combine(section) @@ -537,10 +548,10 @@ func (x xExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSe func (x xExtension) ValidatePlan(p *plan.Plan) error { var xs *xSection - xs = p.Section(xField).(*xSection) + xs = p.Sections[xField].(*xSection) if xs != nil { var ys *ySection - ys = p.Section(yField).(*ySection) + ys = p.Sections[yField].(*ySection) // Test dependency: Make sure every Y field in X refer to an existing Y entry. for xEntryField, xEntryValue := range xs.Entries { @@ -583,7 +594,7 @@ func (xs *xSection) IsZero() bool { return xs.Entries == nil } -func (xs *xSection) Combine(other plan.LayerSection) error { +func (xs *xSection) Combine(other plan.Section) error { otherxSection, ok := other.(*xSection) if !ok { return fmt.Errorf("cannot combine incompatible section type") @@ -645,10 +656,10 @@ func (x *X) Merge(other *X) { const yField string = "y-field" -// yExtension implements the LayerSectionExtension interface. +// yExtension implements the SectionExtension interface. type yExtension struct{} -func (y yExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) { +func (y yExtension) ParseSection(data yaml.Node) (plan.Section, error) { ys := &ySection{} err := data.Decode(ys) if err != nil { @@ -663,7 +674,7 @@ func (y yExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) { return ys, nil } -func (y yExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSection, error) { +func (y yExtension) CombineSections(sections ...plan.Section) (plan.Section, error) { ys := &ySection{} for _, section := range sections { err := ys.Combine(section) @@ -701,7 +712,7 @@ func (ys *ySection) IsZero() bool { return ys.Entries == nil } -func (ys *ySection) Combine(other plan.LayerSection) error { +func (ys *ySection) Combine(other plan.Section) error { otherySection, ok := other.(*ySection) if !ok { return fmt.Errorf("cannot combine incompatible section type") diff --git a/internals/plan/plan.go b/internals/plan/plan.go index 9c10e87a..107f1d45 100644 --- a/internals/plan/plan.go +++ b/internals/plan/plan.go @@ -33,16 +33,16 @@ import ( "github.com/canonical/pebble/internals/osutil" ) -// LayerSectionExtension allows the plan layer schema to be extended without +// SectionExtension allows the plan layer schema to be extended without // adding centralised schema knowledge to the plan library. -type LayerSectionExtension interface { +type SectionExtension interface { // ParseSection returns a newly allocated concrete type containing the // unmarshalled section content. - ParseSection(data yaml.Node) (LayerSection, error) + ParseSection(data yaml.Node) (Section, error) // CombineSections returns a newly allocated concrete type containing the // result of combining the supplied sections in order. - CombineSections(sections ...LayerSection) (LayerSection, error) + CombineSections(sections ...Section) (Section, error) // ValidatePlan takes the complete plan as input, and allows the // extension to validate the plan. This can be used for cross section @@ -50,7 +50,7 @@ type LayerSectionExtension interface { ValidatePlan(plan *Plan) error } -type LayerSection interface { +type Section interface { // Validate checks whether the section is valid, returning an error if not. Validate() error @@ -70,7 +70,7 @@ const ( var ( // layerExtensions keeps a map of registered extensions. - layerExtensions = map[string]LayerSectionExtension{} + layerExtensions = map[string]SectionExtension{} // layerExtensionsOrder records the order in which the extensions were registered. layerExtensionsOrder = []string{} @@ -85,7 +85,7 @@ var layerBuiltins = []string{"summary", "description", "services", "checks", "lo // done before the plan library is used. The order in which extensions are // registered determines the order in which the sections are marshalled. // Extension sections are marshalled after the built-in sections. -func RegisterExtension(field string, ext LayerSectionExtension) { +func RegisterExtension(field string, ext SectionExtension) { if slices.Contains(layerBuiltins, field) { panic(fmt.Sprintf("internal error: extension %q already used as built-in field", field)) } @@ -111,14 +111,7 @@ type Plan struct { Checks map[string]*Check `yaml:"checks,omitempty"` LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"` - Sections map[string]LayerSection `yaml:",inline"` -} - -// Section retrieves a section from the plan. If Section is called -// before the plan is loaded, or with an unregistered field, this method -// will return nil. -func (p *Plan) Section(field string) LayerSection { - return p.Sections[field] + Sections map[string]Section `yaml:",inline"` } // MarshalYAML implements an override for top level omitempty tags handling. @@ -170,7 +163,7 @@ type Layer struct { Checks map[string]*Check `yaml:"checks,omitempty"` LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"` - Sections map[string]LayerSection `yaml:",inline"` + Sections map[string]Section `yaml:",inline"` } type Service struct { @@ -673,7 +666,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { Services: make(map[string]*Service), Checks: make(map[string]*Check), LogTargets: make(map[string]*LogTarget), - Sections: make(map[string]LayerSection), + Sections: make(map[string]Section), } // Combine the same sections from each layer. Note that we do this before @@ -681,7 +674,7 @@ func CombineLayers(layers ...*Layer) (*Layer, error) { // a zero value section, even if no layers are supplied (similar to the // allocations taking place above for the built-in types). for field, extension := range layerExtensions { - var sections []LayerSection + var sections []Section for _, layer := range layers { if section := layer.Sections[field]; section != nil { sections = append(sections, section) @@ -1241,7 +1234,7 @@ func ParseLayer(order int, label string, data []byte) (*Layer, error) { Services: make(map[string]*Service), Checks: make(map[string]*Check), LogTargets: make(map[string]*LogTarget), - Sections: make(map[string]LayerSection), + Sections: make(map[string]Section), } // The following manual approach is required because: diff --git a/internals/plan/plan_test.go b/internals/plan/plan_test.go index 3829301f..34692c93 100644 --- a/internals/plan/plan_test.go +++ b/internals/plan/plan_test.go @@ -206,7 +206,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, { Order: 1, Label: "layer-1", @@ -257,7 +257,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }}, result: &plan.Layer{ Summary: "Simple override layer.", @@ -337,7 +337,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, start: map[string][]string{ "srv1": {"srv2", "srv1", "srv3"}, @@ -400,7 +400,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }}, }, { summary: "Unknown keys are not accepted", @@ -551,7 +551,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }}, }, { summary: `Invalid service command: cannot have any arguments after [ ... ] group`, @@ -660,7 +660,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Checks override replace works correctly", @@ -738,7 +738,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Checks override merge works correctly", @@ -822,7 +822,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Timeout is capped at period", @@ -852,7 +852,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Unset timeout is capped at period", @@ -881,7 +881,7 @@ var planTests = []planTest{{ }, }, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "One of http, tcp, or exec must be present for check", @@ -1002,7 +1002,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Overriding log targets", @@ -1099,7 +1099,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, { Label: "layer-1", Order: 1, @@ -1138,7 +1138,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }}, result: &plan.Layer{ Services: map[string]*plan.Service{ @@ -1184,7 +1184,7 @@ var planTests = []planTest{{ Override: plan.MergeOverride, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Log target requires type field", @@ -1294,7 +1294,7 @@ var planTests = []planTest{{ }, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, { Order: 1, Label: "layer-1", @@ -1320,7 +1320,7 @@ var planTests = []planTest{{ }, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }}, result: &plan.Layer{ Services: map[string]*plan.Service{}, @@ -1348,7 +1348,7 @@ var planTests = []planTest{{ }, }, }, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Reserved log target labels", @@ -1399,7 +1399,7 @@ var planTests = []planTest{{ }, Checks: map[string]*plan.Check{}, LogTargets: map[string]*plan.LogTarget{}, - Sections: map[string]plan.LayerSection{}, + Sections: map[string]plan.Section{}, }, }, { summary: "Three layers missing command",