diff --git a/helix.go b/helix.go index e36dad5..17d230a 100644 --- a/helix.go +++ b/helix.go @@ -345,6 +345,7 @@ func (c *Client) doRequest(req *http.Request, resp *Response) error { c.setRequestHeaders(req) rateLimitFunc := c.opts.RateLimitFunc + attempt := 0 for { if c.lastResponse != nil && rateLimitFunc != nil { @@ -354,10 +355,22 @@ func (c *Client) doRequest(req *http.Request, resp *Response) error { } } + if attempt > 0 && + req.Body != nil && + req.GetBody != nil { + + var err error + req.Body, err = req.GetBody() + if err != nil { + return err + } + } + response, err := c.opts.HTTPClient.Do(req) if err != nil { return fmt.Errorf("Failed to execute API request: %s", err.Error()) } + attempt++ defer response.Body.Close() resp.Header = response.Header diff --git a/helix_test.go b/helix_test.go index 914a6cc..8ab8c44 100644 --- a/helix_test.go +++ b/helix_test.go @@ -3,6 +3,7 @@ package helix import ( "context" "errors" + "io" "log" "net/http" "net/http/httptest" @@ -19,6 +20,14 @@ type mockHTTPClient struct { func (mtc *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { rr := httptest.NewRecorder() + if req.Body != nil { + defer req.Body.Close() + if out, err := io.ReadAll(req.Body); err != nil { + return nil, err + } else if len(out) != int(req.ContentLength) { + return nil, errors.New("content length mismatch") + } + } handler := http.HandlerFunc(mtc.mockHandler) handler.ServeHTTP(rr, req) @@ -380,6 +389,50 @@ func TestAutomaticUserTokenRefresh(t *testing.T) { } } +func TestAutomaticUserTokenRefreshWithRequestBody(t *testing.T) { + t.Parallel() + + options := &Options{ + ClientID: "client-id", + ClientSecret: "old-client-secret", + UserAccessToken: "old-user-token", + RefreshToken: "old-refresh-token", + } + client := newMockClient(options, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/oauth2/token") { + w.Write([]byte(`{"access_token":"new-access-token","expires_in":14154,"refresh_token":"new-refresh-token","scope":["analytics:read:games","bits:read","clips:edit","user:edit","user:read:email"]}`)) + } else if strings.Contains(r.URL.Path, "/eventsub/subscriptions") { + if strings.Contains(r.Header.Get("Authorization"), "old-user-token") { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"Unauthorized","status":401,"message":"Invalid OAuth token"}`)) + } else { + w.Write([]byte(`{"total":8,"data":[],"pagination":{}}`)) + } + } else { + log.Printf("Unknown URL sent to test server: %s", r.URL.Path) + } + }) + + _, err := client.CreateEventSubSubscription(&EventSubSubscription{ + Transport: EventSubTransport{ + Method: "webhook", + Callback: "https://localhost", + }, + }) // any method works + if err != nil { + t.Fatalf("Did not expect an error, got \"%s\"", err.Error()) + } + + time.Sleep(5 * time.Millisecond) + + if client.opts.UserAccessToken != "new-access-token" { + t.Errorf("expected UserAccessToken to be %q, got %q", "new-access-token", client.opts.UserAccessToken) + } + if client.opts.RefreshToken != "new-refresh-token" { + t.Errorf("expected RefreshToken to be %q, got %q", "new-refresh-token", client.opts.RefreshToken) + } +} + func TestSetRequestHeaders(t *testing.T) { t.Parallel()