Skip to content

Commit 8e7ad48

Browse files
authored
Merge pull request #174 from kbase/credentials-for-dbs
Created an auth.Credential type and threaded it through DTS.
2 parents 276ac9a + 10d1a9d commit 8e7ad48

17 files changed

+221
-291
lines changed

auth/auth.go

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,7 @@
2121

2222
package auth
2323

24-
// A record containing information about a DTS client. A DTS client is a KBase
25-
// user whose KBase developer token is used to authorize with the DTS.
26-
type Client struct {
27-
// client name (human-readable and display-friendly)
28-
Name string
29-
// KBase username (if any) used by client to access DTS
30-
Username string
31-
// client email address
32-
Email string
33-
// ORCID identifier associated with this client
34-
Orcid string
35-
// organization with which this client is affiliated
36-
Organization string
37-
}
38-
39-
// A record containing information about a DTS user using a DTS client to
40-
// request file transfers. A DTS user need not have a KBase developer token
41-
// (but should have a KBase account if they are requesting files be transferred
42-
// to KBase).
24+
// A record containing information about a DTS user using a DTS client to request file transfers.
4325
type User struct {
4426
// name (human-readable and display-friendly)
4527
Name string
@@ -52,3 +34,12 @@ type User struct {
5234
// true if this user is a Superuser
5335
IsSuper bool
5436
}
37+
38+
// A credential used for authorization and authentication
39+
type Credential struct {
40+
// the ID used for authorization (username or UUID)
41+
Id string `yaml:"id"`
42+
// the secret used for authentication (e.g. password)
43+
// DO NOT STORE THIS IN A CONFIG FILE! Use an environment variable instead
44+
Secret string `yaml:"secret"`
45+
}

auth/authenticator.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ import (
2626
"encoding/csv"
2727
"errors"
2828
"os"
29-
"path/filepath"
3029
"strings"
3130
"time"
3231

3332
"github.com/fernet/fernet-go"
34-
35-
"github.com/kbase/dts/config"
3633
)
3734

3835
// This type accepts a valid access token in exchange for a user record. It is
@@ -44,19 +41,20 @@ type Authenticator struct {
4441
TimeOfLastRead time.Time
4542
RereadInterval time.Duration
4643
AccessTokenFile string
44+
Secret string
4745
}
4846

4947
const (
5048
// how often to reread the access token file, in minutes
5149
defaultRereadInterval = time.Minute
52-
// name of the access token file
53-
defaultAccessTokenFile = "access.dat"
5450
)
5551

56-
func NewAuthenticator() (*Authenticator, error) {
52+
// Creates a new authenticator by reading an access token file and decrypting it with a secret.
53+
func NewAuthenticator(accessTokenFile, secret string) (*Authenticator, error) {
5754
var a Authenticator
5855
a.RereadInterval = defaultRereadInterval
59-
a.AccessTokenFile = defaultAccessTokenFile
56+
a.AccessTokenFile = accessTokenFile
57+
a.Secret = secret
6058
err := a.readAccessTokenFile()
6159
if err != nil {
6260
return nil, err
@@ -83,13 +81,12 @@ func (a *Authenticator) GetUser(accessToken string) (User, error) {
8381
}
8482

8583
func (a *Authenticator) readAccessTokenFile() error {
86-
tokenFilePath := filepath.Join(config.Service.DataDirectory, a.AccessTokenFile)
87-
key, err := fernet.DecodeKey(config.Service.Secret)
84+
key, err := fernet.DecodeKey(a.Secret)
8885
if err != nil {
8986
return err
9087
}
9188

92-
cipherText, err := os.ReadFile(tokenFilePath)
89+
cipherText, err := os.ReadFile(a.AccessTokenFile)
9390
if err != nil {
9491
return err
9592
}

auth/authenticator_test.go

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ package auth
2727
import (
2828
"fmt"
2929
"log"
30+
"log/slog"
3031
"os"
3132
"path/filepath"
3233
"testing"
3334
"time"
3435

3536
"github.com/fernet/fernet-go"
3637
"github.com/stretchr/testify/assert"
37-
38-
"github.com/kbase/dts/config"
39-
"github.com/kbase/dts/dtstest"
4038
)
4139

4240
// runs setup, runs all tests, and does breakdown
@@ -51,7 +49,6 @@ func TestMain(m *testing.M) {
5149
func TestRunner(t *testing.T) {
5250
tester := SerialTests{Test: t}
5351
tester.TestNewAuthenticator()
54-
tester.TestInvalidDataDirectory()
5552
tester.TestGetUser()
5653
tester.TestGetUserAfterReread()
5754
tester.TestGetUserAfterBadReread()
@@ -64,6 +61,9 @@ var TestKey fernet.Key
6461
// temporary testing directory
6562
var TestDir string
6663

64+
// testing access token file
65+
var TestAccessTokenFile string
66+
6767
// testing access token
6868
var TestAccessToken string
6969

@@ -77,21 +77,19 @@ var TestUser = User{
7777
}
7878

7979
func setup() {
80-
dtstest.EnableDebugLogging()
80+
enableDebugLogging()
8181

8282
log.Print("Creating testing directory...\n")
8383
var err error
8484
TestDir, err = os.MkdirTemp(os.TempDir(), "data-transfer-service-tests-")
8585
if err != nil {
8686
log.Panicf("Couldn't create testing directory: %s", err.Error())
8787
}
88-
config.Service.DataDirectory = TestDir
8988

9089
err = TestKey.Generate()
9190
if err != nil {
9291
log.Panicf("Couldn't generate encryption key: %s", err.Error())
9392
}
94-
config.Service.Secret = TestKey.Encode()
9593

9694
TestAccessToken = "7029c1877e9c2dd3dab814cc0f2763af"
9795

@@ -106,7 +104,8 @@ func setup() {
106104
log.Panicf("Couldn't encrypt test access data: %s", err.Error())
107105
}
108106

109-
output, err := os.Create(filepath.Join(TestDir, "access.dat"))
107+
TestAccessTokenFile = filepath.Join(TestDir, "access.dat")
108+
output, err := os.Create(TestAccessTokenFile)
110109
if err != nil {
111110
log.Panicf("Couldn't open test access data file: %s", err.Error())
112111
}
@@ -119,6 +118,13 @@ func setup() {
119118
setupKBaseAuthServerTests()
120119
}
121120

121+
func enableDebugLogging() {
122+
logLevel := new(slog.LevelVar)
123+
logLevel.Set(slog.LevelDebug)
124+
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})
125+
slog.SetDefault(slog.New(h))
126+
}
127+
122128
// To run the tests serially, we attach them to a SerialTests type and
123129
// have them run by a a single test runner.
124130
type SerialTests struct{ Test *testing.T }
@@ -127,27 +133,16 @@ type SerialTests struct{ Test *testing.T }
127133
// constructed
128134
func (t *SerialTests) TestNewAuthenticator() {
129135
assert := assert.New(t.Test)
130-
auth, err := NewAuthenticator()
136+
auth, err := NewAuthenticator(TestAccessTokenFile, TestKey.Encode())
131137
assert.NotNil(auth, "Authenticator not created")
132138
assert.Nil(err, "Authenticator constructor triggered an error")
133139
}
134140

135-
// tests the case in which a directory without an encrpyted access.dat file has
136-
// been configured for the authenticator
137-
func (t *SerialTests) TestInvalidDataDirectory() {
138-
assert := assert.New(t.Test)
139-
config.Service.DataDirectory = os.Getenv("HOME")
140-
auth, err := NewAuthenticator()
141-
assert.Nil(auth, "Authenticator created with invalid data directory")
142-
assert.NotNil(err, "Invalid data directory for authenticator triggered no error")
143-
config.Service.DataDirectory = TestDir
144-
}
145-
146141
// tests whether the authenticator server can return information for the
147142
// the user associated with a valid ORCID
148143
func (t *SerialTests) TestGetUser() {
149144
assert := assert.New(t.Test)
150-
auth, err := NewAuthenticator()
145+
auth, err := NewAuthenticator(TestAccessTokenFile, TestKey.Encode())
151146
assert.NotNil(auth)
152147
assert.Nil(err)
153148

@@ -166,7 +161,7 @@ func (t *SerialTests) TestGetUser() {
166161
// user after enough time has passed to trigger a re-read of the access file
167162
func (t *SerialTests) TestGetUserAfterReread() {
168163
assert := assert.New(t.Test)
169-
auth, err := NewAuthenticator()
164+
auth, err := NewAuthenticator(TestAccessTokenFile, TestKey.Encode())
170165
assert.NotNil(auth)
171166
assert.Nil(err)
172167

@@ -188,7 +183,7 @@ func (t *SerialTests) TestGetUserAfterReread() {
188183
// tests whether the authenticator server handles a bad re-read correctly
189184
func (t *SerialTests) TestGetUserAfterBadReread() {
190185
assert := assert.New(t.Test)
191-
auth, err := NewAuthenticator()
186+
auth, err := NewAuthenticator(TestAccessTokenFile, TestKey.Encode())
192187
assert.NotNil(auth)
193188
assert.Nil(err)
194189

@@ -209,7 +204,7 @@ func (t *SerialTests) TestGetUserAfterBadReread() {
209204
// (fictitious ORCID: https://orcid.org/0000-0001-5109-3700)
210205
func (t *SerialTests) TestGetInvalidUser() {
211206
assert := assert.New(t.Test)
212-
auth, _ := NewAuthenticator()
207+
auth, _ := NewAuthenticator(TestAccessTokenFile, TestKey.Encode())
213208
badAccessToken := "c5683570c1412b77eabcb9d6eb0aae2a"
214209
_, err := auth.GetUser(badAccessToken)
215210
assert.NotNil(err)

auth/kbase_auth_server.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,23 @@ func NewKBaseAuthServer(accessToken string, options ...KBaseAuthServerOption) (*
9797
}
9898

9999
// returns a normalized user record for the current KBase user
100-
func (server KBaseAuthServer) Client() (Client, error) {
100+
func (server KBaseAuthServer) User() (User, error) {
101101
kbUser, err := server.kbaseUser()
102102
if err != nil {
103-
return Client{}, err
103+
return User{}, err
104104
}
105-
client := Client{
106-
Name: kbUser.Display,
107-
Username: kbUser.Username,
108-
Email: kbUser.Email,
105+
user := User{
106+
Name: kbUser.Display,
107+
Email: kbUser.Email,
109108
}
110109
for _, pid := range kbUser.Idents {
111110
// grab the first ORCID associated with the user
112111
if pid.Provider == "OrcID" {
113-
client.Orcid = pid.UserName
112+
user.Orcid = pid.UserName
114113
break
115114
}
116115
}
117-
return client, nil
116+
return user, nil
118117
}
119118

120119
//-----------

auth/kbase_auth_server_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,11 @@ func TestClient(t *testing.T) {
200200
if len(devToken) > 0 {
201201
server, _ := NewKBaseAuthServer(devToken)
202202
assert.NotNil(server)
203-
client, err := server.Client()
203+
user, err := server.User()
204204
assert.Nil(err)
205205

206-
assert.True(len(client.Username) > 0)
207-
assert.True(len(client.Email) > 0)
208-
assert.Equal(os.Getenv("DTS_KBASE_TEST_ORCID"), client.Orcid)
206+
assert.True(len(user.Email) > 0)
207+
assert.Equal(os.Getenv("DTS_KBASE_TEST_ORCID"), user.Orcid)
209208
}
210209

211210
// test with the mock server
@@ -214,11 +213,10 @@ func TestClient(t *testing.T) {
214213
cfg.BaseURL = mockKBaseServer.URL
215214
})
216215
assert.NotNil(server, "Authentication server not created with valid token")
217-
client, err := server.Client()
218-
assert.Nil(err, "Client() triggered an error with valid token")
216+
user, err := server.User()
217+
assert.Nil(err, "User() triggered an error with valid token")
219218

220-
assert.Equal("testuser", client.Username)
221-
assert.Equal("Test User", client.Name)
222-
assert.Equal("[email protected]", client.Email)
223-
assert.Equal("testuser", client.Orcid)
219+
assert.Equal("Test User", user.Name)
220+
assert.Equal("[email protected]", user.Email)
221+
assert.Equal("testuser", user.Orcid)
224222
}

config/config.go

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929

3030
"github.com/google/uuid"
3131
"gopkg.in/yaml.v3"
32+
33+
"github.com/kbase/dts/auth"
3234
)
3335

3436
// a type with service configuration parameters
@@ -66,7 +68,7 @@ type serviceConfig struct {
6668

6769
// global config variables
6870
var Service serviceConfig
69-
var Credentials map[string]credentialConfig
71+
var Credentials map[string]auth.Credential
7072
var Endpoints map[string]endpointConfig
7173
var Databases map[string]databaseConfig
7274

@@ -77,10 +79,10 @@ var Databases map[string]databaseConfig
7779
// data around internally, but this is not yet complete. Once that is done, the
7880
// global variables above can be removed.
7981
type Config struct {
80-
Service serviceConfig `yaml:"service"`
81-
Credentials map[string]credentialConfig `yaml:"credentials"`
82-
Databases map[string]databaseConfig `yaml:"databases"`
83-
Endpoints map[string]endpointConfig `yaml:"endpoints"`
82+
Service serviceConfig `yaml:"service"`
83+
Credentials map[string]auth.Credential `yaml:"credentials"`
84+
Databases map[string]databaseConfig `yaml:"databases"`
85+
Endpoints map[string]endpointConfig `yaml:"endpoints"`
8486
}
8587

8688
// This helper locates and reads the selected sections in a configuration file,
@@ -167,16 +169,6 @@ func (params serviceConfig) Validate() error {
167169
return nil
168170
}
169171

170-
func (credential credentialConfig) Validate(name string) error {
171-
if credential.Id == "" {
172-
return &InvalidCredentialConfigError{
173-
Credential: name,
174-
Message: "Invalid credential ID",
175-
}
176-
}
177-
return nil
178-
}
179-
180172
func (endpoint endpointConfig) Validate(name string) error {
181173
if endpoint.Id == uuid.Nil { // invalid endpoint UUID
182174
return &InvalidEndpointConfigError{
@@ -236,15 +228,6 @@ func (c Config) Validate(service, credentials, databases, endpoints bool) error
236228
}
237229
}
238230

239-
if credentials {
240-
for name, credential := range c.Credentials {
241-
err = credential.Validate(name)
242-
if err != nil {
243-
return err
244-
}
245-
}
246-
}
247-
248231
if endpoints {
249232
if len(c.Endpoints) == 0 {
250233
return &InvalidServiceConfigError{

0 commit comments

Comments
 (0)