Skip to content

Commit

Permalink
Merge branch 'develop' into feature/docs-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
hjoaquim authored May 14, 2024
2 parents ecc6b3b + 29dfc7b commit a3d51a0
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 56 deletions.
4 changes: 3 additions & 1 deletion cli/openbb_cli/argparse_translator/obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def _contains_obbject(uuid: str, obbjects: List[OBBject]) -> bool:
"""Check if obbject with uuid is in the registry."""
return any(obbject.id == uuid for obbject in obbjects)

def register(self, obbject: OBBject):
def register(self, obbject: OBBject) -> bool:
"""Designed to add an OBBject instance to the registry."""
if isinstance(obbject, OBBject) and not self._contains_obbject(
obbject.id, self._obbjects
):
self._obbjects.append(obbject)
return True
return False

def get(self, idx: int) -> OBBject:
"""Return the obbject at index idx."""
Expand Down
94 changes: 51 additions & 43 deletions cli/openbb_cli/controllers/base_platform_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,54 +153,65 @@ def method(self, other_args: List[str], translator=translator):
ns_parser = self._intersect_data_processing_commands(ns_parser)

obbject = translator.execute_func(parsed_args=ns_parser)
df: pd.DataFrame = None
df: pd.DataFrame = pd.DataFrame()
fig: OpenBBFigure = None
title = f"{self.PATH}{translator.func.__name__}"

if obbject:
if session.max_obbjects_exceeded():
session.obbject_registry.remove()

# use the obbject to store the command so we can display it later on results
obbject.extra["command"] = f"{title} {' '.join(other_args)}"

session.obbject_registry.register(obbject)
# we need to force to re-link so that the new obbject
# is immediately available for data processing commands
self._link_obbject_to_data_processing_commands()
# also update the completer
self.update_completer(self.choices_default)

if session.settings.SHOW_MSG_OBBJECT_REGISTRY and isinstance(
obbject, OBBject
):
session.console.print("Added OBBject to registry.")

if hasattr(ns_parser, "chart") and ns_parser.chart:
obbject.show()
fig = obbject.chart.fig
if hasattr(obbject, "to_dataframe"):

if isinstance(obbject, OBBject) and obbject.results:
if session.max_obbjects_exceeded():
session.obbject_registry.remove()
session.console.print(
"[yellow]Maximum number of OBBjects reached. The oldest entry was removed.[yellow]"
)

# use the obbject to store the command so we can display it later on results
obbject.extra["command"] = f"{title} {' '.join(other_args)}"

register_result = session.obbject_registry.register(obbject)

# we need to force to re-link so that the new obbject
# is immediately available for data processing commands
self._link_obbject_to_data_processing_commands()
# also update the completer
self.update_completer(self.choices_default)

if (
session.settings.SHOW_MSG_OBBJECT_REGISTRY
and register_result
):
session.console.print("Added OBBject to registry.")

# making the dataframe available
# either for printing or exporting (or both)
df = obbject.to_dataframe()
elif isinstance(obbject, dict):
df = pd.DataFrame.from_dict(obbject, orient="index")
else:
df = None

elif hasattr(obbject, "to_dataframe"):
df = obbject.to_dataframe()
if isinstance(df.columns, pd.RangeIndex):
df.columns = [str(i) for i in df.columns]
print_rich_table(df=df, show_index=True, title=title)
if hasattr(ns_parser, "chart") and ns_parser.chart:
obbject.show()
fig = obbject.chart.fig if obbject.chart else None
else:
if isinstance(df.columns, pd.RangeIndex):
df.columns = [str(i) for i in df.columns]

print_rich_table(df=df, show_index=True, title=title)

elif isinstance(obbject, dict):
df = pd.DataFrame.from_dict(obbject, orient="index")
print_rich_table(df=df, show_index=True, title=title)
elif isinstance(obbject, dict):
df = pd.DataFrame.from_dict(obbject, orient="columns")
print_rich_table(df=df, show_index=True, title=title)

elif obbject:
session.console.print(obbject)
elif not isinstance(obbject, OBBject):
session.console.print(obbject)

if hasattr(ns_parser, "export") and ns_parser.export:
if (
hasattr(ns_parser, "export")
and ns_parser.export
and not df.empty
):
sheet_name = getattr(ns_parser, "sheet_name", None)
if sheet_name and isinstance(sheet_name, list):
sheet_name = sheet_name[0]

export_data(
export_type=",".join(ns_parser.export),
dir_path=os.path.dirname(os.path.abspath(__file__)),
Expand All @@ -209,11 +220,8 @@ def method(self, other_args: List[str], translator=translator):
sheet_name=sheet_name,
figure=fig,
)

if session.max_obbjects_exceeded():
session.console.print(
"[yellow]\nMaximum number of OBBjects reached. The oldest entry was removed.[yellow]"
)
elif hasattr(ns_parser, "export") and ns_parser.export and df.empty:
session.console.print("[yellow]No data to export.[/yellow]")

except Exception as e:
session.console.print(f"[red]{e}[/]\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def get_function_hint_type_list(cls, func: Callable) -> List[Type]:
hint_type_list.append(parameter.annotation)

if return_type:
if not issubclass(return_type, OBBject):
raise ValueError("Return type must be an OBBject.")
hint_type = get_args(get_type_hints(return_type)["results"])[0]
hint_type_list.append(hint_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

# pylint: disable=unused-argument

import warnings
from datetime import date as dateType
from typing import Any, Dict, List, Literal, Optional, Union
from warnings import warn

from openbb_core.provider.abstract.fetcher import Fetcher
from openbb_core.provider.standard_models.historical_eps import (
Expand All @@ -20,8 +20,6 @@
)
from pydantic import Field, field_validator

_warn = warnings.warn


class AlphaVantageHistoricalEpsQueryParams(HistoricalEpsQueryParams):
"""
Expand Down Expand Up @@ -103,17 +101,13 @@ async def aextract_data(
**kwargs: Any,
) -> List[Dict]:
"""Return the raw data from the AlphaVantage endpoint."""

api_key = credentials.get("alpha_vantage_api_key") if credentials else ""

BASE_URL = "https://www.alphavantage.co/query?function=EARNINGS&"

# We are allowing multiple symbols to be passed in the query, so we need to handle that.
symbols = query.symbol.split(",")

urls = [f"{BASE_URL}symbol={symbol}&apikey={api_key}" for symbol in symbols]

results = []
results: List = []
messages: List = []

# We need to make a custom callback function for this async request.
async def response_callback(response: ClientResponse, _: ClientSession):
Expand All @@ -123,7 +117,11 @@ async def response_callback(response: ClientResponse, _: ClientSession):
target = (
"annualEarnings" if query.period == "annual" else "quarterlyEarnings"
)
result = []
message = data.get("Information", "")
if message:
messages.append(message)
warn(f"Symbol Error for {symbol}: {message}")
result: List = []
# If data is returned, append it to the results list.
if data:
result = [
Expand All @@ -137,13 +135,15 @@ async def response_callback(response: ClientResponse, _: ClientSession):
results.extend(result[: query.limit])
else:
results.extend(result)

# If no data is returned, raise a warning and move on to the next symbol.
if not data:
_warn(f"Symbol Error: No data found for {symbol}")
warn(f"Symbol Error: No data found for {symbol}")

await amake_requests(urls, response_callback, **kwargs) # type: ignore

if not results:
raise EmptyDataError(f"No data was returned -> \n{messages[-1]}")

return results

@staticmethod
Expand Down

0 comments on commit a3d51a0

Please sign in to comment.