Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for diff oracle source #211

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defaults:
working-directory: .

jobs:
black:
ruff-format:
runs-on: ubicloud
steps:
- uses: actions/checkout@v3
Expand All @@ -23,9 +23,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black
- name: Run Black
run: black --check .
pip install ruff
- name: Run ruff format
run: ruff format --check .

tests:
runs-on: ubicloud
Expand Down Expand Up @@ -54,7 +54,7 @@ jobs:

bump-version:
runs-on: ubicloud
needs: [black, tests]
needs: [ruff-format, tests]
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
steps:
- uses: actions/checkout@v3
Expand Down
36 changes: 25 additions & 11 deletions examples/fetch_all_markets.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
import os

from anchorpy import Provider
from anchorpy import Wallet
from driftpy.drift_client import AccountSubscriptionConfig
from driftpy.drift_client import DriftClient
from anchorpy.provider import Provider, Wallet
from solana.rpc.async_api import AsyncClient
from solders.keypair import Keypair

from driftpy.drift_client import AccountSubscriptionConfig, DriftClient


async def get_all_market_names():
env = "mainnet-beta" # 'devnet'
rpc = os.environ.get("MAINNET_RPC_ENDPOINT")
kp = Keypair() # random wallet
wallet = Wallet(kp)
Expand All @@ -19,10 +17,10 @@ async def get_all_market_names():
drift_client = DriftClient(
provider.connection,
provider.wallet,
env.split("-")[0],
account_subscription=AccountSubscriptionConfig("cached"),
"mainnet",
account_subscription=AccountSubscriptionConfig("websocket"),
)

await drift_client.subscribe()
all_perps_markets = await drift_client.program.account["PerpMarket"].all()
sorted_all_perps_markets = sorted(
all_perps_markets, key=lambda x: x.account.market_index
Expand All @@ -46,10 +44,26 @@ async def get_all_market_names():
print(market)

result = result_perp + result_spot[1:]

print("Here are some prices:")
print(drift_client.get_oracle_price_data_for_perp_market(0))
print(drift_client.get_oracle_price_data_for_spot_market(0))
await drift_client.unsubscribe()
return result


if __name__ == "__main__":
loop = asyncio.new_event_loop()
answer = loop.run_until_complete(get_all_market_names())
print(answer)
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
answer = loop.run_until_complete(get_all_market_names())
print(answer)
finally:
# Clean up pending tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()

# Run loop until tasks complete/cancel
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
781 changes: 415 additions & 366 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ urllib3 = "1.26.13"
websockets = "10.4"
yarl = "1.8.2"
zstandard = "0.18.0"
jinja2 = "^3.1.2"
jinja2 = "^3.0.2"
mypy = "^1.7.0"
deprecated = "^1.2.14"
events = "^0.5"
Expand All @@ -91,12 +91,13 @@ pytest = "^7.2.0"
flake8 = "6.0.0"
black = "24.4.2"
pytest-asyncio = "^0.21.0"
mkdocs = "^1.3.0"
mkdocstrings = "^0.17.0"
mkdocs = "^1.6.0"
mkdocstrings = "^0.27.0"
mkdocs-material = "^8.1.8"
bump2version = "^1.0.1"
autopep8 = "^2.0.4"
mypy = "^1.7.0"
icecream = "^2.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.4"
Expand All @@ -110,7 +111,12 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
asyncio_mode = "strict"


[tool.ruff]
exclude = [".git", "__pycache__", "docs/source/conf.py", "old", "build", "dist"]
[tool.ruff.pycodestyle]

[tool.ruff.lint.pycodestyle]
max-line-length = 88

[tool.pyright]
reportMissingModuleSource = false
58 changes: 34 additions & 24 deletions src/driftpy/account_subscription_config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from typing import Literal, Optional
from typing import Literal, Optional, cast

from solders.pubkey import Pubkey
from anchorpy.program.core import Program
from solana.rpc.commitment import Commitment
from solders.pubkey import Pubkey

from driftpy.accounts.bulk_account_loader import BulkAccountLoader
from driftpy.accounts.cache import (
CachedDriftClientAccountSubscriber,
CachedUserAccountSubscriber,
)
from driftpy.accounts.demo import (
DemoDriftClientAccountSubscriber,
DemoUserAccountSubscriber,
)
from driftpy.accounts.polling import (
PollingDriftClientAccountSubscriber,
PollingUserAccountSubscriber,
)
from anchorpy import Program

from driftpy.accounts.types import FullOracleWrapper
from driftpy.accounts.ws import (
WebsocketDriftClientAccountSubscriber,
WebsocketUserAccountSubscriber,
)
from driftpy.accounts.demo import (
DemoDriftClientAccountSubscriber,
DemoUserAccountSubscriber,
)
from driftpy.types import OracleInfo


Expand All @@ -32,31 +32,33 @@ def default():

def __init__(
self,
type: Literal["polling", "websocket", "cached", "demo"],
account_subscription_type: Literal["polling", "websocket", "cached", "demo"],
bulk_account_loader: Optional[BulkAccountLoader] = None,
commitment: Commitment = None,
commitment: Commitment = Commitment("confirmed"),
):
self.type = type
self.type = account_subscription_type
self.commitment = commitment
self.bulk_account_loader = None

if self.type == "polling":
if bulk_account_loader is None:
raise ValueError("polling subscription requires bulk account loader")
if self.type != "polling":
return

if commitment is not None and commitment != bulk_account_loader.commitment:
raise ValueError(
f"bulk account loader commitment {bulk_account_loader.commitment} != commitment passed {commitment}"
)
if bulk_account_loader is None:
raise ValueError("polling subscription requires bulk account loader")

self.bulk_account_loader = bulk_account_loader
if commitment != bulk_account_loader.commitment:
raise ValueError(
f"bulk account loader commitment {bulk_account_loader.commitment} != commitment passed {commitment}"
)

self.commitment = commitment
self.bulk_account_loader = bulk_account_loader

def get_drift_client_subscriber(
self,
program: Program,
perp_market_indexes: list[int] = None,
spot_market_indexes: list[int] = None,
oracle_infos: list[OracleInfo] = None,
perp_market_indexes: list[int] | None = None,
spot_market_indexes: list[int] | None = None,
oracle_infos: list[OracleInfo] | None = None,
):
should_find_all_markets_and_oracles = (
perp_market_indexes is None
Expand All @@ -69,6 +71,10 @@ def get_drift_client_subscriber(

match self.type:
case "polling":
if self.bulk_account_loader is None:
raise ValueError(
"polling subscription requires bulk account loader"
)
return PollingDriftClientAccountSubscriber(
program,
self.bulk_account_loader,
Expand All @@ -82,7 +88,7 @@ def get_drift_client_subscriber(
program,
perp_market_indexes,
spot_market_indexes,
oracle_infos,
cast(list[FullOracleWrapper], oracle_infos),
should_find_all_markets_and_oracles,
self.commitment,
)
Expand Down Expand Up @@ -115,6 +121,10 @@ def get_drift_client_subscriber(
def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
match self.type:
case "polling":
if self.bulk_account_loader is None:
raise ValueError(
"polling subscription requires bulk account loader"
)
return PollingUserAccountSubscriber(
user_pubkey, program, self.bulk_account_loader
)
Expand Down
Loading
Loading