Skip to content

Commit 9c9c1b0

Browse files
committed
gopls/internal/golang: add extract interface code action
1 parent 97ea816 commit 9c9c1b0

File tree

5 files changed

+380
-23
lines changed

5 files changed

+380
-23
lines changed

gopls/internal/golang/codeaction.go

+42-23
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
8585
}
8686
}
8787

88+
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
89+
if err != nil {
90+
return nil, err
91+
}
92+
8893
if want[protocol.RefactorExtract] {
89-
extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options())
94+
extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options())
9095
if err != nil {
9196
return nil, err
9297
}
@@ -198,20 +203,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
198203
}
199204

200205
// getExtractCodeActions returns any refactor.extract code actions for the selection.
201-
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
202-
if rng.Start == rng.End {
203-
return nil, nil
204-
}
205-
206+
func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
206207
start, end, err := pgf.RangePos(rng)
207208
if err != nil {
208209
return nil, err
209210
}
211+
210212
puri := pgf.URI
211213
var commands []protocol.Command
212-
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
213-
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
214-
Fix: fixExtractFunction,
214+
215+
if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok {
216+
cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{
217+
Fix: fixExtractInterface,
215218
URI: puri,
216219
Range: rng,
217220
ResolveEdits: supportsResolveEdits(options),
@@ -220,9 +223,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
220223
return nil, err
221224
}
222225
commands = append(commands, cmd)
223-
if methodOk {
224-
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
225-
Fix: fixExtractMethod,
226+
}
227+
228+
if rng.Start != rng.End {
229+
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
230+
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
231+
Fix: fixExtractFunction,
226232
URI: puri,
227233
Range: rng,
228234
ResolveEdits: supportsResolveEdits(options),
@@ -231,20 +237,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
231237
return nil, err
232238
}
233239
commands = append(commands, cmd)
240+
if methodOk {
241+
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
242+
Fix: fixExtractMethod,
243+
URI: puri,
244+
Range: rng,
245+
ResolveEdits: supportsResolveEdits(options),
246+
})
247+
if err != nil {
248+
return nil, err
249+
}
250+
commands = append(commands, cmd)
251+
}
234252
}
235-
}
236-
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
237-
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
238-
Fix: fixExtractVariable,
239-
URI: puri,
240-
Range: rng,
241-
ResolveEdits: supportsResolveEdits(options),
242-
})
243-
if err != nil {
244-
return nil, err
253+
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
254+
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
255+
Fix: fixExtractVariable,
256+
URI: puri,
257+
Range: rng,
258+
ResolveEdits: supportsResolveEdits(options),
259+
})
260+
if err != nil {
261+
return nil, err
262+
}
263+
commands = append(commands, cmd)
245264
}
246-
commands = append(commands, cmd)
247265
}
266+
248267
var actions []protocol.CodeAction
249268
for i := range commands {
250269
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))

gopls/internal/golang/extract.go

+34
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"golang.org/x/tools/go/analysis"
2020
"golang.org/x/tools/go/ast/astutil"
21+
"golang.org/x/tools/gopls/internal/cache"
2122
"golang.org/x/tools/gopls/internal/util/bug"
2223
"golang.org/x/tools/gopls/internal/util/safetoken"
2324
"golang.org/x/tools/internal/analysisinternal"
@@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N
127128
return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
128129
}
129130

131+
// CanExtractInterface reports whether the code in the given position is for a
132+
// type which can be represented as an interface.
133+
func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
134+
path, _ := astutil.PathEnclosingInterval(file, start, end)
135+
if len(path) == 0 {
136+
return nil, nil, false, fmt.Errorf("no path enclosing interval")
137+
}
138+
139+
node := path[0]
140+
expr, ok := node.(ast.Expr)
141+
if !ok {
142+
return nil, nil, false, fmt.Errorf("node is not an expression")
143+
}
144+
145+
switch e := expr.(type) {
146+
case *ast.Ident:
147+
o, ok := pkg.TypesInfo().ObjectOf(e).(*types.TypeName)
148+
if !ok {
149+
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
150+
}
151+
152+
if _, ok := o.Type().(*types.Basic); ok {
153+
return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface")
154+
}
155+
156+
return expr, path, true, nil
157+
case *ast.StarExpr, *ast.SelectorExpr:
158+
return expr, path, true, nil
159+
default:
160+
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
161+
}
162+
}
163+
130164
// Calculate indentation for insertion.
131165
// When inserting lines of code, we must ensure that the lines have consistent
132166
// formatting (i.e. the proper indentation). To do so, we observe the indentation on the

gopls/internal/golang/fix.go

+141
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
package golang
66

