Skip to content

Commit

Permalink
client handler: allow to set/get application specific data
Browse files Browse the repository at this point in the history
  • Loading branch information
drakkan committed Aug 17, 2023
1 parent 19e8ada commit dfab33c
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
9 changes: 9 additions & 0 deletions client_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ type clientHandler struct {
isTransferOpen bool // indicate if the transfer connection is opened
isTransferAborted bool // indicate if the transfer was aborted
tlsRequirement TLSRequirement // TLS requirement to respect
extra any // Additional application-specific data
paramsMutex sync.RWMutex // mutex to protect the parameters exposed to the library users
}

Expand Down Expand Up @@ -245,6 +246,14 @@ func (c *clientHandler) HasTLSForTransfers() bool {
return c.transferTLS
}

func (c *clientHandler) SetExtra(extra any) {
c.extra = extra
}

func (c *clientHandler) Extra() any {
return c.extra
}

func (c *clientHandler) setTLSForTransfer(value bool) {
c.paramsMutex.Lock()
defer c.paramsMutex.Unlock()
Expand Down
33 changes: 33 additions & 0 deletions client_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,36 @@ func TestDataConnectionRequirements(t *testing.T) {
assert.Contains(t, err.Error(), "unhandled data connection requirement")
}
}

func TestExtraData(t *testing.T) {
driver := &TestServerDriver{
Debug: false,
}
s := NewTestServerWithDriver(t, driver)

conf := goftp.Config{
User: authUser,
Password: authPass,
}

c, err := goftp.DialConfig(conf, s.Addr())
require.NoError(t, err, "Couldn't connect")

defer func() { panicOnError(c.Close()) }()

raw, err := c.OpenRawConn()
require.NoError(t, err)

defer func() { require.NoError(t, raw.Close()) }()

info := driver.GetClientsInfo()
require.Len(t, info, 1)

for k, v := range info {
ccInfo, ok := v.(map[string]interface{})
require.True(t, ok)
extra, ok := ccInfo["extra"].(uint32)
require.True(t, ok)
require.Equal(t, k, extra)
}
}
6 changes: 6 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ type ClientContext interface {
// If you want to enforce the same requirement for all
// clients, use the TLSRequired parameter defined in server settings instead
SetTLSRequirement(requirement TLSRequirement) error

// SetExtra allows to set application specific data
SetExtra(extra any)

// Extra returns application specific data set using SetExtra
Extra() any
}

// FileTransfer defines the inferface for file transfers.
Expand Down
3 changes: 3 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ func (driver *TestServerDriver) ClientConnected(cc ClientContext) (string, error
}

cc.SetDebug(driver.Debug)
// we set the client id as extra data just for testing
cc.SetExtra(cc.ID())
driver.Clients = append(driver.Clients, cc)
// This will remain the official name for now
return "TEST Server", err
Expand Down Expand Up @@ -267,6 +269,7 @@ func (driver *TestServerDriver) GetClientsInfo() map[uint32]interface{} {
ccInfo["hasTLSForTransfers"] = cc.HasTLSForTransfers()
ccInfo["lastCommand"] = cc.GetLastCommand()
ccInfo["debug"] = cc.Debug()
ccInfo["extra"] = cc.Extra()

info[cc.ID()] = ccInfo
}
Expand Down

0 comments on commit dfab33c

Please sign in to comment.