Skip to content

Commit

Permalink
[Community] handle index fetch and profile update
Browse files Browse the repository at this point in the history
  • Loading branch information
GuillaumeDSM committed Jul 3, 2024
1 parent 68cf640 commit e2d537f
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 15 deletions.
35 changes: 32 additions & 3 deletions octobot/community/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, config=None, backend_url=None, backend_key=None, use_as_singl

self.initialized_event = None
self._login_completed = None
self._fetched_private_data = None
self._startup_info = None

self._fetch_account_task = None
Expand Down Expand Up @@ -119,8 +120,10 @@ async def get_strategy(self, strategy_id, reload=False) -> strategy_data.Strateg
await self.init_public_data(reset=reload)
return self.public_data.get_strategy(strategy_id)

async def get_strategy_profile_data(self, strategy_id: str) -> commons_profiles.ProfileData:
return await self.supabase_client.fetch_product_config(strategy_id)
async def get_strategy_profile_data(
self, strategy_id: str, product_slug: str = None
) -> commons_profiles.ProfileData:
return await self.supabase_client.fetch_product_config(strategy_id, product_slug=product_slug)

def is_feed_connected(self):
return self._community_feed is not None and self._community_feed.is_connected_to_remote_feed()
Expand Down Expand Up @@ -217,6 +220,11 @@ async def _ensure_async_loop(self):
self._login_completed = asyncio.Event()
if should_set:
self._login_completed.set()
if self._fetched_private_data is not None:
should_set = self._fetched_private_data.is_set()
self._fetched_private_data = asyncio.Event()
if should_set:
self._fetched_private_data.set()
# changed event loop: restart client
await self.supabase_client.close()
self.user_account.flush()
Expand All @@ -230,6 +238,8 @@ def is_initialized(self):
return self.initialized_event is not None and self.initialized_event.is_set()

def init_account(self, fetch_private_data):
if fetch_private_data and self._fetched_private_data is None:
self._fetched_private_data = asyncio.Event()
self._fetch_account_task = asyncio.create_task(self._initialize_account(fetch_private_data=fetch_private_data))

async def async_init_account(self, fetch_private_data):
Expand Down Expand Up @@ -271,6 +281,11 @@ async def wait_for_login_if_processing(self):
# ensure login details have been fetched
await asyncio.wait_for(self._login_completed.wait(), self.LOGIN_TIMEOUT)

async def wait_for_private_data_fetch_if_processing(self):
if self._fetched_private_data is not None and not self._fetched_private_data.is_set():
# ensure login details have been fetched
await asyncio.wait_for(self._fetched_private_data.wait(), constants.COMMUNITY_FETCH_TIMEOUT)

