Skip to content

Commit

Permalink
supporting path mirror lookup with precedence during repo policy look…
Browse files Browse the repository at this point in the history
…up (ee) (#1591)

* add project prefix directory lookups when loading policies from repo [ee]
  • Loading branch information
motatoes committed Jun 26, 2024
1 parent 77ee0b1 commit 3dd12d5
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 47 deletions.
8 changes: 4 additions & 4 deletions cli/pkg/core/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import (
)

type Provider interface {
GetAccessPolicy(organisation string, repository string, projectname string) (string, error)
GetPlanPolicy(organisation string, repository string, projectname string) (string, error)
GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error)
GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error)
GetDriftPolicy() (string, error)
GetOrganisation() string // TODO: remove this method from here since out of place
}

type Checker interface {
// TODO refactor arguments - use AccessPolicyContext
CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error)
CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, planOutput string) (bool, []string, error)
CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error)
CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error)
CheckDriftPolicy(SCMOrganisation string, SCMrepository string, projectname string) (bool, error)
}

Expand Down
14 changes: 7 additions & 7 deletions cli/pkg/digger/digger.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func RunJobs(jobs []orchestrator.Job, prService orchestrator.PullRequestService,
SCMrepository := splits[1]

for _, command := range job.Commands {
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, job.RequestedBy, []string{})
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, job.RequestedBy, []string{})

if err != nil {
return false, false, fmt.Errorf("error checking policy: %v", err)
Expand Down Expand Up @@ -187,7 +187,7 @@ func reportPolicyError(projectName string, command string, requestedBy string, r
func run(command string, job orchestrator.Job, policyChecker policy.Checker, orgService orchestrator.OrgService, SCMOrganisation string, SCMrepository string, PRNumber *int, requestedBy string, reporter reporting.Reporter, lock locking2.Lock, prService orchestrator.PullRequestService, projectNamespace string, workingDir string, planStorage storage.PlanStorage, appliesPerProject map[string]bool) (*execution.DiggerExecutorResult, string, error) {
log.Printf("Running '%s' for project '%s' (workflow: %s)\n", command, job.ProjectName, job.ProjectWorkflow)

allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, requestedBy, []string{})
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, requestedBy, []string{})

if err != nil {
return nil, "error checking policy", fmt.Errorf("error checking policy: %v", err)
Expand Down Expand Up @@ -278,7 +278,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
} else if planPerformed {
if isNonEmptyPlan {
reportTerraformPlanOutput(reporter, projectLock.LockId(), plan)
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, planJsonOutput)
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, planJsonOutput)
if err != nil {
msg := fmt.Sprintf("Failed to validate plan. %v", err)
log.Printf(msg)
Expand Down Expand Up @@ -381,7 +381,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
return nil, msg, fmt.Errorf(msg)
}

_, violations, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, terraformPlanJsonStr)
_, violations, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, terraformPlanJsonStr)
if err != nil {
msg := fmt.Sprintf("Failed to check plan policy. %v", err)
log.Printf(msg)
Expand All @@ -393,7 +393,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
planPolicyViolations = []string{}
}

