diff --git a/channel/consensus_channel/consensus_channel.go b/channel/consensus_channel/consensus_channel.go index aed6fe36a..7477351ea 100644 --- a/channel/consensus_channel/consensus_channel.go +++ b/channel/consensus_channel/consensus_channel.go @@ -214,7 +214,7 @@ func (c *ConsensusChannel) Follower() common.Address { // FundingTargets returns a list of channels funded by the ConsensusChannel func (c *ConsensusChannel) FundingTargets() []types.Destination { - return c.current.Outcome.fundingTargets() + return c.current.Outcome.FundingTargets() } func (c *ConsensusChannel) Accept(p SignedProposal) error { @@ -496,8 +496,8 @@ func (o *LedgerOutcome) AsOutcome() outcome.Exit { } } -// fundingTargets returns a list of channels funded by the LedgerOutcome -func (o *LedgerOutcome) fundingTargets() []types.Destination { +// FundingTargets returns a list of channels funded by the LedgerOutcome +func (o LedgerOutcome) FundingTargets() []types.Destination { targets := []types.Destination{} for dest := range o.guarantees { diff --git a/client/client.go b/client/client.go index 51ce10008..1b9d2a9a3 100644 --- a/client/client.go +++ b/client/client.go @@ -255,6 +255,16 @@ func (c *Client) GetPaymentChannel(id types.Destination) (query.PaymentChannelIn return query.GetPaymentChannelInfo(id, c.store, c.vm) } +// GetPaymentChannelsByLedger returns all active payment channels that are funded by the given ledger channel. +func (c *Client) GetPaymentChannelsByLedger(ledgerId types.Destination) ([]query.PaymentChannelInfo, error) { + return query.GetPaymentChannelsByLedger(ledgerId, c.store, c.vm) +} + +// GetAllLedgerChannels returns all ledger channels. +func (c *Client) GetAllLedgerChannels() ([]query.LedgerChannelInfo, error) { + return query.GetAllLedgerChannels(c.store, c.engine.GetConsensusAppAddress()) +} + // GetLedgerChannel returns the ledger channel with the given id. // If no ledger channel exists with the given id an error is returned. func (c *Client) GetLedgerChannel(id types.Destination) (query.LedgerChannelInfo, error) { diff --git a/client/engine/store/durablestore.go b/client/engine/store/durablestore.go index e32fd031e..d56fe3792 100644 --- a/client/engine/store/durablestore.go +++ b/client/engine/store/durablestore.go @@ -266,6 +266,73 @@ func (ds *DurableStore) getChannelById(id types.Destination) (channel.Channel, e return ch, nil } +// GetChannelsByIds returns any channels with ids in the supplied list. +func (ds *DurableStore) GetChannelsByIds(ids []types.Destination) ([]*channel.Channel, error) { + toReturn := []*channel.Channel{} + // We know every channel has a unique id + // so we can stop looking once we've found the correct number of channels + + var err error + + txError := ds.channels.View(func(tx *buntdb.Tx) error { + return tx.Ascend("", func(key, chJSON string) bool { + var ch channel.Channel + err = json.Unmarshal([]byte(chJSON), &ch) + if err != nil { + return false + } + + // If the channel is one of the ones we're looking for, add it to the list + if contains(ids, ch.Id) { + toReturn = append(toReturn, &ch) + } + + // If we've found all the channels we need, stop looking + if len(toReturn) == len(ids) { + return false + } + return true // otherwise, continue looking + }) + }) + + if txError != nil { + return []*channel.Channel{}, txError + } + if err != nil { + return []*channel.Channel{}, err + } + + return toReturn, nil +} + +// GetChannelsByAppDefinition returns any channels that include the given app definition +func (ds *DurableStore) GetChannelsByAppDefinition(appDef types.Address) ([]*channel.Channel, error) { + toReturn := []*channel.Channel{} + var unmarshErr error + err := ds.channels.View(func(tx *buntdb.Tx) error { + return tx.Ascend("", func(key, chJSON string) bool { + var ch channel.Channel + unmarshErr = json.Unmarshal([]byte(chJSON), &ch) + if unmarshErr != nil { + return false + } + + if ch.AppDefinition == appDef { + toReturn = append(toReturn, &ch) + } + + return true + }) + }) + if err != nil { + return []*channel.Channel{}, err + } + if unmarshErr != nil { + return []*channel.Channel{}, unmarshErr + } + return toReturn, nil +} + // GetChannelsByParticipant returns any channels that include the given participant func (ds *DurableStore) GetChannelsByParticipant(participant types.Address) []*channel.Channel { toReturn := []*channel.Channel{} @@ -292,6 +359,31 @@ func (ds *DurableStore) GetChannelsByParticipant(participant types.Address) []*c return toReturn } +func (ds *DurableStore) GetAllConsensusChannels() ([]*consensus_channel.ConsensusChannel, error) { + toReturn := []*consensus_channel.ConsensusChannel{} + var unmarshErr error + err := ds.consensusChannels.View(func(tx *buntdb.Tx) error { + return tx.Ascend("", func(key, chJSON string) bool { + var ch consensus_channel.ConsensusChannel + + unmarshErr = json.Unmarshal([]byte(chJSON), &ch) + if unmarshErr != nil { + return false + } + toReturn = append(toReturn, &ch) + return true + }) + }) + if err != nil { + return []*consensus_channel.ConsensusChannel{}, err + } + + if unmarshErr != nil { + return []*consensus_channel.ConsensusChannel{}, unmarshErr + } + return toReturn, nil +} + // GetConsensusChannelById returns a ConsensusChannel with the given channel id func (ds *DurableStore) GetConsensusChannelById(id types.Destination) (channel *consensus_channel.ConsensusChannel, err error) { var ch *consensus_channel.ConsensusChannel diff --git a/client/engine/store/memstore.go b/client/engine/store/memstore.go index 542d0ab0b..ddbc63954 100644 --- a/client/engine/store/memstore.go +++ b/client/engine/store/memstore.go @@ -179,6 +179,61 @@ func (ms *MemStore) getChannelById(id types.Destination) (channel.Channel, error return ch, nil } +// GetChannelsByIds returns a collection of channels with the given ids +func (ms *MemStore) GetChannelsByIds(ids []types.Destination) ([]*channel.Channel, error) { + toReturn := []*channel.Channel{} + + var err error + + ms.channels.Range(func(key string, chJSON []byte) bool { + var ch channel.Channel + err = json.Unmarshal(chJSON, &ch) + if err != nil { + return false + } + + // If the channel is one of the ones we're looking for, add it to the list + if contains(ids, ch.Id) { + toReturn = append(toReturn, &ch) + } + + // If we've found all the channels we need, stop looking + if len(toReturn) == len(ids) { + return false + } + + return true // otherwise, continue looking + }) + if err != nil { + return []*channel.Channel{}, err + } + return toReturn, nil +} + +// GetChannelsByAppDefinition returns any channels that include the given app definition +func (ms *MemStore) GetChannelsByAppDefinition(appDef types.Address) ([]*channel.Channel, error) { + toReturn := []*channel.Channel{} + var err error + ms.channels.Range(func(key string, chJSON []byte) bool { + var ch channel.Channel + err = json.Unmarshal(chJSON, &ch) + if err != nil { + return false + } + if ch.AppDefinition == appDef { + toReturn = append(toReturn, &ch) + } + + return true // channel not found: continue looking + }) + + if err != nil { + return []*channel.Channel{}, err + } + + return toReturn, nil +} + // GetChannelsByParticipant returns any channels that include the given participant func (ms *MemStore) GetChannelsByParticipant(participant types.Address) []*channel.Channel { toReturn := []*channel.Channel{} @@ -245,6 +300,26 @@ func (ms *MemStore) GetConsensusChannel(counterparty types.Address) (channel *co return } +func (ms *MemStore) GetAllConsensusChannels() ([]*consensus_channel.ConsensusChannel, error) { + toReturn := []*consensus_channel.ConsensusChannel{} + var err error + ms.consensusChannels.Range(func(key string, chJSON []byte) bool { + var ch consensus_channel.ConsensusChannel + + err = json.Unmarshal(chJSON, &ch) + if err != nil { + return false + } + + toReturn = append(toReturn, &ch) + return true // channel not found: continue looking + }) + if err != nil { + return nil, err + } + return toReturn, nil +} + func (ms *MemStore) GetObjectiveByChannelId(channelId types.Destination) (protocols.Objective, bool) { // todo: locking id, found := ms.channelToObjective.Load(channelId.String()) @@ -404,3 +479,13 @@ func (ms *MemStore) RemoveVoucherInfo(channelId types.Destination) error { ms.vouchers.Delete(channelId.String()) return nil } + +// contains is a helper function which returns true if the given item is included in col +func contains[T types.Destination | protocols.ObjectiveId](col []T, item T) bool { + for _, i := range col { + if i == item { + return true + } + } + return false +} diff --git a/client/engine/store/store.go b/client/engine/store/store.go index 361662e09..a2694f526 100644 --- a/client/engine/store/store.go +++ b/client/engine/store/store.go @@ -19,19 +19,18 @@ var ( // Store is responsible for persisting objectives, objective metadata, states, signatures, private keys and blockchain data type Store interface { - GetChannelSecretKey() *[]byte // Get a pointer to a secret key for signing channel updates - GetAddress() *types.Address // Get the (Ethereum) address associated with the ChannelSecretKey - + GetChannelSecretKey() *[]byte // Get a pointer to a secret key for signing channel updates + GetAddress() *types.Address // Get the (Ethereum) address associated with the ChannelSecretKey GetObjectiveById(protocols.ObjectiveId) (protocols.Objective, error) // Read an existing objective GetObjectiveByChannelId(types.Destination) (obj protocols.Objective, ok bool) // Get the objective that currently owns the channel with the supplied ChannelId SetObjective(protocols.Objective) error // Write an objective - + GetChannelsByIds(ids []types.Destination) ([]*channel.Channel, error) // Returns a collection of channels with the given ids GetChannelById(id types.Destination) (c *channel.Channel, ok bool) GetChannelsByParticipant(participant types.Address) []*channel.Channel // Returns any channels that includes the given participant SetChannel(*channel.Channel) error DestroyChannel(id types.Destination) - - ReleaseChannelFromOwnership(types.Destination) // Release channel from being owned by any objective + GetChannelsByAppDefinition(appDef types.Address) ([]*channel.Channel, error) // Returns any channels that includes the given app definition + ReleaseChannelFromOwnership(types.Destination) // Release channel from being owned by any objective ConsensusChannelStore payments.VoucherStore @@ -39,6 +38,7 @@ type Store interface { } type ConsensusChannelStore interface { + GetAllConsensusChannels() ([]*consensus_channel.ConsensusChannel, error) GetConsensusChannel(counterparty types.Address) (channel *consensus_channel.ConsensusChannel, ok bool) GetConsensusChannelById(id types.Destination) (channel *consensus_channel.ConsensusChannel, err error) SetConsensusChannel(*consensus_channel.ConsensusChannel) error diff --git a/client/query/query.go b/client/query/query.go index 7d7cfe196..67e29eb60 100644 --- a/client/query/query.go +++ b/client/query/query.go @@ -1,6 +1,7 @@ package query import ( + "errors" "fmt" "math/big" @@ -127,6 +128,65 @@ func GetPaymentChannelInfo(id types.Destination, store store.Store, vm *payments return PaymentChannelInfo{}, fmt.Errorf("could not find channel with id %v", id) } +// GetAllLedgerChannels returns a `LedgerChannelInfo` for each ledger channel in the store. +func GetAllLedgerChannels(store store.Store, consensusAppDefinition types.Address) ([]LedgerChannelInfo, error) { + toReturn := []LedgerChannelInfo{} + + allConsensus, err := store.GetAllConsensusChannels() + if err != nil { + return []LedgerChannelInfo{}, err + } + for _, con := range allConsensus { + toReturn = append(toReturn, ConstructLedgerInfoFromConsensus(con)) + } + allChannels, err := store.GetChannelsByAppDefinition(consensusAppDefinition) + if err != nil { + return []LedgerChannelInfo{}, err + } + for _, c := range allChannels { + toReturn = append(toReturn, ConstructLedgerInfoFromChannel(c)) + } + return toReturn, nil +} + +// GetPaymentChannelsByLedger returns a `PaymentChannelInfo` for each active payment channel funded by the given ledger channel. +func GetPaymentChannelsByLedger(ledgerId types.Destination, s store.Store, vm *payments.VoucherManager) ([]PaymentChannelInfo, error) { + // If a ledger channel is actively funding payment channels it must be in the form of a consensus channel + con, err := s.GetConsensusChannelById(ledgerId) + // If the ledger channel is not a consensus channel we know that there are no payment channels funded by it + if errors.Is(err, store.ErrNoSuchChannel) { + return []PaymentChannelInfo{}, nil + } + if err != nil { + return []PaymentChannelInfo{}, fmt.Errorf("could not find any payment channels funded by %s: %w", ledgerId, err) + } + + toQuery := con.ConsensusVars().Outcome.FundingTargets() + + paymentChannels, err := s.GetChannelsByIds(toQuery) + if err != nil { + return []PaymentChannelInfo{}, fmt.Errorf("could not query the store about ids %v: %w", toQuery, err) + } + + toReturn := []PaymentChannelInfo{} + for _, p := range paymentChannels { + paid, remaining, err := GetVoucherBalance(p.Id, vm) + if err != nil { + return []PaymentChannelInfo{}, err + } + // TODO: n+1 query problem + // We should query for the vfos in bulk, rather than one at a time + // Or we should be able to determine the status soley from the channel + vfo, _ := GetVirtualFundObjective(p.Id, s) + info, err := ConstructPaymentInfo(p, vfo, paid, remaining) + if err != nil { + return []PaymentChannelInfo{}, err + } + toReturn = append(toReturn, info) + } + return toReturn, nil +} + // GetLedgerChannelInfo returns the LedgerChannelInfo for the given channel // It does this by querying the provided store func GetLedgerChannelInfo(id types.Destination, store store.Store) (LedgerChannelInfo, error) { @@ -166,7 +226,6 @@ func ConstructLedgerInfoFromChannel(c *channel.Channel) LedgerChannelInfo { func ConstructPaymentInfo(c *channel.Channel, vfo *virtualfund.Objective, paid, remaining *big.Int) (PaymentChannelInfo, error) { status := getStatusFromChannel(c) - if vfo != nil && vfo.Status == protocols.Completed { // This means intermediaries may not have a fully signed postfund state even though the channel is "ready" // To determine the the correct status we check the status of the virtual fund objective diff --git a/client_test/rpc_test.go b/client_test/rpc_test.go index c87cd45f8..5e4ec92d6 100644 --- a/client_test/rpc_test.go +++ b/client_test/rpc_test.go @@ -84,9 +84,11 @@ func executeRpcTest(t *testing.T, connectionType transport.TransportType) { expectedAliceLedger := expectedLedgerInfo(res.ChannelId, aliceLedgerOutcome, query.Ready) checkQueryInfo(t, expectedAliceLedger, rpcClientA.GetLedgerChannel(res.ChannelId)) + checkQueryInfoCollection(t, expectedAliceLedger, 1, rpcClientA.GetAllLedgerChannels()) expectedBobLedger := expectedLedgerInfo(bobResponse.ChannelId, bobLedgerOutcome, query.Ready) checkQueryInfo(t, expectedBobLedger, rpcClientB.GetLedgerChannel(bobResponse.ChannelId)) + checkQueryInfoCollection(t, expectedBobLedger, 1, rpcClientB.GetAllLedgerChannels()) initialOutcome := testdata.Outcomes.Create(ta.Alice.Address(), ta.Bob.Address(), 100, 0, types.Address{}) vRes := rpcClientA.CreateVirtual( @@ -105,11 +107,16 @@ func executeRpcTest(t *testing.T, connectionType transport.TransportType) { expectedVirtual := expectedPaymentInfo(vRes.ChannelId, initialOutcome, query.Ready) aliceVirtual := rpcClientA.GetVirtualChannel(vRes.ChannelId) checkQueryInfo(t, expectedVirtual, aliceVirtual) + checkQueryInfoCollection(t, expectedVirtual, 1, rpcClientA.GetPaymentChannelsByLedger(res.ChannelId)) + bobVirtual := rpcClientB.GetVirtualChannel(vRes.ChannelId) checkQueryInfo(t, expectedVirtual, bobVirtual) + checkQueryInfoCollection(t, expectedVirtual, 1, rpcClientB.GetPaymentChannelsByLedger(bobResponse.ChannelId)) + ireneVirtual := rpcClientI.GetVirtualChannel(vRes.ChannelId) checkQueryInfo(t, expectedVirtual, ireneVirtual) - + checkQueryInfoCollection(t, expectedVirtual, 1, rpcClientI.GetPaymentChannelsByLedger(bobResponse.ChannelId)) + checkQueryInfoCollection(t, expectedVirtual, 1, rpcClientI.GetPaymentChannelsByLedger(res.ChannelId)) rpcClientA.Pay(vRes.ChannelId, 1) closeVId := rpcClientA.CloseVirtual(vRes.ChannelId) @@ -125,6 +132,19 @@ func executeRpcTest(t *testing.T, connectionType transport.TransportType) { <-rpcClientB.ObjectiveCompleteChan(closeIdB) <-rpcClientI.ObjectiveCompleteChan(closeIdB) + if len(rpcClientA.GetPaymentChannelsByLedger(res.ChannelId)) != 0 { + t.Error("Alice should not have any payment channels open") + } + if len(rpcClientB.GetPaymentChannelsByLedger(bobResponse.ChannelId)) != 0 { + t.Error("Bob should not have any payment channels open") + } + if len(rpcClientI.GetPaymentChannelsByLedger(res.ChannelId)) != 0 { + t.Error("Irene should not have any payment channels open") + } + if len(rpcClientI.GetPaymentChannelsByLedger(bobResponse.ChannelId)) != 0 { + t.Error("Irene should not have any payment channels open") + } + expectedAliceLedgerNotifs := []query.LedgerChannelInfo{ expectedLedgerInfo(res.ChannelId, simpleOutcome(ta.Alice.Address(), ta.Irene.Address(), 100, 100), query.Proposed), expectedLedgerInfo(res.ChannelId, simpleOutcome(ta.Alice.Address(), ta.Irene.Address(), 100, 100), query.Ready), @@ -235,6 +255,22 @@ func checkQueryInfo[T channelInfo](t *testing.T, expected T, fetched T) { } } +func checkQueryInfoCollection[T channelInfo](t *testing.T, expected T, expectedLength int, fetched []T) { + if len(fetched) != expectedLength { + t.Fatalf("expected %d channel infos, got %d", expectedLength, len(fetched)) + } + found := false + for _, fetched := range fetched { + if cmp.Equal(expected, fetched, cmp.AllowUnexported(big.Int{})) { + found = true + break + } + } + if !found { + panic(fmt.Errorf("did not find info %v in channel infos: %v", expected, fetched)) + } +} + // checkNotifications checks that the expected notifications are received on the notifChan. // Due to the async nature of RPC notifications (and how quickly are clients communicate), the order of the notifications is not guaranteed. // This function checks that all the expected notifications are received, but not in any particular order. diff --git a/rpc/client.go b/rpc/client.go index 94cfa9a0a..53cec701d 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -51,7 +51,7 @@ func NewRpcClient(rpcServerUrl string, myAddress types.Address, logger zerolog.L func (rc *RpcClient) GetVirtualChannel(id types.Destination) query.PaymentChannelInfo { req := serde.GetPaymentChannelRequest{Id: id} - return waitForRequest[serde.GetPaymentChannelRequest, query.PaymentChannelInfo](rc, req) + return waitForRequest[serde.GetPaymentChannelRequest, query.PaymentChannelInfo](rc, serde.GetPaymentChannelRequestMethod, req) } // CreateLedger creates a new ledger channel @@ -64,7 +64,7 @@ func (rc *RpcClient) CreateVirtual(intermediaries []types.Address, counterparty rand.Uint64(), common.Address{}) - return waitForRequest[virtualfund.ObjectiveRequest, virtualfund.ObjectiveResponse](rc, objReq) + return waitForRequest[virtualfund.ObjectiveRequest, virtualfund.ObjectiveResponse](rc, serde.VirtualFundRequestMethod, objReq) } // CloseVirtual closes a virtual channel @@ -72,13 +72,23 @@ func (rc *RpcClient) CloseVirtual(id types.Destination) protocols.ObjectiveId { objReq := virtualdefund.NewObjectiveRequest( id) - return waitForRequest[virtualdefund.ObjectiveRequest, protocols.ObjectiveId](rc, objReq) + return waitForRequest[virtualdefund.ObjectiveRequest, protocols.ObjectiveId](rc, serde.VirtualDefundRequestMethod, objReq) } func (rc *RpcClient) GetLedgerChannel(id types.Destination) query.LedgerChannelInfo { req := serde.GetLedgerChannelRequest{Id: id} - return waitForRequest[serde.GetLedgerChannelRequest, query.LedgerChannelInfo](rc, req) + return waitForRequest[serde.GetLedgerChannelRequest, query.LedgerChannelInfo](rc, serde.GetLedgerChannelRequestMethod, req) +} + +// GetAllLedgerChannels returns all ledger channels +func (rc *RpcClient) GetAllLedgerChannels() []query.LedgerChannelInfo { + return waitForRequest[serde.NoPayloadRequest, []query.LedgerChannelInfo](rc, serde.GetAllLedgerChannelsMethod, struct{}{}) +} + +// GetPaymentChannelsByLedger returns all active payment channels for a given ledger channel +func (rc *RpcClient) GetPaymentChannelsByLedger(ledgerId types.Destination) []query.PaymentChannelInfo { + return waitForRequest[serde.GetPaymentChannelsByLedgerRequest, []query.PaymentChannelInfo](rc, serde.GetPaymentChannelsByLedgerMethod, serde.GetPaymentChannelsByLedgerRequest{LedgerId: ledgerId}) } // CreateLedger creates a new ledger channel @@ -90,21 +100,21 @@ func (rc *RpcClient) CreateLedger(counterparty types.Address, ChallengeDuration rand.Uint64(), common.Address{}) - return waitForRequest[directfund.ObjectiveRequest, directfund.ObjectiveResponse](rc, objReq) + return waitForRequest[directfund.ObjectiveRequest, directfund.ObjectiveResponse](rc, serde.DirectFundRequestMethod, objReq) } // CloseLedger closes a ledger channel func (rc *RpcClient) CloseLedger(id types.Destination) protocols.ObjectiveId { objReq := directdefund.NewObjectiveRequest(id) - return waitForRequest[directdefund.ObjectiveRequest, protocols.ObjectiveId](rc, objReq) + return waitForRequest[directdefund.ObjectiveRequest, protocols.ObjectiveId](rc, serde.DirectDefundRequestMethod, objReq) } // Pay uses the specified channel to pay the specified amount func (rc *RpcClient) Pay(id types.Destination, amount uint64) { pReq := serde.PaymentRequest{Amount: amount, Channel: id} - waitForRequest[serde.PaymentRequest, serde.PaymentRequest](rc, pReq) + waitForRequest[serde.PaymentRequest, serde.PaymentRequest](rc, serde.PayRequestMethod, pReq) } func (rc *RpcClient) Close() { @@ -154,8 +164,8 @@ func (rc *RpcClient) subscribeToNotifications() error { return err } -func waitForRequest[T serde.RequestPayload, U serde.ResponsePayload](rc *RpcClient, requestData T) U { - resChan, err := request[T, U](rc.transport, requestData, rc.logger) +func waitForRequest[T serde.RequestPayload, U serde.ResponsePayload](rc *RpcClient, method serde.RequestMethod, requestData T) U { + resChan, err := request[T, U](rc.transport, method, requestData, rc.logger) if err != nil { panic(err) } @@ -188,28 +198,13 @@ func (rc *RpcClient) PaymentChannelUpdatesChan(paymentChannelId types.Destinatio // request uses the supplied transport and payload to send a non-blocking JSONRPC request. // It returns a channel that sends a response payload. If the request fails to send, an error is returned. -func request[T serde.RequestPayload, U serde.ResponsePayload](trans transport.Requester, request T, logger zerolog.Logger) (<-chan response[U], error) { +func request[T serde.RequestPayload, U serde.ResponsePayload](trans transport.Requester, method serde.RequestMethod, reqPayload T, logger zerolog.Logger) (<-chan response[U], error) { + return sendRPCRequest[T, U](method, reqPayload, trans, logger) +} + +func sendRPCRequest[T serde.RequestPayload, U serde.ResponsePayload](method serde.RequestMethod, request T, trans transport.Requester, logger zerolog.Logger) (<-chan response[U], error) { returnChan := make(chan response[U], 1) - var method serde.RequestMethod - switch any(request).(type) { - case directfund.ObjectiveRequest: - method = serde.DirectFundRequestMethod - case directdefund.ObjectiveRequest: - method = serde.DirectDefundRequestMethod - case virtualfund.ObjectiveRequest: - method = serde.VirtualFundRequestMethod - case virtualdefund.ObjectiveRequest: - method = serde.VirtualDefundRequestMethod - case serde.PaymentRequest: - method = serde.PayRequestMethod - case serde.GetLedgerChannelRequest: - method = serde.GetLedgerChannelRequestMethod - case serde.GetPaymentChannelRequest: - method = serde.GetPaymentChannelRequestMethod - default: - return nil, fmt.Errorf("unknown request type %v", request) - } requestId := rand.Uint64() message := serde.NewJsonRpcRequest(requestId, method, request) data, err := json.Marshal(message) @@ -237,7 +232,6 @@ func request[T serde.RequestPayload, U serde.ResponsePayload](trans transport.Re returnChan <- response[U]{jsonResponse.Result, nil} }() - return returnChan, nil } diff --git a/rpc/serde/jsonrpc.go b/rpc/serde/jsonrpc.go index 53fc4935f..645dc206e 100644 --- a/rpc/serde/jsonrpc.go +++ b/rpc/serde/jsonrpc.go @@ -13,15 +13,17 @@ import ( type RequestMethod string const ( - GetAddressMethod RequestMethod = "get_address" - VersionMethod RequestMethod = "version" - DirectFundRequestMethod RequestMethod = "direct_fund" - DirectDefundRequestMethod RequestMethod = "direct_defund" - VirtualFundRequestMethod RequestMethod = "virtual_fund" - VirtualDefundRequestMethod RequestMethod = "virtual_defund" - PayRequestMethod RequestMethod = "pay" - GetPaymentChannelRequestMethod RequestMethod = "get_payment_channel" - GetLedgerChannelRequestMethod RequestMethod = "get_ledger_channel" + GetAddressMethod RequestMethod = "get_address" + VersionMethod RequestMethod = "version" + DirectFundRequestMethod RequestMethod = "direct_fund" + DirectDefundRequestMethod RequestMethod = "direct_defund" + VirtualFundRequestMethod RequestMethod = "virtual_fund" + VirtualDefundRequestMethod RequestMethod = "virtual_defund" + PayRequestMethod RequestMethod = "pay" + GetPaymentChannelRequestMethod RequestMethod = "get_payment_channel" + GetLedgerChannelRequestMethod RequestMethod = "get_ledger_channel" + GetPaymentChannelsByLedgerMethod RequestMethod = "get_payment_channels_by_ledger" + GetAllLedgerChannelsMethod RequestMethod = "get_all_ledger_channels" ) type NotificationMethod string @@ -48,6 +50,9 @@ type GetPaymentChannelRequest struct { type GetLedgerChannelRequest struct { Id types.Destination } +type GetPaymentChannelsByLedgerRequest struct { + LedgerId types.Destination +} type ( NoPayloadRequest = struct{} @@ -61,6 +66,7 @@ type RequestPayload interface { PaymentRequest | GetLedgerChannelRequest | GetPaymentChannelRequest | + GetPaymentChannelsByLedgerRequest | NoPayloadRequest } @@ -77,19 +83,25 @@ type JsonRpcRequest[T RequestPayload | NotificationPayload] struct { Params T `json:"params"` } +type VersionResponse = string + type ( - VersionResponse = string - ResponsePayload interface { - directfund.ObjectiveResponse | - protocols.ObjectiveId | - virtualfund.ObjectiveResponse | - PaymentRequest | - query.PaymentChannelInfo | - query.LedgerChannelInfo | - VersionResponse - } + GetAllLedgersResponse = []query.LedgerChannelInfo + GetPaymentChannelsByLedgerResponse = []query.PaymentChannelInfo ) +type ResponsePayload interface { + directfund.ObjectiveResponse | + protocols.ObjectiveId | + virtualfund.ObjectiveResponse | + PaymentRequest | + query.PaymentChannelInfo | + query.LedgerChannelInfo | + VersionResponse | + GetAllLedgersResponse | + GetPaymentChannelsByLedgerResponse +} + type JsonRpcResponse[T ResponsePayload] struct { Jsonrpc string `json:"jsonrpc"` Id uint64 `json:"id"` diff --git a/rpc/server.go b/rpc/server.go index 94d2a6995..5c3655e70 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -118,6 +118,24 @@ func (rs *RpcServer) registerHandlers() (err error) { } return l }) + case serde.GetAllLedgerChannelsMethod: + return processRequest(rs, requestData, func(r serde.NoPayloadRequest) []query.LedgerChannelInfo { + ledgers, err := rs.client.GetAllLedgerChannels() + if err != nil { + // TODO: What's the best way to handle this error? + panic(err) + } + return ledgers + }) + case serde.GetPaymentChannelsByLedgerMethod: + return processRequest(rs, requestData, func(r serde.GetPaymentChannelsByLedgerRequest) []query.PaymentChannelInfo { + payChs, err := rs.client.GetPaymentChannelsByLedger(r.LedgerId) + if err != nil { + // TODO: What's the best way to handle this error? + panic(err) + } + return payChs + }) default: responseErr := methodNotFoundError responseErr.Id = validationResult.Id