Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frank/update driftpy #34

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python = "^3.10"
python-dotenv = "^1.0.0"
solana = "^0.30.1"
anchorpy = "^0.17.1"
driftpy = "^0.7.19"
driftpy = "^0.7.20"

[build-system]
requires = ["poetry-core"]
Expand Down
42 changes: 21 additions & 21 deletions python/sdk/jit_proxy/jit_proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional, cast

from borsh_construct.enum import _rust_enum
from sumtypes import constructor
from sumtypes import constructor # type: ignore

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from anchorpy import Context, Program

Expand Down Expand Up @@ -74,7 +74,7 @@ async def jit(self, params: JitIxParams):
await self.init()

sub_account_id = self.drift_client.get_sub_account_id_for_ix(
params.sub_account_id
params.sub_account_id # type: ignore
)

order = next(
Expand All @@ -90,11 +90,11 @@ async def jit(self, params: JitIxParams):
params.taker,
self.drift_client.get_user_account(sub_account_id),
],
writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX]
if is_variant(order.market_type, "Spot")
writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX] # type: ignore
if is_variant(order.market_type, "Spot") # type: ignore
else [],
writable_perp_market_indexes=[order.market_index]
if is_variant(order.market_type, "Perp")
writable_perp_market_indexes=[order.market_index] # type: ignore
if is_variant(order.market_type, "Perp") # type: ignore
else [],
)

Expand All @@ -114,35 +114,35 @@ async def jit(self, params: JitIxParams):
)
)

