Skip to content

Commit bffba1f

Browse files
committed
WIP
1 parent e551cee commit bffba1f

15 files changed

+865
-294
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ require (
2222
github.com/sourcegraph/jsonrpc2 v0.2.0
2323
github.com/spf13/cobra v1.8.1
2424
github.com/spf13/pflag v1.0.5
25+
gopkg.in/yaml.v2 v2.4.0
2526
gopkg.in/yaml.v3 v3.0.1
2627
)
2728

@@ -89,6 +90,5 @@ require (
8990
google.golang.org/genproto/googleapis/rpc v0.0.0-20240820151423-278611b39280 // indirect
9091
google.golang.org/protobuf v1.34.2 // indirect
9192
gopkg.in/warnings.v0 v0.1.2 // indirect
92-
gopkg.in/yaml.v2 v2.4.0 // indirect
9393
sigs.k8s.io/yaml v1.4.0 // indirect
9494
)

internal/lsp/cache/cache.go

+22-27
Original file line numberDiff line numberDiff line change
@@ -197,33 +197,6 @@ func (c *Cache) SetAggregates(data map[string][]report.Aggregate) {
197197
}
198198
}
199199

200-
// GetFileComplimentAggregates returns all aggregate data other than for the
201-
// provided fileURIs. This is used when running file diagnostics while also
202-
// requiring the previous aggregate state to provide aggregate rule linting.
203-
func (c *Cache) GetFileComplimentAggregates(fileURIs ...string) map[string][]report.Aggregate {
204-
c.aggregateDataMu.Lock()
205-
defer c.aggregateDataMu.Unlock()
206-
207-
excludedFiles := make(map[string]struct{}, len(fileURIs))
208-
for _, fileURI := range fileURIs {
209-
excludedFiles[fileURI] = struct{}{}
210-
}
211-
212-
allAggregates := make(map[string][]report.Aggregate)
213-
214-
for sourceFile, aggregates := range c.aggregateData {
215-
if _, excluded := excludedFiles[sourceFile]; excluded {
216-
continue
217-
}
218-
219-
for _, aggregate := range aggregates {
220-
allAggregates[aggregate.IndexKey()] = append(allAggregates[aggregate.IndexKey()], aggregate)
221-
}
222-
}
223-
224-
return allAggregates
225-
}
226-
227200
// GetFileAggregates is used to get aggregate data for a given list of files.
228201
// This is only used in tests to validate the cache state.
229202
func (c *Cache) GetFileAggregates(fileURIs ...string) map[string][]report.Aggregate {
@@ -268,6 +241,28 @@ func (c *Cache) SetFileDiagnostics(fileURI string, diags []types.Diagnostic) {
268241
c.diagnosticsFile[fileURI] = diags
269242
}
270243

244+
// SetFileDiagnosticsForRules will perform a partial update of the diagnostics
245+
// for a file given a list of evaluated rules.
246+
func (c *Cache) SetFileDiagnosticsForRules(fileURI string, rules []string, diags []types.Diagnostic) {
247+
c.diagnosticsFileMu.Lock()
248+
defer c.diagnosticsFileMu.Unlock()
249+
250+
ruleKeys := make(map[string]struct{}, len(rules))
251+
for _, rule := range rules {
252+
ruleKeys[rule] = struct{}{}
253+
}
254+
255+
preservedDiagnostics := make([]types.Diagnostic, 0)
256+
257+
for _, diag := range c.diagnosticsFile[fileURI] {
258+
if _, ok := ruleKeys[diag.Code]; !ok {
259+
preservedDiagnostics = append(preservedDiagnostics, diag)
260+
}
261+
}
262+
263+
c.diagnosticsFile[fileURI] = append(preservedDiagnostics, diags...)
264+
}
265+
271266
func (c *Cache) ClearFileDiagnostics() {
272267
c.diagnosticsFileMu.Lock()
273268
defer c.diagnosticsFileMu.Unlock()

internal/lsp/cache/cache_test.go

+39-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"reflect"
55
"testing"
66

7+
"github.com/styrainc/regal/internal/lsp/types"
78
"github.com/styrainc/regal/pkg/report"
89
)
910

@@ -89,12 +90,6 @@ func TestManageAggregates(t *testing.T) {
8990
t.Fatalf("unexpected number of aggregates for file2.rego: %d", len(aggs2))
9091
}
9192

92-
file1ComplimentAggs := c.GetFileComplimentAggregates("file1.rego")
93-
94-
if !reflect.DeepEqual(file1ComplimentAggs, aggs2) {
95-
t.Fatalf("unexpected compliment aggregates for file1.rego, exp\n%v\ngot\n%v", aggs2, file1ComplimentAggs)
96-
}
97-
9893
allAggs := c.GetFileAggregates()
9994

10095
if len(allAggs) != 2 {
@@ -126,3 +121,41 @@ func TestManageAggregates(t *testing.T) {
126121
t.Fatalf("unexpected number of aggregates: %d", len(allAggs))
127122
}
128123
}
124+
125+
func TestPartialDiagnosticsUpdate(t *testing.T) {
126+
t.Parallel()
127+
128+
c := NewCache()
129+
130+
diag1 := types.Diagnostic{Code: "code1"}
131+
diag2 := types.Diagnostic{Code: "code2"}
132+
diag3 := types.Diagnostic{Code: "code3"}
133+
134+
c.SetFileDiagnostics("foo.rego", []types.Diagnostic{
135+
diag1, diag2,
136+
})
137+
138+
foundDiags, ok := c.GetFileDiagnostics("foo.rego")
139+
if !ok {
140+
t.Fatalf("expected to get diags for foo.rego")
141+
}
142+
143+
if !reflect.DeepEqual(foundDiags, []types.Diagnostic{diag1, diag2}) {
144+
t.Fatalf("unexpected diagnostics: %v", foundDiags)
145+
}
146+
147+
c.SetFileDiagnosticsForRules(
148+
"foo.rego",
149+
[]string{"code2", "code3"},
150+
[]types.Diagnostic{diag3},
151+
)
152+
153+
foundDiags, ok = c.GetFileDiagnostics("foo.rego")
154+
if !ok {
155+
t.Fatalf("expected to get diags for foo.rego")
156+
}
157+
158+
if !reflect.DeepEqual(foundDiags, []types.Diagnostic{diag1, diag3}) {
159+
t.Fatalf("unexpected diagnostics: %v", foundDiags)
160+
}
161+
}

internal/lsp/lint.go

+31-14
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ func updateFileDiagnostics(
149149
cache *cache.Cache,
150150
regalConfig *config.Config,
151151
fileURI string,
152-
workspaceRootURI string,
152+
workspaceRootDir string,
153+
updateDiagnosticsForRules []string,
153154
) error {
154155
module, ok := cache.GetModule(fileURI)
155156
if !ok {
@@ -165,11 +166,12 @@ func updateFileDiagnostics(
165166
input := rules.NewInput(map[string]string{fileURI: contents}, map[string]*ast.Module{fileURI: module})
166167

167168
regalInstance := linter.NewLinter().
168-
WithAggregates(cache.GetFileComplimentAggregates(fileURI)).
169-
WithAlwaysAggregate(true).
169+
// needed to get the aggregateData for this file
170+
WithCollectQuery(true).
171+
// needed to get the aggregateData out so we can update the cache
170172
WithExportAggregates(true).
171173
WithInputModules(&input).
172-
WithRootDir(workspaceRootURI)
174+
WithRootDir(workspaceRootDir)
173175

174176
if regalConfig != nil {
175177
regalInstance = regalInstance.WithUserConfig(*regalConfig)
@@ -180,7 +182,7 @@ func updateFileDiagnostics(
180182
return fmt.Errorf("failed to lint: %w", err)
181183
}
182184

183-
fileDiags := convertReportToDiagnostics(&rpt, workspaceRootURI)
185+
fileDiags := convertReportToDiagnostics(&rpt, workspaceRootDir)
184186

185187
files := cache.GetAllFiles()
186188

@@ -198,7 +200,7 @@ func updateFileDiagnostics(
198200
fd = []types.Diagnostic{}
199201
}
200202

201-
cache.SetFileDiagnostics(uri, fd)
203+
cache.SetFileDiagnosticsForRules(uri, updateDiagnosticsForRules, fd)
202204
}
203205
}
204206

@@ -211,32 +213,40 @@ func updateAllDiagnostics(
211213
ctx context.Context,
212214
cache *cache.Cache,
213215
regalConfig *config.Config,
214-
workspaceRootURI string,
216+
workspaceRootDir string,
215217
overwriteAggregates bool,
218+
aggregatesReportOnly bool,
219+
updateDiagnosticsForRules []string,
216220
) error {
221+
var err error
222+
217223
modules := cache.GetAllModules()
218224
files := cache.GetAllFiles()
219225

220-
input := rules.NewInput(files, modules)
221-
222226
regalInstance := linter.NewLinter().
223-
WithInputModules(&input).
224-
WithRootDir(workspaceRootURI).
227+
WithRootDir(workspaceRootDir).
225228
// aggregates need only be exported if they're to be used to overwrite.
226229
WithExportAggregates(overwriteAggregates)
227230

228231
if regalConfig != nil {
229232
regalInstance = regalInstance.WithUserConfig(*regalConfig)
230233
}
231234

235+
if aggregatesReportOnly {
236+
regalInstance = regalInstance.
237+
WithAggregates(cache.GetFileAggregates())
238+
} else {
239+
input := rules.NewInput(files, modules)
240+
regalInstance = regalInstance.WithInputModules(&input)
241+
}
242+
232243
rpt, err := regalInstance.Lint(ctx)
233244
if err != nil {
234245
return fmt.Errorf("failed to lint: %w", err)
235246
}
236247

237-
fileDiags := convertReportToDiagnostics(&rpt, workspaceRootURI)
248+
fileDiags := convertReportToDiagnostics(&rpt, workspaceRootDir)
238249

239-
// Update diagnostics for all files
240250
for uri := range files {
241251
parseErrs, ok := cache.GetParseErrors(uri)
242252
if ok && len(parseErrs) > 0 {
@@ -248,7 +258,14 @@ func updateAllDiagnostics(
248258
fd = []types.Diagnostic{}
249259
}
250260

251-
cache.SetFileDiagnostics(uri, fd)
261+
// when only an aggregate report was run, then we must make sure to
262+
// only update diagnostics from these rules. So the report is
263+
// authoratative, but for those rules only.
264+
if aggregatesReportOnly {
265+
cache.SetFileDiagnosticsForRules(uri, updateDiagnosticsForRules, fd)
266+
} else {
267+
cache.SetFileDiagnostics(uri, fd)
268+
}
252269
}
253270

254271
if overwriteAggregates {

internal/lsp/race_off.go

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//go:build !race
2+
// +build !race
3+
4+
package lsp
5+
6+
func isRaceEnabled() bool {
7+
return false
8+
}

internal/lsp/race_on.go

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//go:build race
2+
// +build race
3+
4+
package lsp
5+
6+
func isRaceEnabled() bool {
7+
return true
8+
}

0 commit comments

Comments
 (0)