From 30c43c0d50fc9254014dff2308b8faf147d34453 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 23 Feb 2022 18:14:44 +0200 Subject: [PATCH] sql/mysql: mysql implementation for schema.Normalizer --- sql/mysql/normalize.go | 101 ++++++++++++++++++++++++++++++++++++ sql/mysql/normalize_test.go | 62 ++++++++++++++++++++++ sql/schema/inspect.go | 12 +++++ 3 files changed, 175 insertions(+) create mode 100644 sql/mysql/normalize.go create mode 100644 sql/mysql/normalize_test.go diff --git a/sql/mysql/normalize.go b/sql/mysql/normalize.go new file mode 100644 index 00000000000..9b70e748f3e --- /dev/null +++ b/sql/mysql/normalize.go @@ -0,0 +1,101 @@ +// Copyright 2021-present The Atlas Authors. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package mysql + +import ( + "context" + "crypto/md5" + "fmt" + "time" + + "ariga.io/atlas/sql/schema" +) + +// NormalizeRealm returns the normal representation of the given database. +func (d *Driver) NormalizeRealm(ctx context.Context, r *schema.Realm) (nr *schema.Realm, err error) { + for _, s := range r.Schemas { + switch s.Name { + case "mysql", "information_schema", "performance_schema", "sys": + return nil, fmt.Errorf("sql/mysql: normalizing internal schema %q is not supported", s.Name) + } + } + var ( + twins = make(map[string]string) + changes = make([]schema.Change, 0, len(r.Schemas)) + reverse = make([]schema.Change, 0, len(r.Schemas)) + opts = &schema.InspectRealmOption{ + Schemas: make([]string, 0, len(r.Schemas)), + } + ) + for _, s := range r.Schemas { + twin := twinName(s.Name) + twins[twin] = s.Name + s.Name = twin + opts.Schemas = append(opts.Schemas, s.Name) + // Skip adding the schema.IfNotExists clause + // to fail if the schema exists. + st := schema.New(twin).AddAttrs(s.Attrs...) + changes = append(changes, &schema.AddSchema{S: st}) + reverse = append(reverse, &schema.DropSchema{S: st, Extra: []schema.Clause{&schema.IfExists{}}}) + for _, t := range s.Tables { + // If objects are not strongly connected. + if t.Schema != s { + t.Schema = s + } + changes = append(changes, &schema.AddTable{T: t}) + } + } + patch := func(r *schema.Realm) { + for _, s := range r.Schemas { + s.Name = twins[s.Name] + } + } + // Delete the twin resources, and return + // the source realm to its initial state. + defer func() { + patch(r) + uerr := d.ApplyChanges(ctx, reverse) + if err != nil { + err = fmt.Errorf("%w: %v", err, uerr) + } + err = uerr + }() + if err := d.ApplyChanges(ctx, changes); err != nil { + return nil, err + } + if nr, err = d.InspectRealm(ctx, opts); err != nil { + return nil, err + } + patch(nr) + return nr, nil +} + +// NormalizeSchema returns the normal representation of the given database. +func (d *Driver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) { + r := &schema.Realm{} + if s.Realm != nil { + r.Attrs = s.Realm.Attrs + } + r.Schemas = append(r.Schemas, s) + nr, err := d.NormalizeRealm(ctx, r) + if err != nil { + return nil, err + } + ns, ok := nr.Schema(s.Name) + if !ok { + return nil, fmt.Errorf("sql/mysql: missing normalized schema %q", s.Name) + } + return ns, nil +} + +const maxLen = 64 + +func twinName(name string) string { + twin := fmt.Sprintf("atlas_twin_%s_%d", name, time.Now().Unix()) + if len(twin) <= maxLen { + return twin + } + return fmt.Sprintf("%s_%x", twin[:maxLen-33], md5.Sum([]byte(twin))) +} diff --git a/sql/mysql/normalize_test.go b/sql/mysql/normalize_test.go new file mode 100644 index 00000000000..24aae6a45de --- /dev/null +++ b/sql/mysql/normalize_test.go @@ -0,0 +1,62 @@ +// Copyright 2021-present The Atlas Authors. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package mysql + +import ( + "context" + "strings" + "testing" + + "ariga.io/atlas/sql/migrate" + "ariga.io/atlas/sql/schema" + + "github.com/stretchr/testify/require" +) + +func TestDriver_NormalizeRealm(t *testing.T) { + var ( + apply = &mockApply{} + inspect = &mockInspect{ + realm: schema.NewRealm(schema.New("test").SetCharset("utf8mb4")), + } + drv = &Driver{ + Inspector: inspect, + PlanApplier: apply, + } + ) + normal, err := drv.NormalizeRealm(context.Background(), schema.NewRealm(schema.New("test"))) + require.NoError(t, err) + require.Equal(t, normal, inspect.realm) + + require.Len(t, inspect.schemas, 1) + require.True(t, strings.HasPrefix(inspect.schemas[0], "atlas_twin_test_")) + + require.Len(t, apply.changes, 2, "expect 2 calls (create and drop)") + require.Len(t, apply.changes[0], 1) + require.Equal(t, &schema.AddSchema{S: schema.New(inspect.schemas[0])}, apply.changes[0][0]) + require.Len(t, apply.changes[1], 1) + require.Equal(t, &schema.DropSchema{S: schema.New(inspect.schemas[0]), Extra: []schema.Clause{&schema.IfExists{}}}, apply.changes[1][0]) +} + +type mockInspect struct { + schema.Inspector + schemas []string + realm *schema.Realm +} + +func (m *mockInspect) InspectRealm(_ context.Context, opts *schema.InspectRealmOption) (*schema.Realm, error) { + m.schemas = append(m.schemas, opts.Schemas...) + return m.realm, nil +} + +type mockApply struct { + migrate.PlanApplier + changes [][]schema.Change +} + +func (m *mockApply) ApplyChanges(_ context.Context, changes []schema.Change) error { + m.changes = append(m.changes, changes) + return nil +} diff --git a/sql/schema/inspect.go b/sql/schema/inspect.go index 2ad0deefa60..930cb6b7505 100644 --- a/sql/schema/inspect.go +++ b/sql/schema/inspect.go @@ -64,3 +64,15 @@ type ( InspectRealm(ctx context.Context, opts *InspectRealmOption) (*Realm, error) } ) + +// Normalizer is the interface implemented by the different database drivers for +// "normalizing" schema objects. i.e. converting schema objects defined in natural +// form to their representation in the database. Thus, two schema objects are equal +// if their normal forms are equal. +type Normalizer interface { + // NormalizeSchema returns the normal representation of a schema. + NormalizeSchema(context.Context, *Schema) (*Schema, error) + + // NormalizeRealm returns the normal representation of a database. + NormalizeRealm(context.Context, *Realm) (*Realm, error) +}