def can_authenticate(self):
return bool(
identifiers_provider.IdentifiersProvider.BACKEND_URL
Expand Down Expand Up @@ -381,6 +396,9 @@ async def get_current_bot_products_subscription(self) -> dict:
def get_owned_packages(self) -> list[str]:
return self.user_account.owned_packages

def has_open_source_package(self) -> bool:
return bool(self.get_owned_packages())

def has_owned_packages_to_install(self) -> bool:
return self.user_account.has_pending_packages_to_install

Expand Down Expand Up @@ -490,7 +508,13 @@ async def _init_community_data(self, fetch_private_data):

async def init_public_data(self, reset=False):
if reset or not self.public_data.products.fetched:
self.public_data.set_products(await self.supabase_client.fetch_products())
await self._refresh_products()

async def _refresh_products(self):
category_types = ["profile"]
if self.has_open_source_package():
category_types.append("index")
self.public_data.set_products(await self.supabase_client.fetch_products(category_types))

async def fetch_private_data(self, reset=False):
try:
Expand All @@ -516,6 +540,11 @@ async def fetch_private_data(self, reset=False):
self.save_mqtt_device_uuid(fetched_mqtt_uuid)
except Exception as err:
self.logger.exception(err, True, f"Error when fetching package urls: {err}")
finally:
self._fetched_private_data.set()
if self.has_open_source_package():
# fetch indexes as well
await self._refresh_products()

async def _fetch_package_urls(self, mqtt_uuid: typing.Optional[str]) -> (list[str], str):
resp = await self.supabase_client.http_get(
Expand Down
4 changes: 2 additions & 2 deletions octobot/community/models/community_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import octobot.community.models.strategy_data as strategy_data


STRATEGY_CATEGORY_TYPE = "profile"
STRATEGY_CATEGORY_TYPES = ["profile", "index"]


class CommunityPublicData:
Expand All @@ -36,7 +36,7 @@ def get_strategies(self) -> list[strategy_data.StrategyData]:
return [
strategy_data.StrategyData.from_dict(strategy_dict)
for strategy_dict in self.products.value.values()
if self._get_category_type(strategy_dict) == STRATEGY_CATEGORY_TYPE
if self._get_category_type(strategy_dict) in STRATEGY_CATEGORY_TYPES
]

def _get_category_type(self, product: dict):
Expand Down
41 changes: 40 additions & 1 deletion octobot/community/models/strategy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
import octobot_commons.enums as commons_enums


CATEGORY_NAME_TRANSLATIONS_BY_SLUG = {
"coingecko-index": {"en": "Crypto Basket"}
}
FORCED_URL_PATH_BY_SLUG = {
"coingecko-index": "features/crypto-basket",
}
DEFAULT_LOGO_NAME_BY_SLUG = {
"coingecko-index": "crypto-basket.png",
}
AUTO_UPDATED_CATEGORIES = ["coingecko-index"]
DEFAULT_LOGO_NAME = "default_strategy.png"
EXTENSION_CATEGORIES = ["coingecko-index"]


@dataclasses.dataclass
class CategoryData(commons_dataclasses.FlexibleDataclass):
slug: str = ""
Expand All @@ -33,8 +47,18 @@ def get_url(self) -> str:
if external_links:
if blog_slug := external_links.get("blog"):
return f"{identifiers_provider.IdentifiersProvider.COMMUNITY_LANDING_URL}/en/blog/{blog_slug}"
if features_slug := external_links.get("features"):
return f"{identifiers_provider.IdentifiersProvider.COMMUNITY_LANDING_URL}/features/{features_slug}"
return ""

def get_default_logo_url(self) -> str:
return DEFAULT_LOGO_NAME_BY_SLUG.get(self.slug, DEFAULT_LOGO_NAME)

def get_name(self, locale, default_locale=constants.DEFAULT_LOCALE):
return CATEGORY_NAME_TRANSLATIONS_BY_SLUG.get(self.slug, self.name_translations).get(locale, default_locale)

def is_auto_updated(self) -> bool:
return self.slug in AUTO_UPDATED_CATEGORIES

@dataclasses.dataclass
class ResultsData(commons_dataclasses.FlexibleDataclass):
Expand Down Expand Up @@ -76,7 +100,11 @@ def get_name(self, locale, default_locale=constants.DEFAULT_LOCALE):
return self.content["name_translations"].get(locale, default_locale)

def get_url(self) -> str:
return f"{identifiers_provider.IdentifiersProvider.COMMUNITY_URL}/strategies/{self.slug}"
path = FORCED_URL_PATH_BY_SLUG.get(self.category.slug, f"strategies/{self.slug}")
return f"{identifiers_provider.IdentifiersProvider.COMMUNITY_LANDING_URL}/{path}"

def get_product_url(self) -> str:
return f"{identifiers_provider.IdentifiersProvider.COMMUNITY_LANDING_URL}/strategies/{self.slug}"

def get_risk(self) -> commons_enums.ProfileRisk:
risk = self.attributes['risk'].upper()
Expand All @@ -86,3 +114,14 @@ def get_risk(self) -> commons_enums.ProfileRisk:
return commons_enums.ProfileRisk[risk]
except KeyError:
return commons_enums.ProfileRisk.MODERATE

def get_logo_url(self, prefix: str) -> str:
if self.logo_url:
return self.logo_url
return f"{prefix}{self.category.get_default_logo_url()}"

def is_auto_updated(self) -> bool:
return self.category.is_auto_updated()

def is_extension_only(self) -> bool:
return self.category.slug in EXTENSION_CATEGORIES
20 changes: 11 additions & 9 deletions octobot/community/supabase_backend/community_supabase_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def fetch_startup_info(self, bot_id) -> dict:
(await self.postgres_functions().invoke("get_startup_info", {"body": {"bot_id": bot_id}}))["data"]
)[0]

async def fetch_products(self) -> list:
async def fetch_products(self, category_types: list[str]) -> list:
return (
await self.table("products").select(
"*,"
Expand All @@ -243,10 +243,9 @@ async def fetch_products(self) -> list:
" profitability,"
" reference_market_profitability"
")"
).match({
enums.ProductKeys.VISIBILITY.value: "public",
"category.type": "profile",
})
).eq(
enums.ProductKeys.VISIBILITY.value, "public"
).in_("category.type", category_types)
.execute()
).data

