Skip to content

Commit 42c8073

Browse files
committed
Merge remote-tracking branch 'origin/main' into deprecate-and-add-new-apis
2 parents 2d4c502 + 1d07942 commit 42c8073

34 files changed

+616
-514
lines changed

api/handlers/search_handler.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package handlers
22

33
import (
4+
"context"
45
"fmt"
56
"net/http"
67
"net/url"
@@ -35,19 +36,20 @@ func NewSearchHandler(log logrus.FieldLogger, searcher models.RecordSearcher, re
3536
}
3637

3738
func (handler *SearchHandler) Search(w http.ResponseWriter, r *http.Request) {
39+
ctx := r.Context()
3840
cfg, err := handler.buildSearchCfg(r.URL.Query())
3941
if err != nil {
4042
writeJSONError(w, http.StatusBadRequest, err.Error())
4143
return
4244
}
43-
results, err := handler.recordSearcher.Search(cfg)
45+
results, err := handler.recordSearcher.Search(ctx, cfg)
4446
if err != nil {
4547
handler.log.Errorf("error searching records: %w", err)
4648
writeJSONError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
4749
return
4850
}
4951

50-
response, err := handler.toSearchResponse(results)
52+
response, err := handler.toSearchResponse(ctx, results)
5153
if err != nil {
5254
handler.log.Errorf("error mapping search results: %w", err)
5355
writeJSONError(w, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
@@ -76,10 +78,10 @@ func (handler *SearchHandler) buildSearchCfg(params url.Values) (cfg models.Sear
7678
return
7779
}
7880

79-
func (handler *SearchHandler) toSearchResponse(results []models.SearchResult) (response []SearchResponse, err error) {
81+
func (handler *SearchHandler) toSearchResponse(ctx context.Context, results []models.SearchResult) (response []SearchResponse, err error) {
8082
typeRepo := newCachingTypeRepo(handler.typeRepo)
8183
for _, result := range results {
82-
recordType, err := typeRepo.GetByName(result.TypeName)
84+
recordType, err := typeRepo.GetByName(ctx, result.TypeName)
8385
if err != nil {
8486
return nil, fmt.Errorf("typeRepository.GetByName: %q: %v", result.TypeName, err)
8587
}
@@ -155,19 +157,19 @@ type cachingTypeRepo struct {
155157
repo models.TypeRepository
156158
}
157159

158-
func (decorator *cachingTypeRepo) CreateOrReplace(ent models.Type) error {
160+
func (decorator *cachingTypeRepo) CreateOrReplace(ctx context.Context, ent models.Type) error {
159161
panic("not implemented")
160162
}
161163

162-
func (decorator *cachingTypeRepo) GetByName(name string) (models.Type, error) {
164+
func (decorator *cachingTypeRepo) GetByName(ctx context.Context, name string) (models.Type, error) {
163165
ent, exists := decorator.cache[name]
164166
if exists {
165167
return ent, nil
166168
}
167169

168170
decorator.mu.Lock()
169171
defer decorator.mu.Unlock()
170-
ent, err := decorator.repo.GetByName(name)
172+
ent, err := decorator.repo.GetByName(ctx, name)
171173
if err != nil {
172174
return ent, err
173175
}

api/handlers/search_handler_test.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package handlers_test
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io"
@@ -19,6 +20,7 @@ import (
1920
)
2021

2122
func TestSearchHandler(t *testing.T) {
23+
ctx := context.Background()
2224
// todo: pass testCase to ValidateResponse
2325
type testCase struct {
2426
Title string
@@ -51,7 +53,7 @@ func TestSearchHandler(t *testing.T) {
5153
var withTypes = func(ents ...models.Type) func(tc testCase, repo *mock.TypeRepository) {
5254
return func(tc testCase, repo *mock.TypeRepository) {
5355
for _, ent := range ents {
54-
repo.On("GetByName", ent.Name).Return(ent, nil)
56+
repo.On("GetByName", ctx, ent.Name).Return(ent, nil)
5557
}
5658
return
5759
}
@@ -69,7 +71,7 @@ func TestSearchHandler(t *testing.T) {
6971
SearchText: "test",
7072
InitSearcher: func(tc testCase, searcher *mock.RecordSearcher) {
7173
err := fmt.Errorf("service unavailable")
72-
searcher.On("Search", testifyMock.AnythingOfType("models.SearchConfig")).
74+
searcher.On("Search", ctx, testifyMock.AnythingOfType("models.SearchConfig")).
7375
Return([]models.SearchResult{}, err)
7476
},
7577
ExpectStatus: http.StatusInternalServerError,
@@ -84,11 +86,11 @@ func TestSearchHandler(t *testing.T) {
8486
Record: models.Record{},
8587
},
8688
}
87-
searcher.On("Search", testifyMock.AnythingOfType("models.SearchConfig")).
89+
searcher.On("Search", ctx, testifyMock.AnythingOfType("models.SearchConfig")).
8890
Return(results, nil)
8991
},
9092
InitRepo: func(tc testCase, repo *mock.TypeRepository) {
91-
repo.On("GetByName", testifyMock.AnythingOfType("string")).
93+
repo.On("GetByName", ctx, testifyMock.AnythingOfType("string")).
9294
Return(models.Type{}, models.ErrNoSuchType{})
9395
},
9496
ExpectStatus: http.StatusInternalServerError,
@@ -112,7 +114,7 @@ func TestSearchHandler(t *testing.T) {
112114
},
113115
},
114116
}
115-
searcher.On("Search", cfg).Return(response, nil)
117+
searcher.On("Search", ctx, cfg).Return(response, nil)
116118
},
117119
InitRepo: withTypes(testdata.Type),
118120
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -183,7 +185,7 @@ func TestSearchHandler(t *testing.T) {
183185
},
184186
},
185187
}
186-
searcher.On("Search", cfg).Return(results, nil)
188+
searcher.On("Search", ctx, cfg).Return(results, nil)
187189
},
188190
InitRepo: withTypes(testdata.Type),
189191
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -250,7 +252,7 @@ func TestSearchHandler(t *testing.T) {
250252
results = append(results, result)
251253
}
252254

253-
searcher.On("Search", cfg).Return(results, nil)
255+
searcher.On("Search", ctx, cfg).Return(results, nil)
254256
return
255257
},
256258
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -297,7 +299,7 @@ func TestSearchHandler(t *testing.T) {
297299
},
298300
},
299301
}
300-
searcher.On("Search", cfg).Return(results, nil)
302+
searcher.On("Search", ctx, cfg).Return(results, nil)
301303
return
302304
},
303305
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -351,7 +353,7 @@ func TestSearchHandler(t *testing.T) {
351353
},
352354
},
353355
}
354-
searcher.On("Search", cfg).Return(results, nil)
356+
searcher.On("Search", ctx, cfg).Return(results, nil)
355357
return
356358
},
357359
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -405,7 +407,7 @@ func TestSearchHandler(t *testing.T) {
405407
},
406408
},
407409
}
408-
searcher.On("Search", cfg).Return(results, nil)
410+
searcher.On("Search", ctx, cfg).Return(results, nil)
409411
},
410412
ValidateResponse: func(tc testCase, body io.Reader) error {
411413
var actualResults []handlers.SearchResponse
@@ -456,7 +458,7 @@ func TestSearchHandler(t *testing.T) {
456458
},
457459
},
458460
}
459-
searcher.On("Search", cfg).Return(results, nil)
461+
searcher.On("Search", ctx, cfg).Return(results, nil)
460462
return
461463
},
462464
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -505,7 +507,7 @@ func TestSearchHandler(t *testing.T) {
505507
},
506508
},
507509
}
508-
searcher.On("Search", cfg).Return(results, nil)
510+
searcher.On("Search", ctx, cfg).Return(results, nil)
509511
return
510512
},
511513
ValidateResponse: func(tc testCase, body io.Reader) error {
@@ -558,7 +560,7 @@ func TestSearchHandler(t *testing.T) {
558560
},
559561
},
560562
}
561-
searcher.On("Search", cfg).Return(results, nil)
563+
searcher.On("Search", ctx, cfg).Return(results, nil)
562564
return
563565
},
564566
},

