Skip to content

Commit 1eb5d62

Browse files
Better configuration (#79)
* Configurable Transport (#75) * new functions to allow HTTPClient configuration * updated go.mod for testing from remote * updated go.mod for remote testing * revert go.mod replace directives * Fixed NewOrgClientWithTransport comment * Make client fully configurable * make empty messages limit configurable #70 #71 * make auth token private in config * add docs * lint --------- Co-authored-by: Michael Fox <[email protected]>
1 parent 133d2c9 commit 1eb5d62

File tree

10 files changed

+89
-53
lines changed

10 files changed

+89
-53
lines changed

api.go

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,34 @@ import (
66
"net/http"
77
)
88

9-
const apiURLv1 = "https://api.openai.com/v1"
10-
11-
func newTransport() *http.Client {
12-
return &http.Client{}
13-
}
14-
159
// Client is OpenAI GPT-3 API client.
1610
type Client struct {
17-
BaseURL string
18-
HTTPClient *http.Client
19-
authToken string
20-
idOrg string
11+
config ClientConfig
2112
}
2213

2314
// NewClient creates new OpenAI API client.
2415
func NewClient(authToken string) *Client {
25-
return &Client{
26-
BaseURL: apiURLv1,
27-
HTTPClient: newTransport(),
28-
authToken: authToken,
29-
idOrg: "",
30-
}
16+
config := DefaultConfig(authToken)
17+
return &Client{config}
18+
}
19+
20+
// NewClientWithConfig creates new OpenAI API client for specified config.
21+
func NewClientWithConfig(config ClientConfig) *Client {
22+
return &Client{config}
3123
}
3224

3325
// NewOrgClient creates new OpenAI API client for specified Organization ID.
26+
//
27+
// Deprecated: Please use NewClientWithConfig.
3428
func NewOrgClient(authToken, org string) *Client {
35-
return &Client{
36-
BaseURL: apiURLv1,
37-
HTTPClient: newTransport(),
38-
authToken: authToken,
39-
idOrg: org,
40-
}
29+
config := DefaultConfig(authToken)
30+
config.OrgID = org
31+
return &Client{config}
4132
}
4233

4334
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
4435
req.Header.Set("Accept", "application/json; charset=utf-8")
45-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
36+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
4637

4738
// Check whether Content-Type is already set, Upload Files API requires
4839
// Content-Type == multipart/form-data
@@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
5142
req.Header.Set("Content-Type", "application/json; charset=utf-8")
5243
}
5344

54-
if len(c.idOrg) > 0 {
55-
req.Header.Set("OpenAI-Organization", c.idOrg)
45+
if len(c.config.OrgID) > 0 {
46+
req.Header.Set("OpenAI-Organization", c.config.OrgID)
5647
}
5748

58-
res, err := c.HTTPClient.Do(req)
49+
res, err := c.config.HTTPClient.Do(req)
5950
if err != nil {
6051
return err
6152
}
@@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
8677
}
8778

8879
func (c *Client) fullURL(suffix string) string {
89-
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
80+
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
9081
}

api_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) {
110110

111111
func TestRequestError(t *testing.T) {
112112
var err error
113-
c := NewClient("dummy")
114-
c.BaseURL = "https://httpbin.org/status/418?"
113+
114+
config := DefaultConfig("dummy")
115+
config.BaseURL = "https://httpbin.org/status/418?"
116+
c := NewClientWithConfig(config)
115117
ctx := context.Background()
116118
_, err = c.ListEngines(ctx)
117119
if err == nil {

completion_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) {
2525
ts.Start()
2626
defer ts.Close()
2727

28-
client := NewClient(test.GetTestToken())
28+
config := DefaultConfig(test.GetTestToken())
29+
config.BaseURL = ts.URL + "/v1"
30+
client := NewClientWithConfig(config)
2931
ctx := context.Background()
30-
client.BaseURL = ts.URL + "/v1"
3132

3233
req := CompletionRequest{
3334
MaxTokens: 5,

config.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package gogpt
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
const (
8+
apiURLv1 = "https://api.openai.com/v1"
9+
defaultEmptyMessagesLimit uint = 300
10+
)
11+
12+
// ClientConfig is a configuration of a client.
13+
type ClientConfig struct {
14+
authToken string
15+
16+
HTTPClient *http.Client
17+
18+
BaseURL string
19+
OrgID string
20+
21+
EmptyMessagesLimit uint
22+
}
23+
24+
func DefaultConfig(authToken string) ClientConfig {
25+
return ClientConfig{
26+
HTTPClient: &http.Client{},
27+
BaseURL: apiURLv1,
28+
OrgID: "",
29+
authToken: authToken,
30+
31+
EmptyMessagesLimit: defaultEmptyMessagesLimit,
32+
}
33+
}

edits_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ func TestEdits(t *testing.T) {
2323
ts.Start()
2424
defer ts.Close()
2525

26-
client := NewClient(test.GetTestToken())
26+
config := DefaultConfig(test.GetTestToken())
27+
config.BaseURL = ts.URL + "/v1"
28+
client := NewClientWithConfig(config)
2729
ctx := context.Background()
28-
client.BaseURL = ts.URL + "/v1"
2930

3031
// create an edit request
3132
model := "ada"

files_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) {
2222
ts.Start()
2323
defer ts.Close()
2424

25-
client := NewClient(test.GetTestToken())
25+
config := DefaultConfig(test.GetTestToken())
26+
config.BaseURL = ts.URL + "/v1"
27+
client := NewClientWithConfig(config)
2628
ctx := context.Background()
27-
client.BaseURL = ts.URL + "/v1"
2829

2930
req := FileRequest{
3031
FileName: "test.go",

image_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ func TestImages(t *testing.T) {
2323
ts.Start()
2424
defer ts.Close()
2525

26-
client := NewClient(test.GetTestToken())
26+
config := DefaultConfig(test.GetTestToken())
27+
config.BaseURL = ts.URL + "/v1"
28+
client := NewClientWithConfig(config)
2729
ctx := context.Background()
28-
client.BaseURL = ts.URL + "/v1"
2930

3031
req := ImageRequest{}
3132
req.Prompt = "Lorem ipsum"
@@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) {
9495
ts.Start()
9596
defer ts.Close()
9697

97-
client := NewClient(test.GetTestToken())
98+
config := DefaultConfig(test.GetTestToken())
99+
config.BaseURL = ts.URL + "/v1"
100+
client := NewClientWithConfig(config)
98101
ctx := context.Background()
99-
client.BaseURL = ts.URL + "/v1"
100102

101103
origin, err := os.Create("image.png")
102104
if err != nil {

moderation_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ func TestModerations(t *testing.T) {
2525
ts.Start()
2626
defer ts.Close()
2727

28-
client := NewClient(test.GetTestToken())
28+
config := DefaultConfig(test.GetTestToken())
29+
config.BaseURL = ts.URL + "/v1"
30+
client := NewClientWithConfig(config)
2931
ctx := context.Background()
30-
client.BaseURL = ts.URL + "/v1"
3132

3233
// create an edit request
3334
model := "text-moderation-stable"

stream.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ import (
1111
)
1212

1313
var (
14-
emptyMessagesLimit = 300
1514
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
1615
)
1716

1817
type CompletionStream struct {
18+
emptyMessagesLimit uint
19+
1920
reader *bufio.Reader
2021
response *http.Response
2122
}
2223

2324
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
24-
emptyMessagesCount := 0
25+
var emptyMessagesCount uint
2526

2627
waitForData:
2728
line, err := stream.reader.ReadBytes('\n')
@@ -33,7 +34,7 @@ waitForData:
3334
line = bytes.TrimSpace(line)
3435
if !bytes.HasPrefix(line, headerData) {
3536
emptyMessagesCount++
36-
if emptyMessagesCount > emptyMessagesLimit {
37+
if emptyMessagesCount > stream.emptyMessagesLimit {
3738
err = ErrTooManyEmptyStreamMessages
3839
return
3940
}
@@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream(
7475
req.Header.Set("Accept", "text/event-stream")
7576
req.Header.Set("Cache-Control", "no-cache")
7677
req.Header.Set("Connection", "keep-alive")
77-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
78+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
7879
if err != nil {
7980
return
8081
}
8182

8283
req = req.WithContext(ctx)
83-
resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
84+
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
8485
if err != nil {
8586
return
8687
}
8788

8889
stream = &CompletionStream{
90+
emptyMessagesLimit: c.config.EmptyMessagesLimit,
91+
8992
reader: bufio.NewReader(resp.Body),
9093
response: resp,
9194
}

stream_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) {
3737
defer server.Close()
3838

3939
// Client portion of the test
40-
client := NewClient(test.GetTestToken())
40+
config := DefaultConfig(test.GetTestToken())
41+
config.BaseURL = server.URL + "/v1"
42+
config.HTTPClient.Transport = &tokenRoundTripper{
43+
test.GetTestToken(),
44+
http.DefaultTransport,
45+
}
46+
47+
client := NewClientWithConfig(config)
4148
ctx := context.Background()
42-
client.BaseURL = server.URL + "/v1"
4349

4450
request := CompletionRequest{
4551
Prompt: "Ex falso quodlibet",
@@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) {
4854
Stream: true,
4955
}
5056

51-
client.HTTPClient.Transport = &tokenRoundTripper{
52-
test.GetTestToken(),
53-
http.DefaultTransport,
54-
}
55-
5657
stream, err := client.CreateCompletionStream(ctx, request)
5758
if err != nil {
5859
t.Errorf("CreateCompletionStream returned error: %v", err)

0 commit comments

Comments
 (0)