Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frank/sequence & position enforcers #165

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 40 additions & 120 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from driftpy.constants.numeric_constants import (
QUOTE_SPOT_MARKET_INDEX,
)
from driftpy.enforcers.position_enforcer import PositionEnforcer
from driftpy.enforcers.sequence_enforcer import SequenceEnforcer
from driftpy.decode.utils import decode_name
from driftpy.drift_user import DriftUser
from driftpy.accounts import *
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
market_lookup_table: Optional[Pubkey] = None,
jito_params: Optional[JitoParams] = None,
enforce_tx_sequencing: bool = False,
enforce_position_sizing: bool = False,
):
"""Initializes the drift client object

Expand Down Expand Up @@ -155,29 +158,13 @@ def __init__(

self.tx_version = tx_version if tx_version is not None else Legacy

self.enforce_tx_sequencing = enforce_tx_sequencing
if self.enforce_tx_sequencing is True:
file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json")
with file.open() as f:
raw = file.read_text()
idl = Idl.from_json(raw)

provider = Provider(connection, wallet, opts)
self.sequence_enforcer_pid = (
SEQUENCER_PROGRAM_ID
if env == "mainnet"
else DEVNET_SEQUENCER_PROGRAM_ID
)
self.sequence_enforcer_program = Program(
idl,
self.sequence_enforcer_pid,
provider,
)
self.sequence_number_by_subaccount = {}
self.sequence_bump_by_subaccount = {}
self.sequence_initialized_by_subaccount = {}
self.sequence_address_by_subaccount = {}
self.resetting_sequence = False
self.sequence_enforcer = None
if enforce_tx_sequencing is True:
self.sequence_enforcer = SequenceEnforcer(self.connection, self.wallet, env)

self.position_enforcer = None
if enforce_position_sizing is True:
self.position_enforcer = PositionEnforcer(self.connection, self.wallet)

if jito_params is not None:
from driftpy.tx.jito_tx_sender import JitoTxSender
Expand All @@ -199,8 +186,8 @@ def __init__(

async def subscribe(self):
await self.account_subscriber.subscribe()
if self.enforce_tx_sequencing:
await self.load_sequence_info()
if self.sequence_enforcer:
await self.sequence_enforcer.load_sequence_info(self.sub_account_ids)
for sub_account_id in self.sub_account_ids:
await self.add_user(sub_account_id)

Expand Down Expand Up @@ -357,12 +344,18 @@ async def send_ixs(
subaccount = sequencer_subaccount or self.active_sub_account_id

if (
self.enforce_tx_sequencing
and self.sequence_initialized_by_subaccount[subaccount]
and not self.resetting_sequence
self.sequence_enforcer
and self.sequence_enforcer.get_sequence_init_for_subaccount(subaccount)
is True
and not self.sequence_enforcer.get_resetting_sequence()
):
sequence_instruction = self.get_check_and_set_sequence_number_ix(
self.sequence_number_by_subaccount[subaccount], subaccount
sequence_instruction = (
self.sequence_enforcer.get_check_and_set_sequence_number_ix(
self.sequence_enforcer.get_sequence_number_for_subaccount(
subaccount
),
subaccount,
)
)
ixs.insert(len(compute_unit_instructions), sequence_instruction)

Expand Down Expand Up @@ -886,6 +879,7 @@ async def place_perp_order(
self,
order_params: OrderParams,
sub_account_id: int = None,
expected_size: Optional[int] = None,
):
tx_sig_and_slot = await self.send_ixs(
[
Expand Down Expand Up @@ -2846,106 +2840,32 @@ def get_update_prelaunch_oracle_ix(self, market_index: int):
)

async def init_sequence(self, subaccount: int = 0) -> Signature:
if self.sequence_enforcer is None:
raise Exception("Sequence enforcer is not initialized")
try:
sig = (await self.send_ixs([self.get_sequence_init_ix(subaccount)])).tx_sig
self.sequence_initialized_by_subaccount[subaccount] = True
sig = (
await self.send_ixs(
[self.sequence_enforcer.get_sequence_init_ix(subaccount)]
)
).tx_sig
self.sequence_enforcer.set_sequence_init_for_subaccount(subaccount, True)
return sig
except Exception as e:
print(f"WARNING: failed to initialize sequence: {e}")

def get_sequence_init_ix(self, subaccount: int = 0) -> Instruction:
if self.enforce_tx_sequencing is False:
raise ValueError("tx sequencing is disabled")
return self.sequence_enforcer_program.instruction["initialize"](
self.sequence_bump_by_subaccount[subaccount],
str(subaccount),
ctx=Context(
accounts={
"sequence_account": self.sequence_address_by_subaccount[subaccount],
"authority": self.wallet.payer.pubkey(),
"system_program": ID,
}
),
)

async def reset_sequence_number(
self, sequence_number: int = 0, subaccount: int = 0
) -> Signature:
if self.sequence_enforcer is None:
raise Exception("Sequence enforcer is not initialized")
try:
ix = self.get_reset_sequence_number_ix(sequence_number)
self.resetting_sequence = True
ix = self.sequence_enforcer.get_reset_sequence_number_ix(sequence_number)
self.sequence_enforcer.set_resetting_sequence(True)
sig = (await self.send_ixs(ix)).tx_sig
self.resetting_sequence = False
self.sequence_number_by_subaccount[subaccount] = sequence_number
self.sequence_enforcer.set_resetting_sequence(False)
self.sequence_enforcer.set_sequence_number_for_subaccount(
subaccount, sequence_number
)
return sig
except Exception as e:
print(f"WARNING: failed to reset sequence number: {e}")

def get_reset_sequence_number_ix(
self, sequence_number: int, subaccount: int = 0
) -> Instruction:
if self.enforce_tx_sequencing is False:
raise ValueError("tx sequencing is disabled")
return self.sequence_enforcer_program.instruction["reset_sequence_number"](
sequence_number,
ctx=Context(
accounts={
"sequence_account": self.sequence_address_by_subaccount[subaccount],
"authority": self.wallet.payer.pubkey(),
}
),
)

def get_check_and_set_sequence_number_ix(
self, sequence_number: Optional[int] = None, subaccount: int = 0
):
if self.enforce_tx_sequencing is False:
raise ValueError("tx sequencing is disabled")
sequence_number = (
sequence_number or self.sequence_number_by_subaccount[subaccount]
)

if (
sequence_number < self.sequence_number_by_subaccount[subaccount] - 1
): # we increment after creating the ix, so we check - 1
print(
f"WARNING: sequence number {sequence_number} < last used {self.sequence_number_by_subaccount[subaccount] - 1}"
)

ix = self.sequence_enforcer_program.instruction[
"check_and_set_sequence_number"
](
sequence_number,
ctx=Context(
accounts={
"sequence_account": self.sequence_address_by_subaccount[subaccount],
"authority": self.wallet.payer.pubkey(),
}
),
)

self.sequence_number_by_subaccount[subaccount] += 1
return ix

async def load_sequence_info(self):
for subaccount in self.sub_account_ids:
address, bump = get_sequencer_public_key_and_bump(
self.sequence_enforcer_pid, self.wallet.payer.pubkey(), subaccount
)
try:
sequence_account_raw = await self.sequence_enforcer_program.account[
"SequenceAccount"
].fetch(address)
except anchorpy.error.AccountDoesNotExistError as e:
self.sequence_address_by_subaccount[subaccount] = address
self.sequence_number_by_subaccount[subaccount] = 1
self.sequence_bump_by_subaccount[subaccount] = bump
self.sequence_initialized_by_subaccount[subaccount] = False
continue
sequence_account = cast(SequenceAccount, sequence_account_raw)
self.sequence_number_by_subaccount[subaccount] = (
sequence_account.sequence_num + 1
)
self.sequence_bump_by_subaccount[subaccount] = bump
self.sequence_initialized_by_subaccount[subaccount] = True
self.sequence_address_by_subaccount[subaccount] = address
68 changes: 68 additions & 0 deletions src/driftpy/enforcers/position_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Union
from driftpy.types import (
OrderParams,
SpotPosition,
PerpPosition,
MarketType,
UserAccount,
is_variant,
)


class PositionEnforcer:
def __init__(self):
pass

def set_and_check_order_params(
self, expected_size: int, order_params: OrderParams, user: UserAccount
) -> OrderParams:
size_adjustment = self._get_size_adjustment(
expected_size, order_params.market_index, order_params.market_type, user
)
order_params.base_asset_amount = max(
order_params.base_asset_amount + size_adjustment, 0
)
if order_params.base_asset_amount == 0:
print("WARNING: PositionEnforcer has reduced order size to ZERO.")
return order_params

def _get_size_adjustment(
self,
expected_size: int,
market_index: int,
market_type: MarketType,
user: UserAccount,
) -> int:
position: Union[SpotPosition, PerpPosition]
if is_variant(market_type, "Perp"):
position = next(
(
pos
for pos in user.perp_positions
if pos.market_index == market_index
),
None,
)
if position is None:
raise Exception(
f"Position market_index: {market_index} market_type: {market_type} not found"
)

difference = position.base_asset_amount - expected_size
return difference * -1 # positive if too short, negative if too long
else:
position = next(
(
pos
for pos in user.spot_positions
if pos.market_index == market_index
),
None,
)
if position is None:
raise Exception(
f"Position market_index: {market_index} market_type: {market_type} not found"
)

difference = position.scaled_balance - expected_size
return difference * -1 # positive if too short, negative if too long
Loading