diff --git a/build/build.go b/build/build.go index d0d8291e4c97..f0952971902a 100644 --- a/build/build.go +++ b/build/build.go @@ -18,6 +18,7 @@ import ( "github.com/containerd/containerd/v2/core/images" "github.com/distribution/reference" + noderesolver "github.com/docker/buildx/build/resolver" "github.com/docker/buildx/builder" "github.com/docker/buildx/driver" "github.com/docker/buildx/util/buildflags" @@ -120,7 +121,7 @@ type NamedContext struct { } type reqForNode struct { - *resolvedNode + *noderesolver.ResolvedNode so *client.SolveOpt } @@ -189,7 +190,7 @@ func warnOnNoOutput(ctx context.Context, nodes []builder.Node, opts map[string]O logrus.Warnf("%s. Build result will only remain in the build cache. To push result image into registry use --push or to load image into docker use --load", warnNoOutputBuf.String()) } -func newBuildRequests(ctx context.Context, docker *dockerutil.Client, cfg *confutil.Config, drivers map[string][]*resolvedNode, w progress.Writer, opts map[string]Options) (_ map[string][]*reqForNode, _ func(), retErr error) { +func newBuildRequests(ctx context.Context, docker *dockerutil.Client, cfg *confutil.Config, drivers map[string][]*noderesolver.ResolvedNode, w progress.Writer, opts map[string]Options) (_ map[string][]*reqForNode, _ func(), retErr error) { reqForNodes := make(map[string][]*reqForNode) var releasers []func() @@ -219,7 +220,7 @@ func newBuildRequests(ctx context.Context, docker *dockerutil.Client, cfg *confu if np.Node().Driver.IsMobyDriver() { hasMobyDriver = true } - opt.Platforms = np.platforms + opt.Platforms = np.Platforms() gatewayOpts, err := np.BuildOpts(ctx) if err != nil { return nil, nil, err @@ -236,7 +237,7 @@ func newBuildRequests(ctx context.Context, docker *dockerutil.Client, cfg *confu } addGitAttrs(so) reqn = append(reqn, &reqForNode{ - resolvedNode: np, + ResolvedNode: np, so: so, }) } @@ -267,7 +268,7 @@ func newBuildRequests(ctx context.Context, docker *dockerutil.Client, cfg *confu return reqForNodes, releaseAll, nil } -func validateTargetLinks(reqForNodes map[string][]*reqForNode, drivers map[string][]*resolvedNode, opts map[string]Options) error { +func validateTargetLinks(reqForNodes map[string][]*reqForNode, drivers map[string][]*noderesolver.ResolvedNode, opts map[string]Options) error { for name := range opts { dps := reqForNodes[name] for i, dp := range dps { @@ -282,7 +283,7 @@ func validateTargetLinks(reqForNodes map[string][]*reqForNode, drivers map[strin var found bool for _, dp2 := range dps2 { - if dp2.driverIndex == dp.driverIndex { + if dp2.Key() == dp.Key() { found = true break } @@ -335,7 +336,11 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opts map[ } warnOnNoOutput(ctx, nodes, opts) - drivers, err := resolveDrivers(ctx, nodes, opts, w) + optPlatforms := make(map[string][]ocispecs.Platform, len(opts)) + for k, opt := range opts { + optPlatforms[k] = opt.Platforms + } + drivers, err := noderesolver.ResolveAll(ctx, nodes, optPlatforms, w) if err != nil { return nil, err } @@ -459,7 +464,7 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opts map[ pw = progress.ResetTime(pw) - if err := waitContextDeps(ctx, dp.driverIndex, results, so); err != nil { + if err := waitContextDeps(ctx, dp, results, so); err != nil { return err } @@ -510,7 +515,7 @@ func BuildWithResultHandler(ctx context.Context, nodes []builder.Node, opts map[ callRes = res.Metadata } - rKey := resultKey(dp.driverIndex, k) + rKey := resultKey(dp, k) results.Set(rKey, res) forceEval := false @@ -897,8 +902,8 @@ func remoteDigestWithMoby(ctx context.Context, d *driver.DriverHandle, name stri return remoteImage.Descriptor.Digest.String(), nil } -func resultKey(index int, name string) string { - return fmt.Sprintf("%d-%s", index, name) +func resultKey(node *noderesolver.ResolvedNode, name string) string { + return fmt.Sprintf("%s-%s", node.Key(), name) } // detectSharedMounts looks for same local mounts used by multiple requests to the same node @@ -916,7 +921,7 @@ func detectSharedMounts(ctx context.Context, reqs map[string][]*reqForNode) (_ m m := map[string]map[fsKey]*fsTracker{} for _, reqs := range reqs { for _, req := range reqs { - nodeName := req.resolvedNode.Node().Name + nodeName := req.ResolvedNode.Node().Name if _, ok := m[nodeName]; !ok { m[nodeName] = map[fsKey]*fsTracker{} } @@ -1028,8 +1033,8 @@ func calculateChildTargets(reqs map[string][]*reqForNode, opt map[string]Options so := reqs[name][i].so for k, v := range so.FrontendAttrs { if strings.HasPrefix(k, "context:") && strings.HasPrefix(v, "target:") { - target := resultKey(dp.driverIndex, strings.TrimPrefix(v, "target:")) - out[target] = append(out[target], resultKey(dp.driverIndex, name)) + target := resultKey(dp.ResolvedNode, strings.TrimPrefix(v, "target:")) + out[target] = append(out[target], resultKey(dp.ResolvedNode, name)) } } } @@ -1037,11 +1042,11 @@ func calculateChildTargets(reqs map[string][]*reqForNode, opt map[string]Options return out } -func waitContextDeps(ctx context.Context, index int, results *waitmap.Map, so *client.SolveOpt) error { +func waitContextDeps(ctx context.Context, node *noderesolver.ResolvedNode, results *waitmap.Map, so *client.SolveOpt) error { m := map[string][]string{} for k, v := range so.FrontendAttrs { if strings.HasPrefix(k, "context:") && strings.HasPrefix(v, "target:") { - target := resultKey(index, strings.TrimPrefix(v, "target:")) + target := resultKey(node, strings.TrimPrefix(v, "target:")) m[target] = append(m[target], k) } } diff --git a/build/dial.go b/build/dial.go index 571d3c1e5d39..bb0025cdcb4a 100644 --- a/build/dial.go +++ b/build/dial.go @@ -7,6 +7,7 @@ import ( "slices" "github.com/containerd/platforms" + "github.com/docker/buildx/build/resolver" "github.com/docker/buildx/builder" "github.com/docker/buildx/util/progress" ocispecs "github.com/opencontainers/image-spec/specs-go/v1" @@ -28,27 +29,25 @@ func Dial(ctx context.Context, nodes []builder.Node, pw progress.Writer, platfor pls = []ocispecs.Platform{*platform} } - opts := map[string]Options{"default": {Platforms: pls}} - resolved, err := resolveDrivers(ctx, nodes, opts, pw) + resolved, err := resolver.Resolve(ctx, nodes, pls, pw) if err != nil { return nil, err } var dialError error - for _, ls := range resolved { - for _, rn := range ls { - if platform != nil { - if !slices.ContainsFunc(rn.platforms, platforms.Only(*platform).Match) { - continue - } + for _, rnode := range resolved { + if platform != nil { + if !slices.ContainsFunc(rnode.Platforms(), platforms.Only(*platform).Match) { + continue } + } - conn, err := nodes[rn.driverIndex].Driver.Dial(ctx) - if err == nil { - return conn, nil - } - dialError = stderrors.Join(err) + driver := rnode.Node().Driver + conn, err := driver.Dial(ctx) + if err == nil { + return conn, nil } + dialError = stderrors.Join(err) } return nil, errors.Wrap(dialError, "no nodes available") diff --git a/build/driver.go b/build/resolver/driver.go similarity index 81% rename from build/driver.go rename to build/resolver/driver.go index 6eab6db8d13b..24d4eda34d9a 100644 --- a/build/driver.go +++ b/build/resolver/driver.go @@ -1,9 +1,10 @@ -package build +package resolver import ( "context" "fmt" "slices" + "strconv" "sync" "github.com/containerd/platforms" @@ -20,17 +21,42 @@ import ( "golang.org/x/sync/errgroup" ) -type resolvedNode struct { +func Resolve(ctx context.Context, nodes []builder.Node, platforms []ocispecs.Platform, pw progress.Writer) ([]*ResolvedNode, error) { + result, err := ResolveAll(ctx, nodes, map[string][]ocispecs.Platform{"default": platforms}, pw) + if err != nil { + return nil, err + } + return result["default"], nil +} + +func ResolveAll(ctx context.Context, nodes []builder.Node, optPlatforms map[string][]ocispecs.Platform, pw progress.Writer) (map[string][]*ResolvedNode, error) { + driverRes := newDriverResolver(nodes) + drivers, err := driverRes.Resolve(ctx, optPlatforms, pw) + if err != nil { + return nil, err + } + return drivers, err +} + +type ResolvedNode struct { resolver *nodeResolver driverIndex int platforms []ocispecs.Platform } -func (dp resolvedNode) Node() builder.Node { +func (dp ResolvedNode) Key() string { + return strconv.Itoa(dp.driverIndex) +} + +func (dp ResolvedNode) Node() builder.Node { return dp.resolver.nodes[dp.driverIndex] } -func (dp resolvedNode) Client(ctx context.Context) (*client.Client, error) { +func (dp ResolvedNode) Platforms() []ocispecs.Platform { + return dp.platforms +} + +func (dp ResolvedNode) Client(ctx context.Context) (*client.Client, error) { clients, err := dp.resolver.boot(ctx, []int{dp.driverIndex}, nil) if err != nil { return nil, err @@ -38,7 +64,7 @@ func (dp resolvedNode) Client(ctx context.Context) (*client.Client, error) { return clients[0], nil } -func (dp resolvedNode) BuildOpts(ctx context.Context) (gateway.BuildOpts, error) { +func (dp ResolvedNode) BuildOpts(ctx context.Context) (gateway.BuildOpts, error) { opts, err := dp.resolver.opts(ctx, []int{dp.driverIndex}, nil) if err != nil { return gateway.BuildOpts{}, err @@ -66,15 +92,6 @@ type nodeResolver struct { buildOpts cachedGroup[gateway.BuildOpts] } -func resolveDrivers(ctx context.Context, nodes []builder.Node, opt map[string]Options, pw progress.Writer) (map[string][]*resolvedNode, error) { - driverRes := newDriverResolver(nodes) - drivers, err := driverRes.Resolve(ctx, opt, pw) - if err != nil { - return nil, err - } - return drivers, err -} - func newDriverResolver(nodes []builder.Node) *nodeResolver { r := &nodeResolver{ nodes: nodes, @@ -84,14 +101,14 @@ func newDriverResolver(nodes []builder.Node) *nodeResolver { return r } -func (r *nodeResolver) Resolve(ctx context.Context, opt map[string]Options, pw progress.Writer) (map[string][]*resolvedNode, error) { +func (r *nodeResolver) Resolve(ctx context.Context, optPlatforms map[string][]ocispecs.Platform, pw progress.Writer) (map[string][]*ResolvedNode, error) { if len(r.nodes) == 0 { return nil, nil } - nodes := map[string][]*resolvedNode{} - for k, opt := range opt { - node, perfect, err := r.resolve(ctx, opt.Platforms, pw, platforms.OnlyStrict, nil) + nodes := map[string][]*ResolvedNode{} + for k, optPlatforms := range optPlatforms { + node, perfect, err := r.resolve(ctx, optPlatforms, pw, platforms.OnlyStrict, nil) if err != nil { return nil, err } @@ -100,7 +117,7 @@ func (r *nodeResolver) Resolve(ctx context.Context, opt map[string]Options, pw p } nodes[k] = node } - if len(nodes) != len(opt) { + if len(nodes) != len(optPlatforms) { // if we didn't get a perfect match, we need to boot all drivers allIndexes := make([]int, len(r.nodes)) for i := range allIndexes { @@ -143,9 +160,9 @@ func (r *nodeResolver) Resolve(ctx context.Context, opt map[string]Options, pw p // then we can attempt to match against all the available platforms // (this time we don't care about imperfect matches) - nodes = map[string][]*resolvedNode{} - for k, opt := range opt { - node, _, err := r.resolve(ctx, opt.Platforms, pw, platforms.Only, func(idx int, n builder.Node) []ocispecs.Platform { + nodes = map[string][]*ResolvedNode{} + for k, optPlatforms := range optPlatforms { + node, _, err := r.resolve(ctx, optPlatforms, pw, platforms.Only, func(idx int, n builder.Node) []ocispecs.Platform { return workers[idx] }) if err != nil { @@ -173,7 +190,7 @@ func (r *nodeResolver) Resolve(ctx context.Context, opt map[string]Options, pw p return nodes, nil } -func (r *nodeResolver) resolve(ctx context.Context, ps []ocispecs.Platform, pw progress.Writer, matcher matchMaker, additional func(idx int, n builder.Node) []ocispecs.Platform) ([]*resolvedNode, bool, error) { +func (r *nodeResolver) resolve(ctx context.Context, ps []ocispecs.Platform, pw progress.Writer, matcher matchMaker, additional func(idx int, n builder.Node) []ocispecs.Platform) ([]*ResolvedNode, bool, error) { if len(r.nodes) == 0 { return nil, true, nil } @@ -189,16 +206,16 @@ func (r *nodeResolver) resolve(ctx context.Context, ps []ocispecs.Platform, pw p nodeIdxs = append(nodeIdxs, idx) } - var nodes []*resolvedNode + var nodes []*ResolvedNode if len(nodeIdxs) == 0 { - nodes = append(nodes, &resolvedNode{ + nodes = append(nodes, &ResolvedNode{ resolver: r, driverIndex: 0, }) nodeIdxs = append(nodeIdxs, 0) } else { for i, idx := range nodeIdxs { - node := &resolvedNode{ + node := &ResolvedNode{ resolver: r, driverIndex: idx, } @@ -338,8 +355,8 @@ func (r *nodeResolver) opts(ctx context.Context, idxs []int, pw progress.Writer) // recombineDriverPairs recombines resolved nodes that are on the same driver // back together into a single node. -func recombineNodes(nodes []*resolvedNode) []*resolvedNode { - result := make([]*resolvedNode, 0, len(nodes)) +func recombineNodes(nodes []*ResolvedNode) []*ResolvedNode { + result := make([]*ResolvedNode, 0, len(nodes)) lookup := map[int]int{} for _, node := range nodes { if idx, ok := lookup[node.driverIndex]; ok { diff --git a/build/driver_test.go b/build/resolver/driver_test.go similarity index 99% rename from build/driver_test.go rename to build/resolver/driver_test.go index 1fbfeb42310d..f5d365764a81 100644 --- a/build/driver_test.go +++ b/build/resolver/driver_test.go @@ -1,4 +1,4 @@ -package build +package resolver import ( "context" @@ -22,7 +22,7 @@ func TestFindDriverSanity(t *testing.T) { require.Len(t, res, 1) require.Equal(t, 0, res[0].driverIndex) require.Equal(t, "aaa", res[0].Node().Builder) - require.Equal(t, []ocispecs.Platform{platforms.DefaultSpec()}, res[0].platforms) + require.Equal(t, []ocispecs.Platform{platforms.DefaultSpec()}, res[0].Platforms()) } func TestFindDriverEmpty(t *testing.T) { @@ -228,7 +228,7 @@ func TestSelectNodeNoPlatform(t *testing.T) { require.True(t, perfect) require.Len(t, res, 1) require.Equal(t, "aaa", res[0].Node().Builder) - require.Empty(t, res[0].platforms) + require.Empty(t, res[0].Platforms()) } func TestSelectNodeAdditionalPlatforms(t *testing.T) {