-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient_store.go
172 lines (138 loc) · 4.09 KB
/
client_store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package pgstore
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/go-oauth2/oauth2/v4"
"github.com/go-oauth2/oauth2/v4/models"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
// DefaultClientStoreTable is the default collection for storing clients.
DefaultClientStoreTable = "oauth2_clients"
)
// ClientStoreOption is a function that configures the ClientStore.
type ClientStoreOption func(*ClientStore) error
// WithClientStoreTable configures the auth token table.
func WithClientStoreTable(table string) ClientStoreOption {
return func(s *ClientStore) error {
if table == "" {
return ErrNoTable
}
s.table = table
return nil
}
}
// WithClientStoreConnPool configures the connection pool.
func WithClientStoreConnPool(pool *pgxpool.Pool) ClientStoreOption {
return func(s *ClientStore) error {
if pool == nil {
return ErrNoConnPool
}
s.pool = pool
return nil
}
}
// WithClientStoreLogger configures the logger.
func WithClientStoreLogger(logger Logger) ClientStoreOption {
return func(s *ClientStore) error {
if logger == nil {
return ErrNoLogger
}
s.logger = logger
return nil
}
}
// ClientStoreItem data item
type ClientStoreItem struct {
ID string `db:"id"`
Secret string `db:"secret"`
Domain string `db:"domain"`
Data []byte `db:"data"`
CreatedAt time.Time `db:"created_at"`
}
// ClientStore is a data struct that stores oauth2 client information.
type ClientStore struct {
pool *pgxpool.Pool
table string
logger Logger
}
// scanToClientInfo scans a row into an oauth2.ClientInfo.
func (s *ClientStore) scanToClientInfo(ctx context.Context, row pgx.Row) (oauth2.ClientInfo, error) {
var item ClientStoreItem
err := row.Scan(&item.ID, &item.Secret, &item.Domain, &item.Data, &item.CreatedAt)
if err != nil {
return nil, err
}
var info models.Client
err = json.Unmarshal(item.Data, &info)
if err != nil {
return nil, err
}
s.logger.Log(ctx, LogLevelDebug, "client found", "id", item.ID)
return &info, nil
}
// InitTable initializes the client store table if it does not exist and
// creates the indexes.
func (s *ClientStore) InitTable(ctx context.Context) error {
s.logger.Log(ctx, LogLevelDebug, "initializing client store table", "table", s.table)
_, err := s.pool.Exec(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %[1]s (
id VARCHAR(255) PRIMARY KEY,
secret VARCHAR(255) NOT NULL,
domain VARCHAR(255) NOT NULL,
data JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS %[1]s_domain_idx ON %[1]s (domain);`,
s.table,
))
if err != nil {
s.logger.Log(ctx, LogLevelError, err.Error())
return err
}
return nil
}
// Create creates a new client in the store.
func (s *ClientStore) Create(info oauth2.ClientInfo) error {
s.logger.Log(context.Background(), LogLevelDebug, "creating client", "id", info.GetID())
data, err := json.Marshal(info)
if err != nil {
return err
}
_, err = s.pool.Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %[1]s (id, secret, domain, data, created_at)
VALUES ($1, $2, $3, $4, $5)`,
s.table,
), info.GetID(), info.GetSecret(), info.GetDomain(), data, time.Now())
if err != nil {
s.logger.Log(context.Background(), LogLevelError, "creating client failed", "info", info)
return err
}
s.logger.Log(context.Background(), LogLevelDebug, "client created")
return nil
}
// GetByID returns the client information by key from the store.
func (s *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
s.logger.Log(ctx, LogLevelDebug, "getting client by id", "id", id)
row := s.pool.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.table), id)
return s.scanToClientInfo(ctx, row)
}
// NewClientStore creates a new ClientStore.
func NewClientStore(opts ...ClientStoreOption) (*ClientStore, error) {
s := &ClientStore{
table: DefaultClientStoreTable,
logger: new(NoopLogger),
}
for _, o := range opts {
if err := o(s); err != nil {
return nil, err
}
}
if s.pool == nil {
return nil, ErrNoConnPool
}
return s, nil
}