diff --git a/database/migrations/000111_data_sources_subscription_id.down.sql b/database/migrations/000111_data_sources_subscription_id.down.sql new file mode 100644 index 0000000000..eab1995143 --- /dev/null +++ b/database/migrations/000111_data_sources_subscription_id.down.sql @@ -0,0 +1,8 @@ +-- SPDX-FileCopyrightText: Copyright 2025 The Minder Authors +-- SPDX-License-Identifier: Apache-2.0 + +BEGIN; + +ALTER TABLE rule_type DROP COLUMN subscription_id; + +COMMIT; diff --git a/database/migrations/000111_data_sources_subscription_id.up.sql b/database/migrations/000111_data_sources_subscription_id.up.sql new file mode 100644 index 0000000000..e88c452290 --- /dev/null +++ b/database/migrations/000111_data_sources_subscription_id.up.sql @@ -0,0 +1,10 @@ +-- SPDX-FileCopyrightText: Copyright 2025 The Minder Authors +-- SPDX-License-Identifier: Apache-2.0 + +BEGIN; + +ALTER TABLE data_sources + ADD COLUMN subscription_id UUID DEFAULT NULL + REFERENCES subscriptions(id); + +COMMIT; \ No newline at end of file diff --git a/database/query/datasources.sql b/database/query/datasources.sql index c03c85dfbc..c1480faaa2 100644 --- a/database/query/datasources.sql +++ b/database/query/datasources.sql @@ -1,8 +1,8 @@ -- CreateDataSource creates a new datasource in a given project. -- name: CreateDataSource :one -INSERT INTO data_sources (project_id, name, display_name) -VALUES ($1, $2, $3) RETURNING *; +INSERT INTO data_sources (project_id, name, display_name, subscription_id) +VALUES ($1, $2, $3, sqlc.narg(subscription_id)) RETURNING *; -- AddDataSourceFunction adds a function to a datasource. diff --git a/internal/controlplane/handlers_datasource.go b/internal/controlplane/handlers_datasource.go index 5a1837abf0..d7d65bc87c 100644 --- a/internal/controlplane/handlers_datasource.go +++ b/internal/controlplane/handlers_datasource.go @@ -36,7 +36,7 @@ func (s *Server) CreateDataSource(ctx context.Context, } // Process the request - ret, err := s.dataSourcesService.Create(ctx, dsReq, nil) + ret, err := s.dataSourcesService.Create(ctx, uuid.Nil, dsReq, nil) if err != nil { return nil, err } @@ -167,7 +167,7 @@ func (s *Server) UpdateDataSource(ctx context.Context, } // Process the request - ret, err := s.dataSourcesService.Update(ctx, dsReq, nil) + ret, err := s.dataSourcesService.Update(ctx, uuid.Nil, dsReq, nil) if err != nil { return nil, err } diff --git a/internal/controlplane/handlers_datasource_test.go b/internal/controlplane/handlers_datasource_test.go index 1593da1f93..9fcc7cccd0 100644 --- a/internal/controlplane/handlers_datasource_test.go +++ b/internal/controlplane/handlers_datasource_test.go @@ -37,7 +37,7 @@ func TestCreateDataSource(t *testing.T) { setupMocks: func(dsService *mock_service.MockDataSourcesService, featureClient *flags.FakeClient) { featureClient.Data = map[string]interface{}{"data_sources": true} dsService.EXPECT(). - Create(gomock.Any(), gomock.Any(), gomock.Any()). + Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(&minderv1.DataSource{Name: "test-ds"}, nil) }, request: &minderv1.CreateDataSourceRequest{ @@ -417,7 +417,7 @@ func TestUpdateDataSource(t *testing.T) { setupMocks: func(dsService *mock_service.MockDataSourcesService, featureClient *flags.FakeClient, _ *mockdb.MockStore) { featureClient.Data = map[string]interface{}{"data_sources": true} dsService.EXPECT(). - Update(gomock.Any(), gomock.Any(), gomock.Any()). + Update(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(&minderv1.DataSource{Id: dsIDStr, Name: "updated-ds"}, nil) }, request: &minderv1.UpdateDataSourceRequest{ diff --git a/internal/datasources/service/helpers.go b/internal/datasources/service/helpers.go index 58111e86e5..e6c7156eee 100644 --- a/internal/datasources/service/helpers.go +++ b/internal/datasources/service/helpers.go @@ -5,6 +5,7 @@ package service import ( "context" + "database/sql" "errors" "fmt" @@ -17,6 +18,39 @@ import ( minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" ) +var ( + getByNameQuery = func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID, name string) (db.DataSource, error) { + ds, err := tx.GetDataSourceByName(ctx, db.GetDataSourceByNameParams{ + Name: name, + Projects: projs, + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return db.DataSource{}, util.UserVisibleError(codes.NotFound, + "data source of name %s not found", name) + } + return db.DataSource{}, fmt.Errorf("failed to get data source by name: %w", err) + } + + return ds, nil + } + getByIDQuery = func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID, id uuid.UUID) (db.DataSource, error) { + ds, err := tx.GetDataSource(ctx, db.GetDataSourceParams{ + ID: id, + Projects: projs, + }) + if errors.Is(err, sql.ErrNoRows) { + return db.DataSource{}, util.UserVisibleError(codes.NotFound, + "data source of id %s not found", id.String()) + } + if err != nil { + return db.DataSource{}, fmt.Errorf("failed to get data source by name: %w", err) + } + + return ds, nil + } +) + func (d *dataSourceService) getDataSourceSomehow( ctx context.Context, project uuid.UUID, @@ -33,11 +67,35 @@ func (d *dataSourceService) getDataSourceSomehow( tx := stx.Q() + ds, err := getDataSourceFromDb(ctx, project, opts, tx, theSomehow) + if err != nil { + return nil, fmt.Errorf("failed to get data source from DB: %w", err) + } + + dsfuncs, err := getDataSourceFunctions(ctx, tx, ds) + if err != nil { + return nil, fmt.Errorf("failed to get data source functions: %w", err) + } + + if err := stx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit transaction: %w", err) + } + + return dataSourceDBToProtobuf(*ds, dsfuncs) +} + +func getDataSourceFromDb( + ctx context.Context, + project uuid.UUID, + opts *ReadOptions, + qtx db.ExtendQuerier, + dbQuery func(ctx context.Context, qtx db.ExtendQuerier, projs []uuid.UUID) (db.DataSource, error), +) (*db.DataSource, error) { var projs []uuid.UUID if len(opts.hierarchy) > 0 { projs = opts.hierarchy } else { - prjs, err := listRelevantProjects(ctx, tx, project, opts.canSearchHierarchical()) + prjs, err := listRelevantProjects(ctx, qtx, project, opts.canSearchHierarchical()) if err != nil { return nil, fmt.Errorf("failed to list relevant projects: %w", err) } @@ -45,11 +103,19 @@ func (d *dataSourceService) getDataSourceSomehow( projs = prjs } - ds, err := theSomehow(ctx, tx, projs) + ds, err := dbQuery(ctx, qtx, projs) if err != nil { - return nil, fmt.Errorf("failed to get data source by name: %w", err) + return nil, fmt.Errorf("failed to get data source from DB: %w", err) } + return &ds, nil +} + +func getDataSourceFunctions( + ctx context.Context, + tx db.ExtendQuerier, + ds *db.DataSource, +) ([]db.DataSourcesFunction, error) { dsfuncs, err := tx.ListDataSourceFunctions(ctx, db.ListDataSourceFunctionsParams{ DataSourceID: ds.ID, ProjectID: ds.ProjectID, @@ -66,11 +132,7 @@ func (d *dataSourceService) getDataSourceSomehow( return nil, errors.New("data source has no functions") } - if err := stx.Commit(); err != nil { - return nil, fmt.Errorf("failed to commit transaction: %w", err) - } - - return dataSourceDBToProtobuf(ds, dsfuncs) + return dsfuncs, nil } func (d *dataSourceService) instantiateDataSource( @@ -110,9 +172,15 @@ func listRelevantProjects( } func validateDataSourceFunctionsUpdate( - existingDS, newDS *minderv1.DataSource, + existingDS *db.DataSource, existingFunctions []db.DataSourcesFunction, newDS *minderv1.DataSource, ) error { - existingImpl, err := datasources.BuildFromProtobuf(existingDS) + existingDsProto, err := dataSourceDBToProtobuf(*existingDS, existingFunctions) + if err != nil { + // If we got here, it means the existing data source is invalid. + return fmt.Errorf("failed to convert data source to protobuf: %w", err) + } + + existingImpl, err := datasources.BuildFromProtobuf(existingDsProto) if err != nil { // If we got here, it means the existing data source is invalid. return fmt.Errorf("failed to build data source from protobuf: %w", err) diff --git a/internal/datasources/service/mock/service.go b/internal/datasources/service/mock/service.go index 73e38dacf2..9b8d142531 100644 --- a/internal/datasources/service/mock/service.go +++ b/internal/datasources/service/mock/service.go @@ -60,18 +60,18 @@ func (mr *MockDataSourcesServiceMockRecorder) BuildDataSourceRegistry(ctx, rt, o } // Create mocks base method. -func (m *MockDataSourcesService) Create(ctx context.Context, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { +func (m *MockDataSourcesService) Create(ctx context.Context, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", ctx, ds, opts) + ret := m.ctrl.Call(m, "Create", ctx, subscriptionID, ds, opts) ret0, _ := ret[0].(*v1.DataSource) ret1, _ := ret[1].(error) return ret0, ret1 } // Create indicates an expected call of Create. -func (mr *MockDataSourcesServiceMockRecorder) Create(ctx, ds, opts any) *gomock.Call { +func (mr *MockDataSourcesServiceMockRecorder) Create(ctx, subscriptionID, ds, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockDataSourcesService)(nil).Create), ctx, ds, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockDataSourcesService)(nil).Create), ctx, subscriptionID, ds, opts) } // Delete mocks base method. @@ -134,16 +134,16 @@ func (mr *MockDataSourcesServiceMockRecorder) List(ctx, project, opts any) *gomo } // Update mocks base method. -func (m *MockDataSourcesService) Update(ctx context.Context, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { +func (m *MockDataSourcesService) Update(ctx context.Context, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, ds, opts) + ret := m.ctrl.Call(m, "Update", ctx, subscriptionID, ds, opts) ret0, _ := ret[0].(*v1.DataSource) ret1, _ := ret[1].(error) return ret0, ret1 } // Update indicates an expected call of Update. -func (mr *MockDataSourcesServiceMockRecorder) Update(ctx, ds, opts any) *gomock.Call { +func (mr *MockDataSourcesServiceMockRecorder) Update(ctx, subscriptionID, ds, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDataSourcesService)(nil).Update), ctx, ds, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDataSourcesService)(nil).Update), ctx, subscriptionID, ds, opts) } diff --git a/internal/datasources/service/service.go b/internal/datasources/service/service.go index 375f0eadc9..53787e5fcd 100644 --- a/internal/datasources/service/service.go +++ b/internal/datasources/service/service.go @@ -17,6 +17,7 @@ import ( "github.com/mindersec/minder/internal/datasources" "github.com/mindersec/minder/internal/db" + "github.com/mindersec/minder/internal/marketplaces/namespaces" "github.com/mindersec/minder/internal/util" minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" v1datasources "github.com/mindersec/minder/pkg/datasources/v1" @@ -36,10 +37,10 @@ type DataSourcesService interface { List(ctx context.Context, project uuid.UUID, opts *ReadOptions) ([]*minderv1.DataSource, error) // Create creates a new data source. - Create(ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) + Create(ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) // Update updates an existing data source. - Update(ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) + Update(ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) // Delete deletes a data source in the given project. // @@ -82,19 +83,7 @@ func (d *dataSourceService) GetByName( return d.getDataSourceSomehow( ctx, project, opts, func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID, ) (db.DataSource, error) { - ds, err := tx.GetDataSourceByName(ctx, db.GetDataSourceByNameParams{ - Name: name, - Projects: projs, - }) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return db.DataSource{}, util.UserVisibleError(codes.NotFound, - "data source of name %s not found", name) - } - return db.DataSource{}, fmt.Errorf("failed to get data source by name: %w", err) - } - - return ds, nil + return getByNameQuery(ctx, tx, projs, name) }) } @@ -103,19 +92,7 @@ func (d *dataSourceService) GetByID( return d.getDataSourceSomehow( ctx, project, opts, func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID, ) (db.DataSource, error) { - ds, err := tx.GetDataSource(ctx, db.GetDataSourceParams{ - ID: id, - Projects: projs, - }) - if errors.Is(err, sql.ErrNoRows) { - return db.DataSource{}, util.UserVisibleError(codes.NotFound, - "data source of id %s not found", id.String()) - } - if err != nil { - return db.DataSource{}, fmt.Errorf("failed to get data source by name: %w", err) - } - - return ds, nil + return getByIDQuery(ctx, tx, projs, id) }) } @@ -173,11 +150,15 @@ func (d *dataSourceService) List( // We first validate the data source name uniqueness, then create the data source record. // Finally, we create function records based on the driver type. func (d *dataSourceService) Create( - ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { + ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { if err := ds.Validate(); err != nil { return nil, fmt.Errorf("data source validation failed: %w", err) } + if err := namespaces.ValidateNamespacedNameRules(ds.GetName(), subscriptionID); err != nil { + return nil, fmt.Errorf("data source validation failed: %w", err) + } + stx, err := d.txBuilder(d, opts) if err != nil { return nil, fmt.Errorf("failed to start transaction: %w", err) @@ -216,9 +197,10 @@ func (d *dataSourceService) Create( // Create data source record dsRecord, err := tx.CreateDataSource(ctx, db.CreateDataSourceParams{ - ProjectID: projectID, - Name: ds.GetName(), - DisplayName: ds.GetName(), + ProjectID: projectID, + Name: ds.GetName(), + DisplayName: ds.GetName(), + SubscriptionID: uuid.NullUUID{UUID: subscriptionID, Valid: subscriptionID != uuid.Nil}, }) if err != nil { return nil, fmt.Errorf("failed to create data source: %w", err) @@ -244,7 +226,7 @@ func (d *dataSourceService) Create( // because it's simpler and safer - it ensures consistency and avoids partial updates. // All functions must use the same driver type to maintain data source integrity. func (d *dataSourceService) Update( - ctx context.Context, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { + ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { if err := ds.Validate(); err != nil { return nil, fmt.Errorf("data source validation failed: %w", err) } @@ -268,24 +250,29 @@ func (d *dataSourceService) Update( return nil, fmt.Errorf("invalid project ID: %w", err) } - // Build existing data source - existingDS, err := d.GetByName(ctx, ds.GetName(), projectID, ReadBuilder().WithTransaction(tx)) + // Validate the subscription ID if present + existingDS, err := getDataSourceFromDb(ctx, projectID, ReadBuilder().WithTransaction(tx), tx, + func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID) (db.DataSource, error) { + return getByNameQuery(ctx, tx, projs, ds.GetName()) + }) if err != nil { - return nil, fmt.Errorf("failed to get existing data source: %w", err) + return nil, fmt.Errorf("failed to get existing data source from DB: %w", err) + } + if err = namespaces.DoesSubscriptionIDMatch(subscriptionID, existingDS.SubscriptionID); err != nil { + return nil, fmt.Errorf("failed to update data source: %w", err) } - existingDSID, err := uuid.Parse(existingDS.Id) + // Validate the data source functions update + existingFunctions, err := getDataSourceFunctions(ctx, tx, existingDS) if err != nil { - // This should not happen - return nil, fmt.Errorf("invalid data source ID: %w", err) + return nil, fmt.Errorf("failed to get existing data source functions: %w", err) } - - if err := validateDataSourceFunctionsUpdate(existingDS, ds); err != nil { + if err := validateDataSourceFunctionsUpdate(existingDS, existingFunctions, ds); err != nil { return nil, err } if _, err := tx.UpdateDataSource(ctx, db.UpdateDataSourceParams{ - ID: existingDSID, + ID: existingDS.ID, ProjectID: projectID, DisplayName: ds.GetName(), }); err != nil { @@ -293,13 +280,13 @@ func (d *dataSourceService) Update( } if _, err := tx.DeleteDataSourceFunctions(ctx, db.DeleteDataSourceFunctionsParams{ - DataSourceID: existingDSID, + DataSourceID: existingDS.ID, ProjectID: projectID, }); err != nil { return nil, fmt.Errorf("failed to delete existing functions: %w", err) } - if err := addDataSourceFunctions(ctx, tx, ds, existingDSID, projectID); err != nil { + if err := addDataSourceFunctions(ctx, tx, ds, existingDS.ID, projectID); err != nil { return nil, fmt.Errorf("failed to create data source functions: %w", err) } @@ -308,7 +295,7 @@ func (d *dataSourceService) Update( } if ds.Id == "" { - ds.Id = existingDSID.String() + ds.Id = existingDS.ID.String() } return ds, nil @@ -348,6 +335,19 @@ func (d *dataSourceService) Delete( "data source %s is in use by the following rule types: %v", id, existingRefs) } + // We don't support the deletion of bundle data sources + existingDS, err := getDataSourceFromDb(ctx, project, ReadBuilder().WithTransaction(tx), tx, + func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID) (db.DataSource, error) { + return getByIDQuery(ctx, tx, projs, id) + }) + if err != nil { + return fmt.Errorf("failed to get data source with id %s: %w", id, err) + } + if existingDS.SubscriptionID.Valid { + return util.UserVisibleError(codes.FailedPrecondition, + "data source %s cannot be deleted as it is part of a bundle", id) + } + // Delete the data source record _, err = tx.DeleteDataSource(ctx, db.DeleteDataSourceParams{ ID: id, diff --git a/internal/datasources/service/service_test.go b/internal/datasources/service/service_test.go index 69cc0b887a..297a0b0930 100644 --- a/internal/datasources/service/service_test.go +++ b/internal/datasources/service/service_test.go @@ -24,26 +24,29 @@ import ( v1 "github.com/mindersec/minder/pkg/datasources/v1" ) -var validRESTDriverFixture = &minderv1.DataSource_Rest{ - Rest: &minderv1.RestDataSource{ - Def: map[string]*minderv1.RestDataSource_Def{ - "test_function": { - Endpoint: "http://example.com", - InputSchema: func() *structpb.Struct { - s, _ := structpb.NewStruct(map[string]any{ - "type": "object", - "properties": map[string]any{ - "test": map[string]any{ - "type": "string", +var ( + subscriptionID = uuid.New() + validRESTDriverFixture = &minderv1.DataSource_Rest{ + Rest: &minderv1.RestDataSource{ + Def: map[string]*minderv1.RestDataSource_Def{ + "test_function": { + Endpoint: "http://example.com", + InputSchema: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]any{ + "type": "object", + "properties": map[string]any{ + "test": map[string]any{ + "type": "string", + }, }, - }, - }) - return s - }(), + }) + return s + }(), + }, }, }, - }, -} + } +) func TestGetByName(t *testing.T) { t.Parallel() @@ -413,8 +416,9 @@ func TestCreate(t *testing.T) { t.Parallel() type args struct { - ds *minderv1.DataSource - opts *Options + ds *minderv1.DataSource + opts *Options + subscriptionId uuid.UUID } tests := []struct { name string @@ -456,6 +460,22 @@ func TestCreate(t *testing.T) { }, wantErr: false, }, + { + name: "Invalid namespace name", + args: args{ + ds: &minderv1.DataSource{ + Name: "name-with-no-namespace", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: validRESTDriverFixture, + }, + subscriptionId: subscriptionID, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, { name: "Nil data source", args: args{ @@ -534,7 +554,7 @@ func TestCreate(t *testing.T) { } tt.setup(mockStore) - got, err := svc.Create(context.Background(), tt.args.ds, tt.args.opts) + got, err := svc.Create(context.Background(), tt.args.subscriptionId, tt.args.ds, tt.args.opts) if tt.wantErr { assert.Error(t, err) return @@ -903,6 +923,13 @@ func TestDelete(t *testing.T) { ListRuleTypesReferencesByDataSource(gomock.Any(), args.id). Return([]db.RuleTypeDataSource{}, nil) + mockDB.EXPECT(). + GetDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: args.id, + SubscriptionID: uuid.NullUUID{Valid: false}, + }, nil) + // Mock DeleteDataSource to succeed mockDB.EXPECT(). DeleteDataSource(gomock.Any(), gomock.Eq(db.DeleteDataSourceParams{ @@ -926,6 +953,13 @@ func TestDelete(t *testing.T) { ListRuleTypesReferencesByDataSource(gomock.Any(), args.id). Return([]db.RuleTypeDataSource{}, nil) + mockDB.EXPECT(). + GetDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: args.id, + SubscriptionID: uuid.NullUUID{Valid: false}, + }, nil) + // Mock DeleteDataSource to return sql.ErrNoRows mockDB.EXPECT(). DeleteDataSource(gomock.Any(), gomock.Eq(db.DeleteDataSourceParams{ @@ -953,6 +987,26 @@ func TestDelete(t *testing.T) { }, wantErr: true, }, + { + name: "Data source is part of a bundle", + args: args{ + id: uuid.New(), + project: uuid.New(), + opts: &Options{}, + }, + setup: func(args args, mockDB *mockdb.MockStore) { + mockDB.EXPECT(). + ListRuleTypesReferencesByDataSource(gomock.Any(), args.id). + Return([]db.RuleTypeDataSource{}, nil) + mockDB.EXPECT(). + GetDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: args.id, + SubscriptionID: uuid.NullUUID{Valid: true, UUID: subscriptionID}, + }, nil) + }, + wantErr: true, + }, { name: "Database error when listing references", args: args{ @@ -980,6 +1034,12 @@ func TestDelete(t *testing.T) { mockDB.EXPECT(). ListRuleTypesReferencesByDataSource(gomock.Any(), args.id). Return([]db.RuleTypeDataSource{}, nil) + mockDB.EXPECT(). + GetDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: args.id, + SubscriptionID: uuid.NullUUID{Valid: false}, + }, nil) // Mock DeleteDataSource to return an error mockDB.EXPECT(). @@ -1028,8 +1088,9 @@ func TestUpdate(t *testing.T) { t.Parallel() type args struct { - ds *minderv1.DataSource - opts *Options + ds *minderv1.DataSource + opts *Options + subscriptionId uuid.UUID } tests := []struct { name string @@ -1120,6 +1181,103 @@ func TestUpdate(t *testing.T) { }, wantErr: false, }, + { + name: "Successfully update REST data source with matching subscription ID", + args: args{ + ds: &minderv1.DataSource{ + Id: uuid.New().String(), + Name: "updated_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: &minderv1.DataSource_Rest{ + Rest: &minderv1.RestDataSource{ + Def: map[string]*minderv1.RestDataSource_Def{ + "test_function": { + Endpoint: "http://example.com/updated", + InputSchema: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]any{}) + return s + }(), + }, + }, + }, + }, + }, + subscriptionId: subscriptionID, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.MustParse(uuid.New().String()), + Name: "test_ds", + SubscriptionID: uuid.NullUUID{Valid: true, UUID: subscriptionID}, + }, nil) + mockDB.EXPECT().ListDataSourceFunctions(gomock.Any(), gomock.Any()). + Return([]db.DataSourcesFunction{ + { + ID: uuid.New(), + DataSourceID: uuid.New(), + Name: "test_function", + Type: v1.DataSourceDriverRest, + Definition: restDriverToJson(t, &minderv1.RestDataSource_Def{}), + }, + }, nil) + + mockDB.EXPECT().UpdateDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.New(), + Name: "updated_ds", + }, nil) + + mockDB.EXPECT().DeleteDataSourceFunctions(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockDB.EXPECT().AddDataSourceFunction(gomock.Any(), gomock.Any()). + Return(db.DataSourcesFunction{}, nil) + }, + want: &minderv1.DataSource{ + Name: "updated_ds", + }, + wantErr: false, + }, + { + name: "Non-matching subscription ID", + args: args{ + ds: &minderv1.DataSource{ + Id: uuid.New().String(), + Name: "updated_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: &minderv1.DataSource_Rest{ + Rest: &minderv1.RestDataSource{ + Def: map[string]*minderv1.RestDataSource_Def{ + "test_function": { + Endpoint: "http://example.com/updated", + InputSchema: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]any{}) + return s + }(), + }, + }, + }, + }, + }, + subscriptionId: uuid.New(), + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.MustParse(uuid.New().String()), + Name: "test_ds", + SubscriptionID: uuid.NullUUID{Valid: true, UUID: subscriptionID}, + }, nil) + }, + wantErr: true, + }, { name: "Nil data source", args: args{ @@ -1377,7 +1535,7 @@ func TestUpdate(t *testing.T) { tt.setup(mockStore) - got, err := svc.Update(context.Background(), tt.args.ds, tt.args.opts) + got, err := svc.Update(context.Background(), tt.args.subscriptionId, tt.args.ds, tt.args.opts) if tt.wantErr { assert.Error(t, err) return diff --git a/internal/db/datasources.sql.go b/internal/db/datasources.sql.go index fc9d4e92ff..2b0656510f 100644 --- a/internal/db/datasources.sql.go +++ b/internal/db/datasources.sql.go @@ -73,19 +73,25 @@ func (q *Queries) AddRuleTypeDataSourceReference(ctx context.Context, arg AddRul const createDataSource = `-- name: CreateDataSource :one -INSERT INTO data_sources (project_id, name, display_name) -VALUES ($1, $2, $3) RETURNING id, name, display_name, project_id, created_at, updated_at +INSERT INTO data_sources (project_id, name, display_name, subscription_id) +VALUES ($1, $2, $3, $4) RETURNING id, name, display_name, project_id, created_at, updated_at, subscription_id ` type CreateDataSourceParams struct { - ProjectID uuid.UUID `json:"project_id"` - Name string `json:"name"` - DisplayName string `json:"display_name"` + ProjectID uuid.UUID `json:"project_id"` + Name string `json:"name"` + DisplayName string `json:"display_name"` + SubscriptionID uuid.NullUUID `json:"subscription_id"` } // CreateDataSource creates a new datasource in a given project. func (q *Queries) CreateDataSource(ctx context.Context, arg CreateDataSourceParams) (DataSource, error) { - row := q.db.QueryRowContext(ctx, createDataSource, arg.ProjectID, arg.Name, arg.DisplayName) + row := q.db.QueryRowContext(ctx, createDataSource, + arg.ProjectID, + arg.Name, + arg.DisplayName, + arg.SubscriptionID, + ) var i DataSource err := row.Scan( &i.ID, @@ -94,6 +100,7 @@ func (q *Queries) CreateDataSource(ctx context.Context, arg CreateDataSourcePara &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ) return i, err } @@ -101,7 +108,7 @@ func (q *Queries) CreateDataSource(ctx context.Context, arg CreateDataSourcePara const deleteDataSource = `-- name: DeleteDataSource :one DELETE FROM data_sources WHERE id = $1 AND project_id = $2 -RETURNING id, name, display_name, project_id, created_at, updated_at +RETURNING id, name, display_name, project_id, created_at, updated_at, subscription_id ` type DeleteDataSourceParams struct { @@ -119,6 +126,7 @@ func (q *Queries) DeleteDataSource(ctx context.Context, arg DeleteDataSourcePara &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ) return i, err } @@ -215,7 +223,7 @@ func (q *Queries) DeleteRuleTypeDataSource(ctx context.Context, arg DeleteRuleTy const getDataSource = `-- name: GetDataSource :one -SELECT id, name, display_name, project_id, created_at, updated_at FROM data_sources +SELECT id, name, display_name, project_id, created_at, updated_at, subscription_id FROM data_sources WHERE id = $1 AND project_id = ANY($2::uuid[]) ` @@ -238,13 +246,14 @@ func (q *Queries) GetDataSource(ctx context.Context, arg GetDataSourceParams) (D &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ) return i, err } const getDataSourceByName = `-- name: GetDataSourceByName :one -SELECT id, name, display_name, project_id, created_at, updated_at FROM data_sources +SELECT id, name, display_name, project_id, created_at, updated_at, subscription_id FROM data_sources WHERE name = $1 AND project_id = ANY($2::uuid[]) ` @@ -268,6 +277,7 @@ func (q *Queries) GetDataSourceByName(ctx context.Context, arg GetDataSourceByNa &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ) return i, err } @@ -318,7 +328,7 @@ func (q *Queries) ListDataSourceFunctions(ctx context.Context, arg ListDataSourc const listDataSources = `-- name: ListDataSources :many -SELECT id, name, display_name, project_id, created_at, updated_at FROM data_sources +SELECT id, name, display_name, project_id, created_at, updated_at, subscription_id FROM data_sources WHERE project_id = ANY($1::uuid[]) ` @@ -342,6 +352,7 @@ func (q *Queries) ListDataSources(ctx context.Context, projects []uuid.UUID) ([] &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ); err != nil { return nil, err } @@ -391,7 +402,7 @@ const updateDataSource = `-- name: UpdateDataSource :one UPDATE data_sources SET display_name = $3 WHERE id = $1 AND project_id = $2 -RETURNING id, name, display_name, project_id, created_at, updated_at +RETURNING id, name, display_name, project_id, created_at, updated_at, subscription_id ` type UpdateDataSourceParams struct { @@ -411,6 +422,7 @@ func (q *Queries) UpdateDataSource(ctx context.Context, arg UpdateDataSourcePara &i.ProjectID, &i.CreatedAt, &i.UpdatedAt, + &i.SubscriptionID, ) return i, err } diff --git a/internal/db/models.go b/internal/db/models.go index 8011451821..847c31306b 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -496,12 +496,13 @@ type Bundle struct { } type DataSource struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - DisplayName string `json:"display_name"` - ProjectID uuid.UUID `json:"project_id"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + DisplayName string `json:"display_name"` + ProjectID uuid.UUID `json:"project_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + SubscriptionID uuid.NullUUID `json:"subscription_id"` } type DataSourcesFunction struct { diff --git a/internal/marketplaces/namespaces/validation.go b/internal/marketplaces/namespaces/validation.go index ac8d316aef..130524b12d 100644 --- a/internal/marketplaces/namespaces/validation.go +++ b/internal/marketplaces/namespaces/validation.go @@ -15,18 +15,18 @@ import ( // these functions are tested through the tests for RuleTypeService -// ValidateNamespacedNameRules takes a name for a new profile or rule type and +// ValidateNamespacedNameRules takes a name for a new profile, rule type or data source and // asserts that: // A) If the subscriptionID is empty, there name should not be namespaced // B) If subscriptionID is not empty, the name must be namespaced // This assumes the name has already been validated against the other -// validation rules for profile and rule type names. +// validation rules for profile, rule type and data source names. func ValidateNamespacedNameRules(name string, subscriptionID uuid.UUID) error { hasNamespace := strings.Contains(name, "/") if hasNamespace && subscriptionID == uuid.Nil { - return errors.New("cannot create a rule type or profile with a namespace through the API") + return errors.New("cannot create a rule type, data source or profile with a namespace through the API") } else if !hasNamespace && subscriptionID != uuid.Nil { - return errors.New("rule types and profiles from subscriptions must have namespaced names") + return errors.New("rule types, data sources and profiles from subscriptions must have namespaced names") } // in future, we may want to check that the namespace in the profile/rule diff --git a/pkg/ruletypes/service_test.go b/pkg/ruletypes/service_test.go index 8d19518a83..3080aa7eac 100644 --- a/pkg/ruletypes/service_test.go +++ b/pkg/ruletypes/service_test.go @@ -69,14 +69,14 @@ func TestRuleTypeService(t *testing.T) { { Name: "CreateRuleType rejects attempt to create a namespaced rule when no subscription ID is passed", RuleType: newRuleType(withBasicStructure, withRuleName(namespacedRuleName)), - ExpectedError: "cannot create a rule type or profile with a namespace through the API", + ExpectedError: "cannot create a rule type, data source or profile with a namespace through the API", DBSetup: dbf.NewDBMock(), TestMethod: create, }, { Name: "CreateRuleType rejects attempt to create a non-namespaced rule when no subscription ID is passed", RuleType: newRuleType(withBasicStructure), - ExpectedError: "rule types and profiles from subscriptions must have namespaced names", + ExpectedError: "rule types, data sources and profiles from subscriptions must have namespaced names", DBSetup: dbf.NewDBMock(), SubscriptionID: subscriptionID, TestMethod: create,