if is_variant(order.market_type, "Spot"):
if is_variant(order.market_type, "Spot"): # type: ignore
remaining_accounts.append(
AccountMeta(
pubkey=self.drift_client.get_spot_market_account(
order.market_index
pubkey=self.drift_client.get_spot_market_account( # type: ignore
order.market_index # type: ignore
).vault,
is_writable=False,
is_signer=False,
)
)
remaining_accounts.append(
AccountMeta(
pubkey=self.drift_client.get_quote_spot_market_account().vault,
pubkey=self.drift_client.get_quote_spot_market_account().vault, # type: ignore
is_writable=False,
is_signer=False,
)
)

jit_params = self.program.type["JitParams"](
jit_params = self.program.type["JitParams"]( # type: ignore
taker_order_id=params.taker_order_id,
max_position=cast(int, params.max_position),
min_position=cast(int, params.min_position),
bid=cast(int, params.bid),
ask=cast(int, params.ask),
price_type=self.get_price_type(params.price_type),
price_type=self.get_price_type(params.price_type), # type: ignore
post_only=self.get_post_only(params.post_only),
)

ix = self.program.instruction["jit"](
ix = self.program.instruction["jit"]( # type: ignore
jit_params,
ctx=Context(
accounts={
Expand All @@ -156,7 +156,7 @@ async def jit(self, params: JitIxParams):
"authority": self.drift_client.wallet.public_key,
"drift_program": self.drift_client.program_id,
},
signers={self.drift_client.wallet},
signers={self.drift_client.wallet}, # type: ignore
remaining_accounts=remaining_accounts,
),
)
Expand All @@ -167,16 +167,16 @@ async def jit(self, params: JitIxParams):

def get_price_type(self, price_type: PriceType):
if is_variant(price_type, "Oracle"):
return self.program.type["PriceType"].Oracle()
return self.program.type["PriceType"].Oracle() # type: ignore
elif is_variant(price_type, "Limit"):
return self.program.type["PriceType"].Limit()
else:
return self.program.type["PriceType"].Limit() # type: ignore
else:
raise ValueError(f"Unknown price type: {str(price_type)}")

def get_post_only(self, post_only: PostOnlyParams):
if is_variant(post_only, "MustPostOnly"):
return self.program.type["PostOnlyParam"].MustPostOnly()
return self.program.type["PostOnlyParam"].MustPostOnly() # type: ignore
elif is_variant(post_only, "TryPostOnly"):
return self.program.type["PostOnlyParam"].TryPostOnly()
return self.program.type["PostOnlyParam"].TryPostOnly() # type: ignore
elif is_variant(post_only, "Slide"):
return self.program.type["PostOnlyParam"].Slide()
return self.program.type["PostOnlyParam"].Slide() # type: ignore
10 changes: 5 additions & 5 deletions python/sdk/jit_proxy/jitter/base_jitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.types import is_variant, UserAccount, Order, UserStatsAccount, ReferrerInfo
from driftpy.drift_client import DriftClient
Expand Down Expand Up @@ -72,7 +72,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i
taker_key_str = str(taker_key)

taker_stats_key = get_user_stats_account_public_key(
self.drift_client.program_id, taker.authority
self.drift_client.program_id, taker.authority # type: ignore
)

self.logger.info(f"Taker: {taker.authority}")
Expand Down Expand Up @@ -110,7 +110,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i

if (
order.base_asset_amount - order.base_asset_amount_filled
<= perp_market_account.amm.min_order_size
<= perp_market_account.amm.min_order_size # type: ignore
):
self.logger.info("Order filled within min_order_size")
self.logger.info("----------------------------")
Expand Down Expand Up @@ -138,7 +138,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i

if (
order.base_asset_amount - order.base_asset_amount_filled
<= spot_market_account.min_order_size
<= spot_market_account.min_order_size # type: ignore
):
self.logger.info("Order filled within min_order_size")
self.logger.info("----------------------------")
Expand Down Expand Up @@ -177,7 +177,7 @@ async def create_try_fill(
order: Order,
order_sig: str,
):
future = asyncio.Future()
future = asyncio.Future() # type: ignore
future.set_result(None)
return future

Expand Down
2 changes: 1 addition & 1 deletion python/sdk/jit_proxy/jitter/jitter_shotgun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Coroutine

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.drift_client import DriftClient
from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber
Expand Down
26 changes: 13 additions & 13 deletions python/sdk/jit_proxy/jitter/jitter_sniper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any, Coroutine

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.drift_client import DriftClient
from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:

auction_start_price = convert_to_number(
get_auction_price_for_oracle_offset_auction(
order, order.slot, oracle_price.price
order, order.slot, oracle_price.price # type: ignore
)
if is_variant(order.order_type, "Oracle")
else order.auction_start_price,
Expand All @@ -252,23 +252,23 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:

auction_end_price = convert_to_number(
get_auction_price_for_oracle_offset_auction(
order, order.slot + order.auction_duration - 1, oracle_price.price
order, order.slot + order.auction_duration - 1, oracle_price.price # type: ignore
)
if is_variant(order.order_type, "Oracle")
else order.auction_end_price,
PRICE_PRECISION,
)

bid = (
convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION)
if is_variant(params.price_type, "Oracle")
else convert_to_number(params.bid, PRICE_PRECISION)
convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION) # type: ignore
if is_variant(params.price_type, "Oracle") # type: ignore
else convert_to_number(params.bid, PRICE_PRECISION) # type: ignore
)

ask = (
convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION)
if is_variant(params.price_type, "Oracle")
else convert_to_number(params.ask, PRICE_PRECISION)
convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION) # type: ignore
if is_variant(params.price_type, "Oracle") # type: ignore
else convert_to_number(params.ask, PRICE_PRECISION) # type: ignore
)

slots_until_cross = 0
Expand All @@ -282,7 +282,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
if (
convert_to_number(
get_auction_price(
order, order.slot + slots_until_cross, oracle_price.price
order, order.slot + slots_until_cross, oracle_price.price # type: ignore
),
PRICE_PRECISION,
)
Expand All @@ -294,7 +294,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
if (
convert_to_number(
get_auction_price(
order, order.slot + slots_until_cross, oracle_price.price
order, order.slot + slots_until_cross, oracle_price.price # type: ignore
),
PRICE_PRECISION,
)
Expand All @@ -312,12 +312,12 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
auction_start_price,
auction_end_price,
step_size,
oracle_price,
oracle_price, # type: ignore
)

async def wait_for_slot_or_cross_or_expiry(
self, target_slot: int, order: Order, initial_details: AuctionAndOrderDetails
) -> (int, AuctionAndOrderDetails):
) -> (int, AuctionAndOrderDetails): # type: ignore
auction_end_slot = order.auction_duration + order.slot
current_details: AuctionAndOrderDetails = initial_details
will_cross = initial_details.will_cross
Expand Down