Skip to content

Commit

Permalink
Review feedback group 4
Browse files Browse the repository at this point in the history
  • Loading branch information
flotter committed Aug 28, 2024
1 parent 854e79b commit 814d9ea
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 75 deletions.
87 changes: 49 additions & 38 deletions internals/plan/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type planResult struct {

type extension struct {
field string
ext plan.LayerSectionExtension
ext plan.SectionExtension
}

var extensionTests = []struct {
Expand Down Expand Up @@ -349,16 +349,30 @@ var extensionTests = []struct {
}}

func (s *S) TestPlanExtensions(c *C) {
registeredExtensions := []string{}
defer func() {
// Remove remaining registered extensions.
for _, field := range registeredExtensions {
plan.UnregisterExtension(field)
}
}()

nexttest:
for testIndex, testData := range extensionTests {
c.Logf("TestPlanExtensions :: %s (data index %v)", testData.summary, testIndex)

// Unregister extensions from previous test iteraton.
for _, field := range registeredExtensions {
plan.UnregisterExtension(field)
}
registeredExtensions = []string{}

// Write layers to test directory.
layersDir := filepath.Join(c.MkDir(), "layers")
s.writeLayerFiles(c, layersDir, testData.layers)
var p *plan.Plan

// Register test extensions.
// Register extensions for this test iteration.
for _, e := range testData.extensions {
err := func() (err error) {
defer func() {
Expand All @@ -367,6 +381,7 @@ nexttest:
}
}()
plan.RegisterExtension(e.field, e.ext)
registeredExtensions = append(registeredExtensions, e.field)
return nil
}()
if err != nil {
Expand All @@ -380,37 +395,33 @@ nexttest:
if testData.error != "" || err != nil {
// Expected error.
c.Assert(err, ErrorMatches, testData.error)
} else {
if slices.ContainsFunc(testData.extensions, func(n extension) bool {
return n.field == xField
}) {
// Verify "x-field" data.
var x *xSection
x = p.Section(xField).(*xSection)
c.Assert(err, IsNil)
c.Assert(x.Entries, DeepEquals, testData.result.x.Entries)
}

if slices.ContainsFunc(testData.extensions, func(n extension) bool {
return n.field == yField
}) {
// Verify "y-field" data.
var y *ySection
y = p.Section(yField).(*ySection)
c.Assert(err, IsNil)
c.Assert(y.Entries, DeepEquals, testData.result.y.Entries)
}
continue nexttest
}

// Verify combined plan YAML.
planYAML, err := yaml.Marshal(p)
if slices.ContainsFunc(testData.extensions, func(n extension) bool {
return n.field == xField
}) {
// Verify "x-field" data.
var x *xSection
x = p.Sections[xField].(*xSection)
c.Assert(err, IsNil)
c.Assert(string(planYAML), Equals, testData.resultYaml)
c.Assert(x.Entries, DeepEquals, testData.result.x.Entries)
}

// Unregister test extensions.
for _, e := range testData.extensions {
plan.UnregisterExtension(e.field)
if slices.ContainsFunc(testData.extensions, func(n extension) bool {
return n.field == yField
}) {
// Verify "y-field" data.
var y *ySection
y = p.Sections[yField].(*ySection)
c.Assert(err, IsNil)
c.Assert(y.Entries, DeepEquals, testData.result.y.Entries)
}

// Verify combined plan YAML.
planYAML, err := yaml.Marshal(p)
c.Assert(err, IsNil)
c.Assert(string(planYAML), Equals, testData.resultYaml)
}
}

Expand Down Expand Up @@ -506,10 +517,10 @@ func (s *S) writeLayerFiles(c *C, layersDir string, inputs []*inputLayer) {

const xField string = "x-field"

// xExtension implements the LayerSectionExtension interface.
// xExtension implements the SectionExtension interface.
type xExtension struct{}

func (x xExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) {
func (x xExtension) ParseSection(data yaml.Node) (plan.Section, error) {
xs := &xSection{}
err := data.Decode(xs)
if err != nil {
Expand All @@ -524,7 +535,7 @@ func (x xExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) {
return xs, nil
}

func (x xExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSection, error) {
func (x xExtension) CombineSections(sections ...plan.Section) (plan.Section, error) {
xs := &xSection{}
for _, section := range sections {
err := xs.Combine(section)
Expand All @@ -537,10 +548,10 @@ func (x xExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSe

func (x xExtension) ValidatePlan(p *plan.Plan) error {
var xs *xSection
xs = p.Section(xField).(*xSection)
xs = p.Sections[xField].(*xSection)
if xs != nil {
var ys *ySection
ys = p.Section(yField).(*ySection)
ys = p.Sections[yField].(*ySection)

// Test dependency: Make sure every Y field in X refer to an existing Y entry.
for xEntryField, xEntryValue := range xs.Entries {
Expand Down Expand Up @@ -583,7 +594,7 @@ func (xs *xSection) IsZero() bool {
return xs.Entries == nil
}

func (xs *xSection) Combine(other plan.LayerSection) error {
func (xs *xSection) Combine(other plan.Section) error {
otherxSection, ok := other.(*xSection)
if !ok {
return fmt.Errorf("cannot combine incompatible section type")
Expand Down Expand Up @@ -645,10 +656,10 @@ func (x *X) Merge(other *X) {

const yField string = "y-field"

// yExtension implements the LayerSectionExtension interface.
// yExtension implements the SectionExtension interface.
type yExtension struct{}

func (y yExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) {
func (y yExtension) ParseSection(data yaml.Node) (plan.Section, error) {
ys := &ySection{}
err := data.Decode(ys)
if err != nil {
Expand All @@ -663,7 +674,7 @@ func (y yExtension) ParseSection(data yaml.Node) (plan.LayerSection, error) {
return ys, nil
}

func (y yExtension) CombineSections(sections ...plan.LayerSection) (plan.LayerSection, error) {
func (y yExtension) CombineSections(sections ...plan.Section) (plan.Section, error) {
ys := &ySection{}
for _, section := range sections {
err := ys.Combine(section)
Expand Down Expand Up @@ -701,7 +712,7 @@ func (ys *ySection) IsZero() bool {
return ys.Entries == nil
}

func (ys *ySection) Combine(other plan.LayerSection) error {
func (ys *ySection) Combine(other plan.Section) error {
otherySection, ok := other.(*ySection)
if !ok {
return fmt.Errorf("cannot combine incompatible section type")
Expand Down
31 changes: 12 additions & 19 deletions internals/plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ import (
"github.com/canonical/pebble/internals/osutil"
)

// LayerSectionExtension allows the plan layer schema to be extended without
// SectionExtension allows the plan layer schema to be extended without
// adding centralised schema knowledge to the plan library.
type LayerSectionExtension interface {
type SectionExtension interface {
// ParseSection returns a newly allocated concrete type containing the
// unmarshalled section content.
ParseSection(data yaml.Node) (LayerSection, error)
ParseSection(data yaml.Node) (Section, error)

// CombineSections returns a newly allocated concrete type containing the
// result of combining the supplied sections in order.
CombineSections(sections ...LayerSection) (LayerSection, error)
CombineSections(sections ...Section) (Section, error)

// ValidatePlan takes the complete plan as input, and allows the
// extension to validate the plan. This can be used for cross section
// dependency validation.
ValidatePlan(plan *Plan) error
}

type LayerSection interface {
type Section interface {
// Validate checks whether the section is valid, returning an error if not.
Validate() error

Expand All @@ -70,7 +70,7 @@ const (

var (
// layerExtensions keeps a map of registered extensions.
layerExtensions = map[string]LayerSectionExtension{}
layerExtensions = map[string]SectionExtension{}

// layerExtensionsOrder records the order in which the extensions were registered.
layerExtensionsOrder = []string{}
Expand All @@ -85,7 +85,7 @@ var layerBuiltins = []string{"summary", "description", "services", "checks", "lo
// done before the plan library is used. The order in which extensions are
// registered determines the order in which the sections are marshalled.
// Extension sections are marshalled after the built-in sections.
func RegisterExtension(field string, ext LayerSectionExtension) {
func RegisterExtension(field string, ext SectionExtension) {
if slices.Contains(layerBuiltins, field) {
panic(fmt.Sprintf("internal error: extension %q already used as built-in field", field))
}
Expand All @@ -111,14 +111,7 @@ type Plan struct {
Checks map[string]*Check `yaml:"checks,omitempty"`
LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"`

Sections map[string]LayerSection `yaml:",inline"`
}

// Section retrieves a section from the plan. If Section is called
// before the plan is loaded, or with an unregistered field, this method
// will return nil.
func (p *Plan) Section(field string) LayerSection {
return p.Sections[field]
Sections map[string]Section `yaml:",inline"`
}

// MarshalYAML implements an override for top level omitempty tags handling.
Expand Down Expand Up @@ -170,7 +163,7 @@ type Layer struct {
Checks map[string]*Check `yaml:"checks,omitempty"`
LogTargets map[string]*LogTarget `yaml:"log-targets,omitempty"`

Sections map[string]LayerSection `yaml:",inline"`
Sections map[string]Section `yaml:",inline"`
}

type Service struct {
Expand Down Expand Up @@ -673,15 +666,15 @@ func CombineLayers(layers ...*Layer) (*Layer, error) {
Services: make(map[string]*Service),
Checks: make(map[string]*Check),
LogTargets: make(map[string]*LogTarget),
Sections: make(map[string]LayerSection),
Sections: make(map[string]Section),
}

// Combine the same sections from each layer. Note that we do this before
// the layers length check because we need the extension to provide us with
// a zero value section, even if no layers are supplied (similar to the
// allocations taking place above for the built-in types).
for field, extension := range layerExtensions {
var sections []LayerSection
var sections []Section
for _, layer := range layers {
if section := layer.Sections[field]; section != nil {
sections = append(sections, section)
Expand Down Expand Up @@ -1241,7 +1234,7 @@ func ParseLayer(order int, label string, data []byte) (*Layer, error) {
Services: make(map[string]*Service),
Checks: make(map[string]*Check),
LogTargets: make(map[string]*LogTarget),
Sections: make(map[string]LayerSection),
Sections: make(map[string]Section),
}

// The following manual approach is required because:
Expand Down
Loading

0 comments on commit 814d9ea

Please sign in to comment.