allowedToApply, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, requestedBy, planPolicyViolations)
allowedToApply, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, requestedBy, planPolicyViolations)
if err != nil {
msg := fmt.Sprintf("Failed to run plan policy check before apply. %v", err)
log.Printf(msg)
Expand Down Expand Up @@ -544,7 +544,7 @@ func RunJob(

for _, command := range job.Commands {

allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, nil, SCMOrganisation, SCMrepository, job.ProjectName, command, nil, requestedBy, []string{})
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, nil, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, nil, requestedBy, []string{})

if err != nil {
return fmt.Errorf("error checking policy: %v", err)
Expand Down Expand Up @@ -619,7 +619,7 @@ func RunJob(
}
return fmt.Errorf(msg)
}
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, planJsonOutput)
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, planJsonOutput)
log.Printf(strings.Join(messages, "\n"))
if err != nil {
msg := fmt.Sprintf("Failed to validate plan %v", err)
Expand Down
16 changes: 8 additions & 8 deletions cli/pkg/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ type DiggerHttpPolicyProvider struct {
type NoOpPolicyChecker struct {
}

func (p NoOpPolicyChecker) CheckAccessPolicy(_ orchestrator.OrgService, _ *orchestrator.PullRequestService, _ string, _ string, _ string, _ string, _ *int, _ string, _ []string) (bool, error) {
func (p NoOpPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
return true, nil
}

func (p NoOpPolicyChecker) CheckPlanPolicy(_ string, _ string, _ string, _ string) (bool, []string, error) {
func (p NoOpPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
return true, nil, nil
}

Expand Down Expand Up @@ -181,7 +181,7 @@ func getPlanPolicyForNamespace(p *DiggerHttpPolicyProvider, namespace string, pr
}

// GetPolicy fetches policy for particular project, if not found then it will fallback to org level policy
func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string) (string, error) {
func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
namespace := fmt.Sprintf("%v-%v", organisation, repo)
content, resp, err := getAccessPolicyForNamespace(&p, namespace, projectName)
if err != nil {
Expand Down Expand Up @@ -211,7 +211,7 @@ func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo stri
}
}

func (p DiggerHttpPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string) (string, error) {
func (p DiggerHttpPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
namespace := fmt.Sprintf("%v-%v", organisation, repo)
content, resp, err := getPlanPolicyForNamespace(&p, namespace, projectName)
if err != nil {
Expand Down Expand Up @@ -264,9 +264,9 @@ type DiggerPolicyChecker struct {
}

// TODO refactor to use AccessPolicyContext - too many arguments
func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {

policy, err := p.PolicyProvider.GetAccessPolicy(SCMOrganisation, SCMrepository, projectName)
policy, err := p.PolicyProvider.GetAccessPolicy(SCMOrganisation, SCMrepository, projectName, projectDir)

if err != nil {
log.Printf("Error while fetching policy: %v", err)
Expand Down Expand Up @@ -331,8 +331,8 @@ func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService
return true, nil
}

func (p DiggerPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectName string, planOutput string) (bool, []string, error) {
policy, err := p.PolicyProvider.GetPlanPolicy(SCMOrganisation, SCMrepository, projectName)
func (p DiggerPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
policy, err := p.PolicyProvider.GetPlanPolicy(SCMOrganisation, SCMrepository, projectname, projectDir)
if err != nil {
return false, nil, fmt.Errorf("failed get plan policy: %v", err)
}
Expand Down
16 changes: 8 additions & 8 deletions cli/pkg/policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ func (s *OpaExamplePolicyProvider) GetOrganisation() string {
type DiggerDefaultPolicyProvider struct {
}

func (s *DiggerDefaultPolicyProvider) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerDefaultPolicyProvider) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return DefaultAccessPolicy, nil
}

func (s *DiggerDefaultPolicyProvider) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerDefaultPolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "package digger\n", nil
}

Expand All @@ -69,7 +69,7 @@ func (s *DiggerDefaultPolicyProvider) GetOrganisation() string {
type DiggerExamplePolicyProvider struct {
}

func (s *DiggerExamplePolicyProvider) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerExamplePolicyProvider) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "package digger\n" +
"\n" +
"user_permissions := {\n" +
Expand All @@ -85,7 +85,7 @@ func (s *DiggerExamplePolicyProvider) GetAccessPolicy(_ string, _ string, _ stri
"", nil
}

func (s *DiggerExamplePolicyProvider) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerExamplePolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "package digger\n", nil
}

Expand All @@ -100,7 +100,7 @@ func (s *DiggerExamplePolicyProvider) GetOrganisation() string {
type DiggerExamplePolicyProvider2 struct {
}

func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "package digger\n" +
"\n" +
"user_permissions := {\n" +
Expand All @@ -119,7 +119,7 @@ func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(_ string, _ string, _ str
"", nil
}

func (s *DiggerExamplePolicyProvider2) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
func (s *DiggerExamplePolicyProvider2) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "package digger\n\ndeny[sprintf(message, [resource.address])] {\n message := \"Cannot create EC2 instances!\"\n resource := input.terraform.resource_changes[_]\n resource.change.actions[_] == \"create\"\n resource[type] == \"aws_instance\"\n}\n", nil
}

Expand Down Expand Up @@ -223,7 +223,7 @@ func TestDiggerAccessPolicyChecker_Check(t *testing.T) {
PolicyProvider: tt.fields.PolicyProvider,
}
ciService := utils.MockPullRequestManager{Teams: []string{"engineering"}}
got, err := p.CheckAccessPolicy(ciService, nil, tt.organisation, tt.name, tt.name, tt.command, nil, tt.requestedBy, tt.planPolicyViolations)
got, err := p.CheckAccessPolicy(ciService, nil, tt.organisation, tt.name, tt.name, "", tt.command, nil, tt.requestedBy, tt.planPolicyViolations)
if (err != nil) != tt.wantErr {
t.Errorf("DiggerPolicyChecker.CheckAccessPolicy() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestDiggerPlanPolicyChecker_Check(t *testing.T) {
var p = &DiggerPolicyChecker{
PolicyProvider: tt.fields.PolicyProvider,
}
got, _, err := p.CheckPlanPolicy("", "", "", tt.planJsonOutput)
got, _, err := p.CheckPlanPolicy("", "", "", "", tt.planJsonOutput)
if (err != nil) != tt.wantErr {
t.Errorf("DiggerPolicyChecker.CheckPlanPolicy() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
4 changes: 2 additions & 2 deletions cli/pkg/utils/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ func (tf *MockTerraform) Plan() (bool, string, string, error) {
type MockPolicyChecker struct {
}

func (t MockPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, ptr *int, requestedBy string, planPolicyViolations []string) (bool, error) {
func (t MockPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
return false, nil
}

func (t MockPolicyChecker) CheckPlanPolicy(projectName string, SCMOrganisation string, command string, requestedBy string) (bool, []string, error) {
func (t MockPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
return false, nil, nil
}

Expand Down
69 changes: 51 additions & 18 deletions ee/cli/pkg/policy/policy.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package policy

import (
"fmt"
"github.com/diggerhq/digger/ee/cli/pkg/utils"
"github.com/samber/lo"
"log"
"os"
"path"
"path/filepath"
"slices"
"strings"
)

const DefaultAccessPolicy = `
Expand All @@ -31,31 +34,61 @@ func getContents(filePath string) (string, error) {
return string(contents), nil
}

func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName string, fileName string) (string, error) {
// GetPrefixesForPath
// @path is the total path example /dev/vpc/subnets
// @filename is the name of the file to search for example access.rego
// returns the list of prefixes in priority order example:
// /dev/vpc/subnets/access.rego
// /dev/vpc/access.rego
// /dev/access.rego
func GetPrefixesForPath(path string, fileName string) []string {
var prefixes []string
parts := strings.Split(filepath.Clean(path), string(filepath.Separator))
for i := range parts {
prefixes = append(prefixes, filepath.Join(parts[:i+1]...))
}

slices.Reverse(prefixes)
prefixes = lo.FilterMap(prefixes, func(item string, index int) (string, bool) {
// if input path was absolute then result should be absolute and ignore last item ""
if parts[0] == "" {
return string(filepath.Separator) + item + string(filepath.Separator) + fileName, index < len(prefixes)-1
} else {
return item + string(filepath.Separator) + fileName, index < len(prefixes)
}
})

return prefixes
}

func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName string, projectDir string, fileName string) (string, error) {
var contents string
err := utils.CloneGitRepoAndDoAction(p.ManagementRepoUrl, "main", p.GitToken, func(basePath string) error {
// we start with the project directory path prefixes as the highest priority
prefixes := GetPrefixesForPath(path.Join(basePath, projectDir), fileName)

// we also add a known location as a least priority item
orgAccesspath := path.Join(basePath, "policies", fileName)
repoAccesspath := path.Join(basePath, "policies", repo, fileName)
projectAccessPath := path.Join(basePath, "policies", repo, projectName, fileName)
prefixes = append(prefixes, projectAccessPath)
prefixes = append(prefixes, repoAccesspath)
prefixes = append(prefixes, orgAccesspath)

log.Printf("loading repo orgAccess %v repoAccess %v projectAcces %v", orgAccesspath, repoAccesspath, projectAccessPath)
var err error
contents, err = getContents(projectAccessPath)
if os.IsNotExist(err) {
contents, err = getContents(repoAccesspath)
for _, pathPrefix := range prefixes {
var err error
contents, err = getContents(pathPrefix)
log.Printf("path: %v contents: %v, err: %v", pathPrefix, contents, err)
if err == nil {
return nil
}
if os.IsNotExist(err) {
contents, err = getContents(orgAccesspath)
if os.IsNotExist(err) {
return nil
} else {
fmt.Errorf("could not find any matching policy for %v,%v", repo, projectName)
}
continue
} else {
return err
}
} else {
return err
}

return nil
})
if err != nil {
Expand All @@ -65,11 +98,11 @@ func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName
}

// GetPolicy fetches policy for particular project, if not found then it will fallback to org level policy
func (p DiggerRepoPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string) (string, error) {
return p.getPolicyFileContents(repo, projectName, "access.rego")
func (p DiggerRepoPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
return p.getPolicyFileContents(repo, projectName, projectDir, "access.rego")
}

func (p DiggerRepoPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string) (string, error) {
func (p DiggerRepoPolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
return "", nil
}

Expand Down
25 changes: 25 additions & 0 deletions ee/cli/pkg/policy/policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package policy

import (
"github.com/stretchr/testify/assert"
"log"
"os"
"testing"
)

func init() {
log.SetOutput(os.Stdout)
log.SetFlags(log.Ldate | log.Ltime)
}

func TestGetPrefixesForPath(t *testing.T) {
prefixes := GetPrefixesForPath("dev/vpc/subnets", "access.rego")
assert.Equal(t, []string{"dev/vpc/subnets/access.rego", "dev/vpc/access.rego", "dev/access.rego"}, prefixes)
log.Printf("%v", prefixes)
}

func TestGetPrefixesForPathAbsolute(t *testing.T) {
prefixes := GetPrefixesForPath("/dev/vpc/subnets", "access.rego")
assert.Equal(t, []string{"/dev/vpc/subnets/access.rego", "/dev/vpc/access.rego", "/dev/access.rego"}, prefixes)
log.Printf("%v", prefixes)
}

0 comments on commit 3dd12d5

Please sign in to comment.