77
import (
8+
"bytes"
89
"context"
10+
"errors"
911
"fmt"
1012
"go/ast"
1113
"go/token"
1214
"go/types"
15+
"slices"
1316

1417
"golang.org/x/tools/go/analysis"
18+
"golang.org/x/tools/go/ast/astutil"
1519
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
1620
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
1721
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
@@ -22,6 +26,7 @@ import (
2226
"golang.org/x/tools/gopls/internal/file"
2327
"golang.org/x/tools/gopls/internal/protocol"
2428
"golang.org/x/tools/gopls/internal/util/bug"
29+
"golang.org/x/tools/gopls/internal/util/safetoken"
2530
"golang.org/x/tools/internal/imports"
2631
)
2732

@@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
6166
const (
6267
fixExtractVariable = "extract_variable"
6368
fixExtractFunction = "extract_function"
69+
fixExtractInterface = "extract_interface"
6470
fixExtractMethod = "extract_method"
6571
fixInlineCall = "inline_call"
6672
fixInvertIfCondition = "invert_if_condition"
@@ -112,6 +118,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
112118

113119
// Ad-hoc fixers: these are used when the command is
114120
// constructed directly by logic in server/code_action.
121+
fixExtractInterface: extractInterface,
115122
fixExtractFunction: singleFile(extractFunction),
116123
fixExtractMethod: singleFile(extractMethod),
117124
fixExtractVariable: singleFile(extractVariable),
@@ -142,6 +149,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
142149
return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion)
143150
}
144151

152+
func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
153+
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
154+
155+
var field *ast.Field
156+
var decl ast.Decl
157+
for _, node := range path {
158+
if f, ok := node.(*ast.Field); ok {
159+
field = f
160+
continue
161+
}
162+
163+
// Record the node that starts the declaration of the type that contains
164+
// the field we are creating the interface for.
165+
if d, ok := node.(ast.Decl); ok {
166+
decl = d
167+
break // we have both the field and the declaration
168+
}
169+
}
170+
171+
if field == nil || decl == nil {
172+
return nil, nil, nil
173+
}
174+
175+
p := safetoken.StartPosition(pkg.FileSet(), field.Pos())
176+
pos := protocol.Position{
177+
Line: uint32(p.Line - 1), // Line is zero-based
178+
Character: uint32(p.Column - 1), // Character is zero-based
179+
}
180+
181+
fh, err := snapshot.ReadFile(ctx, pgf.URI)
182+
if err != nil {
183+
return nil, nil, err
184+
}
185+
186+
refs, err := references(ctx, snapshot, fh, pos, false)
187+
if err != nil {
188+
return nil, nil, err
189+
}
190+
191+
type method struct {
192+
signature *types.Signature
193+
name string
194+
}
195+
196+
var methods []method
197+
for _, ref := range refs {
198+
locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI)
199+
if err != nil {
200+
return nil, nil, err
201+
}
202+
203+
_, end, err := locPgf.RangePos(ref.location.Range)
204+
if err != nil {
205+
return nil, nil, err
206+
}
207+
208+
// We are interested in the method call, so we need the node after the dot
209+
rangeEnd := end + token.Pos(len("."))
210+
path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd)
211+
id, ok := path[0].(*ast.Ident)
212+
if !ok {
213+
continue
214+
}
215+
216+
obj := locPkg.TypesInfo().ObjectOf(id)
217+
if obj == nil {
218+
continue
219+
}
220+
221+
sig, ok := obj.Type().(*types.Signature)
222+
if !ok {
223+
return nil, nil, errors.New("cannot extract interface with non-method accesses")
224+
}
225+
226+
fc := method{signature: sig, name: obj.Name()}
227+
if !slices.Contains(methods, fc) {
228+
methods = append(methods, fc)
229+
}
230+
}
231+
232+
interfaceName := "I" + pkg.TypesInfo().ObjectOf(field.Names[0]).Name()
233+
var buf bytes.Buffer
234+
buf.WriteString("\ntype ")
235+
buf.WriteString(interfaceName)
236+
buf.WriteString(" interface {\n")
237+
for _, fc := range methods {
238+
buf.WriteString("\t")
239+
buf.WriteString(fc.name)
240+
types.WriteSignature(&buf, fc.signature, relativeTo(pkg.Types()))
241+
buf.WriteByte('\n')
242+
}
243+
buf.WriteByte('}')
244+
buf.WriteByte('\n')
245+
246+
interfacePos := decl.Pos() - 1
247+
// Move the interface above the documentation comment if the type declaration
248+
// includes one.
249+
switch d := decl.(type) {
250+
case *ast.GenDecl:
251+
if d.Doc != nil {
252+
interfacePos = d.Doc.Pos() - 1
253+
}
254+
case *ast.FuncDecl:
255+
if d.Doc != nil {
256+
interfacePos = d.Doc.Pos() - 1
257+
}
258+
}
259+
260+
return pkg.FileSet(), &analysis.SuggestedFix{
261+
Message: "Extract interface",
262+
TextEdits: []analysis.TextEdit{{
263+
Pos: interfacePos,
264+
End: interfacePos,
265+
NewText: buf.Bytes(),
266+
}, {
267+
Pos: field.Type.Pos(),
268+
End: field.Type.End(),
269+
NewText: []byte(interfaceName),
270+
}},
271+
}, nil
272+
}
273+
274+
func relativeTo(pkg *types.Package) types.Qualifier {
275+
if pkg == nil {
276+
return nil
277+
}
278+
return func(other *types.Package) string {
279+
if pkg == other {
280+
return "" // same package; unqualified
281+
}
282+
return other.Name()
283+
}
284+
}
285+
145286
// suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form.
146287
func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) {
147288
editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{}

0 commit comments

Comments
 (0)