Expand Down Expand Up @@ -365,14 +364,17 @@ async def fetch_exchanges(self, exchange_ids: list) -> list:
f"{enums.ExchangeKeys.INTERNAL_NAME.value}"
).in_(enums.ExchangeKeys.ID.value, exchange_ids).execute()).data

async def fetch_product_config(self, product_id: str) -> commons_profiles.ProfileData:
if not product_id:
async def fetch_product_config(self, product_id: str, product_slug: str = None) -> commons_profiles.ProfileData:
if not product_id and not product_slug:
raise errors.MissingProductConfigError(f"product_id is '{product_id}'")
try:
product = (await self.table("products").select(
query = self.table("products").select(
"slug, "
"product_config:product_configs!current_config_id(config, version)"
).eq(enums.ProductKeys.ID.value, product_id).execute()).data[0]
)
query = query.eq(enums.ProductKeys.SLUG.value, product_slug) if product_slug \
else query.eq(enums.ProductKeys.ID.value, product_id)
product = (await query.execute()).data[0]
except IndexError:
raise errors.MissingProductConfigError(f"Missing product with id '{product_id}'")
profile_data = commons_profiles.ProfileData.from_dict(
Expand Down
1 change: 1 addition & 0 deletions octobot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
# Profiles to force select at startup, identified by profile id, download url or name
FORCED_PROFILE = os.getenv("FORCED_PROFILE", None)
RUN_IN_MAIN_THREAD = os.getenv("RUN_IN_MAIN_THREAD", False)
PROFILE_UPDATE_RESTART_MIN = float(os.getenv("PROFILE_UPDATE_RESTART_MIN", 5))

OCTOBOT_BINARY_PROJECT_NAME = "OctoBot-Binary"

Expand Down
25 changes: 25 additions & 0 deletions octobot/octobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import octobot_commons.os_clock_sync as os_clock_sync
import octobot_commons.system_resources_watcher as system_resources_watcher
import octobot_commons.aiohttp_util as aiohttp_util
import octobot_commons.profiles as profiles

import octobot_services.api as service_api
import octobot_trading.api as trading_api
Expand Down Expand Up @@ -157,6 +158,7 @@ async def _post_initialize(self):

self.automation = automation.Automation(self.bot_id, self.tentacles_setup_config)
self._init_metadata_run_task = asyncio.create_task(self._store_run_metadata_when_available())
await self._init_profile_synchronizer()

async def _wait_for_run_data_init(self, exchange_managers, timeout):
for exchange_manager in exchange_managers:
Expand Down Expand Up @@ -209,6 +211,7 @@ async def stop(self):
await self.exchange_producer.stop()
await self.community_auth.stop()
await self.service_feed_producer.stop()
await profiles.stop_profile_synchronizer()
await os_clock_sync.stop_clock_synchronizer()
await system_resources_watcher.stop_system_resources_watcher()
await service_api.stop_services()
Expand All @@ -233,6 +236,28 @@ async def _ensure_clock(self):
if trading_api.is_trader_enabled_in_config(self.config) and constants.ENABLE_CLOCK_SYNCH:
await os_clock_sync.start_clock_synchronizer()

async def _init_profile_synchronizer(self):
await profiles.start_profile_synchronizer(
self.get_edited_config(constants.CONFIG_KEY, dict_only=False),
self._on_profile_update
)

async def delayed_restart(self, delay):
await asyncio.sleep(delay)
self.octobot_api.restart_bot()

async def _on_profile_update(self, profile_name: str):
await service_api.send_notification(
service_api.create_notification(
f"{constants.PROJECT_NAME} will restart in {constants.PROFILE_UPDATE_RESTART_MIN} minutes "
f"to apply the {profile_name} profile update.",
markdown_format=commons_enums.MarkdownFormat.ITALIC
)
)
asyncio.create_task(self.delayed_restart(
constants.PROFILE_UPDATE_RESTART_MIN * commons_constants.MINUTE_TO_SECONDS
))

async def _ensure_watchers(self):
if constants.ENABLE_SYSTEM_WATCHER:
await system_resources_watcher.start_system_resources_watcher(
Expand Down

0 comments on commit e2d537f

Please sign in to comment.