api/handlers/type_handler.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func NewTypeHandler(log logrus.FieldLogger, er models.TypeRepository, rrf models
5050
}
5151

5252
func (handler *TypeHandler) GetAll(w http.ResponseWriter, r *http.Request) {
53-
types, err := handler.typeRepo.GetAll()
53+
types, err := handler.typeRepo.GetAll(r.Context())
5454
if err != nil {
5555
handler.log.
5656
Errorf("error fetching types: %v", err)
@@ -63,7 +63,7 @@ func (handler *TypeHandler) GetAll(w http.ResponseWriter, r *http.Request) {
6363

6464
func (handler *TypeHandler) GetType(w http.ResponseWriter, r *http.Request) {
6565
name := mux.Vars(r)["name"]
66-
recordType, err := handler.typeRepo.GetByName(name)
66+
recordType, err := handler.typeRepo.GetByName(r.Context(), name)
6767
if err != nil {
6868
handler.log.
6969
Errorf("error fetching type \"%s\": %v", name, err)
@@ -99,7 +99,7 @@ func (handler *TypeHandler) CreateOrReplaceType(w http.ResponseWriter, r *http.R
9999
return
100100
}
101101

102-
err = handler.typeRepo.CreateOrReplace(payload)
102+
err = handler.typeRepo.CreateOrReplace(r.Context(), payload)
103103
if err != nil {
104104
handler.log.
105105
WithField("type", payload.Name).
@@ -124,7 +124,7 @@ func (handler *TypeHandler) CreateOrReplaceType(w http.ResponseWriter, r *http.R
124124

125125
func (handler *TypeHandler) DeleteType(w http.ResponseWriter, r *http.Request) {
126126
name := mux.Vars(r)["name"]
127-
err := handler.typeRepo.Delete(name)
127+
err := handler.typeRepo.Delete(r.Context(), name)
128128
if err != nil {
129129
handler.log.
130130
Errorf("error deleting type \"%s\": %v", name, err)
@@ -157,7 +157,7 @@ func (handler *TypeHandler) DeleteRecord(w http.ResponseWriter, r *http.Request)
157157
statusCode := http.StatusInternalServerError
158158
errMessage := fmt.Sprintf("error deleting record \"%s\" with type \"%s\"", recordID, typeName)
159159

160-
recordType, err := handler.typeRepo.GetByName(typeName)
160+
recordType, err := handler.typeRepo.GetByName(r.Context(), typeName)
161161
if err != nil {
162162
handler.log.
163163
Errorf("error getting type \"%s\": %v", typeName, err)
@@ -179,7 +179,7 @@ func (handler *TypeHandler) DeleteRecord(w http.ResponseWriter, r *http.Request)
179179
return
180180
}
181181

182-
err = recordRepoFactory.Delete(recordID)
182+
err = recordRepoFactory.Delete(r.Context(), recordID)
183183
if err != nil {
184184
handler.log.
185185
Errorf("error deleting record \"%s\": %v", typeName, err)
@@ -199,7 +199,7 @@ func (handler *TypeHandler) DeleteRecord(w http.ResponseWriter, r *http.Request)
199199

200200
func (handler *TypeHandler) IngestRecord(w http.ResponseWriter, r *http.Request) {
201201
name := mux.Vars(r)["name"]
202-
recordType, err := handler.typeRepo.GetByName(name)
202+
recordType, err := handler.typeRepo.GetByName(r.Context(), name)
203203
if err != nil {
204204
status := http.StatusInternalServerError
205205
if _, ok := err.(models.ErrNoSuchType); ok {
@@ -239,7 +239,7 @@ func (handler *TypeHandler) IngestRecord(w http.ResponseWriter, r *http.Request)
239239
writeJSONError(w, status, http.StatusText(status))
240240
return
241241
}
242-
if err = recordRepo.CreateOrReplaceMany(records); err != nil {
242+
if err = recordRepo.CreateOrReplaceMany(r.Context(), records); err != nil {
243243
handler.log.WithField("type", recordType.Name).
244244
Errorf("error creating/updating records: %v", err)
245245

@@ -253,7 +253,7 @@ func (handler *TypeHandler) IngestRecord(w http.ResponseWriter, r *http.Request)
253253

254254
func (handler *TypeHandler) ListTypeRecords(w http.ResponseWriter, r *http.Request) {
255255
name := mux.Vars(r)["name"]
256-
recordType, err := handler.typeRepo.GetByName(name)
256+
recordType, err := handler.typeRepo.GetByName(r.Context(), name)
257257
if err != nil {
258258
status, message := handler.responseStatusForError(err)
259259
writeJSONError(w, status, message)
@@ -270,7 +270,7 @@ func (handler *TypeHandler) ListTypeRecords(w http.ResponseWriter, r *http.Reque
270270
}
271271
filterCfg := filterConfigFromValues(r.URL.Query())
272272

273-
records, err := recordRepo.GetAll(filterCfg)
273+
records, err := recordRepo.GetAll(r.Context(), filterCfg)
274274
if err != nil {
275275
handler.log.WithField("type", recordType).
276276
Errorf("error fetching records: GetAll: %v", err)
@@ -292,7 +292,7 @@ func (handler *TypeHandler) GetTypeRecord(w http.ResponseWriter, r *http.Request
292292
typeName = vars["name"]
293293
recordID = vars["id"]
294294
)
295-
recordType, err := handler.typeRepo.GetByName(typeName)
295+
recordType, err := handler.typeRepo.GetByName(r.Context(), typeName)
296296

297297
// TODO(Aman): make error handling a bit more DRY
298298
if err != nil {
@@ -313,7 +313,7 @@ func (handler *TypeHandler) GetTypeRecord(w http.ResponseWriter, r *http.Request
313313
return
314314
}
315315

316-
record, err := recordRepo.GetByID(recordID)
316+
record, err := recordRepo.GetByID(r.Context(), recordID)
317317
if err != nil {
318318
handler.log.WithField("type", typeName).
319319
WithField("record", recordID).

0 commit comments

Comments
 (0)