Skip to content

Commit

Permalink
feat: update subscription allocation during session update
Browse files Browse the repository at this point in the history
  • Loading branch information
ironman0x7b2 committed Jan 11, 2025
1 parent 2cb0e2e commit 387da58
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 33 deletions.
2 changes: 2 additions & 0 deletions x/session/expected/keeper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package expected

import (
sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"

Expand All @@ -27,5 +28,6 @@ type NodeKeeper interface {

type SubscriptionKeeper interface {
SessionInactivePreHook(ctx sdk.Context, id uint64) error
SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error
UpdateSessionMaxValues(ctx sdk.Context, session v3.Session) error
}
9 changes: 9 additions & 0 deletions x/session/keeper/alias.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper

import (
sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/sentinel-official/hub/v12/x/session/types/v3"
Expand Down Expand Up @@ -33,6 +34,14 @@ func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
return nil
}

func (k *Keeper) SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error {
if err := k.subscription.SessionUpdatePreHook(ctx, id, currBytes); err != nil {
return err
}

return nil
}

func (k *Keeper) UpdateMaxValues(ctx sdk.Context, session v3.Session) error {
if err := k.node.UpdateSessionMaxValues(ctx, session); err != nil {
return err
Expand Down
16 changes: 15 additions & 1 deletion x/session/keeper/msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ func (k *Keeper) HandleMsgUpdateSession(ctx sdk.Context, msg *v3.MsgUpdateSessio
return nil, types.NewErrorInvalidSessionStatus(session.GetID(), session.GetStatus())
}

if k.ProofVerificationEnabled(ctx) {
if msg.DownloadBytes.LT(session.GetDownloadBytes()) {
return nil, types.NewErrorInvalidDownloadBytes(msg.DownloadBytes)
}
if msg.UploadBytes.LT(session.GetUploadBytes()) {
return nil, types.NewErrorInvalidUploadBytes(msg.UploadBytes)
}
if msg.Duration < session.GetDuration() {
return nil, types.NewErrorInvalidDuration(msg.Duration)
}

if ok := k.ProofVerificationEnabled(ctx); ok {
accAddr, err := sdk.AccAddressFromBech32(session.GetAccAddress())
if err != nil {
return nil, err
Expand All @@ -77,6 +87,10 @@ func (k *Keeper) HandleMsgUpdateSession(ctx sdk.Context, msg *v3.MsgUpdateSessio
}
}

if err := k.SessionUpdatePreHook(ctx, session.GetID(), msg.Bytes()); err != nil {
return nil, err
}

if session.GetStatus().Equal(v1base.StatusActive) {
k.DeleteSessionForInactiveAt(ctx, session.GetInactiveAt(), session.GetID())
}
Expand Down
29 changes: 25 additions & 4 deletions x/session/types/errors.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
package types

import (
"time"

sdkerrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"

v1base "github.com/sentinel-official/hub/v12/types/v1"
)

var (
ErrInvalidMessage = sdkerrors.Register(ModuleName, 101, "invalid message")

ErrInvalidSessionStatus = sdkerrors.Register(ModuleName, 201, "invalid session status")
ErrInvalidSignature = sdkerrors.Register(ModuleName, 202, "invalid signature")
ErrSessionNotFound = sdkerrors.Register(ModuleName, 203, "session not found")
ErrUnauthorized = sdkerrors.Register(ModuleName, 204, "unauthorized")
ErrInvalidDownloadBytes = sdkerrors.Register(ModuleName, 201, "invalid download bytes")
ErrInvalidDuration = sdkerrors.Register(ModuleName, 202, "invalid duration")
ErrInvalidSessionStatus = sdkerrors.Register(ModuleName, 203, "invalid session status")
ErrInvalidSignature = sdkerrors.Register(ModuleName, 204, "invalid signature")
ErrInvalidUploadBytes = sdkerrors.Register(ModuleName, 205, "invalid upload bytes")
ErrSessionNotFound = sdkerrors.Register(ModuleName, 206, "session not found")
ErrUnauthorized = sdkerrors.Register(ModuleName, 207, "unauthorized")
)

// NewErrorInvalidDownloadBytes returns an error indicating that the download bytes are invalid.
func NewErrorInvalidDownloadBytes(bytes sdkmath.Int) error {
return sdkerrors.Wrapf(ErrInvalidDownloadBytes, "invalid download bytes %s", bytes)
}

// NewErrorInvalidDuration returns an error indicating that the specified duration is invalid.
func NewErrorInvalidDuration(duration time.Duration) error {
return sdkerrors.Wrapf(ErrInvalidDuration, "invalid duration %d", duration)
}

// NewErrorInvalidSessionStatus returns an error indicating that the provided status is invalid for the session.
func NewErrorInvalidSessionStatus(id uint64, status v1base.Status) error {
return sdkerrors.Wrapf(ErrInvalidSessionStatus, "invalid status %s for session %d", status, id)
Expand All @@ -25,6 +41,11 @@ func NewErrorInvalidSignature(signature []byte) error {
return sdkerrors.Wrapf(ErrInvalidSignature, "invalid signature %X", signature)
}

// NewErrorInvalidUploadBytes returns an error indicating that the upload bytes are invalid.
func NewErrorInvalidUploadBytes(bytes sdkmath.Int) error {
return sdkerrors.Wrapf(ErrInvalidUploadBytes, "invalid upload bytes %s", bytes)
}

// NewErrorSessionNotFound returns an error indicating that the specified session does not exist.
func NewErrorSessionNotFound(id uint64) error {
return sdkerrors.Wrapf(ErrSessionNotFound, "session %d does not exist", id)
Expand Down
4 changes: 4 additions & 0 deletions x/session/types/v3/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func NewMsgUpdateSessionRequest(from base.NodeAddress, id uint64, downloadBytes,
}
}

func (m *MsgUpdateSessionRequest) Bytes() sdkmath.Int {
return m.DownloadBytes.Add(m.UploadBytes)
}

func (m *MsgUpdateSessionRequest) Proof() *Proof {
return &Proof{
ID: m.ID,
Expand Down
6 changes: 6 additions & 0 deletions x/session/types/v3/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
type Session interface {
proto.Message

Bytes() sdkmath.Int

GetID() uint64
GetAccAddress() string
GetNodeAddress() string
Expand Down Expand Up @@ -39,6 +41,10 @@ type Session interface {
SetStatusAt(v time.Time)
}

func (m *BaseSession) Bytes() sdkmath.Int {
return m.GetDownloadBytes().Add(m.GetUploadBytes())
}

func (m *BaseSession) GetID() uint64 { return m.ID }
func (m *BaseSession) GetAccAddress() string { return m.AccAddress }
func (m *BaseSession) GetNodeAddress() string { return m.NodeAddress }
Expand Down
82 changes: 54 additions & 28 deletions x/subscription/keeper/hooks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper

import (
sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"

base "github.com/sentinel-official/hub/v12/types"
Expand All @@ -9,58 +10,94 @@ import (
"github.com/sentinel-official/hub/v12/x/subscription/types/v3"
)

// SessionInactivePreHook handles the necessary operations when a session becomes inactive.
// SessionInactivePreHook performs cleanup operations when a session transitions to an inactive state.
func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
k.Logger(ctx).Info("Running session inactive pre-hook", "id", id)

// Retrieve the session by ID; return an error if not found.
// Retrieve the session by ID; return an error if it doesn't exist.
item, found := k.GetSession(ctx, id)
if !found {
return types.NewErrorSessionNotFound(id)
}

// Assert the retrieved session to the v3.Session type; return nil if the assertion fails.
// Ensure the session is of type v3.Session; do nothing if it's not.
session, ok := item.(*v3.Session)
if !ok {
return nil
}

// Ensure the session status is "InactivePending"; return an error if it has a different status.
// Verify that the session's status is "InactivePending"; otherwise, return an error.
if !session.Status.Equal(v1base.StatusInactivePending) {
return types.NewErrorInvalidSessionStatus(session.ID, session.Status)
}

// Retrieve the subscription associated with the session; return an error if not found.
// Fetch the subscription associated with the session; return an error if it doesn't exist.
subscription, found := k.GetSubscription(ctx, session.SubscriptionID)
if !found {
return types.NewErrorSubscriptionNotFound(session.SubscriptionID)
}

// Convert the session's account address from Bech32 format.
// Decode the session's account address from Bech32 format.
accAddr, err := sdk.AccAddressFromBech32(session.AccAddress)
if err != nil {
return err
}

// Retrieve the allocation for the subscription and account; return an error if not found.
alloc, found := k.GetAllocation(ctx, subscription.ID, accAddr)
// Decode the session's node address from Bech32 format.
nodeAddr, err := base.NodeAddressFromBech32(session.NodeAddress)
if err != nil {
return err
}

// Remove session references for allocation, node, plan, and subscription.
k.DeleteSessionForAllocation(ctx, subscription.ID, accAddr, session.ID)
k.DeleteSessionForPlanByNode(ctx, subscription.PlanID, nodeAddr, session.ID)
k.DeleteSessionForSubscription(ctx, subscription.ID, session.ID)

return nil
}

// SessionUpdatePreHook updates session and allocation details during a session update.
func (k *Keeper) SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error {
k.Logger(ctx).Info("Running session update pre-hook", "id", id)

// Retrieve the session by ID; return an error if it doesn't exist.
item, found := k.GetSession(ctx, id)
if !found {
return types.NewErrorAllocationNotFound(subscription.ID, accAddr)
return types.NewErrorSessionNotFound(id)
}

// Ensure the session is of type v3.Session; do nothing if it's not.
session, ok := item.(*v3.Session)
if !ok {
return nil
}

// Calculate the total utilised bytes as the sum of download and upload bytes.
utilisedBytes := session.DownloadBytes.Add(session.UploadBytes)
// Ensure the session is not in the "Inactive" state; return an error if it is.
if session.Status.Equal(v1base.StatusInactive) {
return types.NewErrorInvalidSessionStatus(session.ID, session.Status)
}

// Decode the session's account address from Bech32 format.
accAddr, err := sdk.AccAddressFromBech32(session.AccAddress)
if err != nil {
return err
}

// Update the utilised bytes in the allocation; cap it at the granted bytes if it exceeds the limit.
alloc.UtilisedBytes = alloc.UtilisedBytes.Add(utilisedBytes)
if alloc.UtilisedBytes.GT(alloc.GrantedBytes) {
alloc.UtilisedBytes = alloc.GrantedBytes
// Fetch the allocation for the subscription and account; return an error if it doesn't exist.
alloc, found := k.GetAllocation(ctx, session.SubscriptionID, accAddr)
if !found {
return types.NewErrorAllocationNotFound(session.SubscriptionID, accAddr)
}

// Save the updated allocation in the store.
// Update allocation's utilised bytes based on the difference between current and previous session bytes.
diffBytes := currBytes.Sub(session.Bytes())
alloc.UtilisedBytes = alloc.UtilisedBytes.Add(diffBytes)

// Store the updated allocation in the keeper.
k.SetAllocation(ctx, alloc)

// Emit an event to log the allocation update.
// Emit an event logging the updated allocation details.
ctx.EventManager().EmitTypedEvent(
&v3.EventAllocate{
ID: alloc.ID,
Expand All @@ -70,16 +107,5 @@ func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
},
)

// Convert the session's node address from Bech32 format.
nodeAddr, err := base.NodeAddressFromBech32(session.NodeAddress)
if err != nil {
return err
}

// Delete the session records associated with allocation, node, plan, and subscription from the store.
k.DeleteSessionForAllocation(ctx, subscription.ID, accAddr, session.ID)
k.DeleteSessionForPlanByNode(ctx, subscription.PlanID, nodeAddr, session.ID)
k.DeleteSessionForSubscription(ctx, subscription.ID, session.ID)

return nil
}

0 comments on commit 387da58

Please sign in to comment.