Skip to content

Commit 893fe91

Browse files
committed
chart integration tests
1 parent a5ec56e commit 893fe91

File tree

5 files changed

+264
-3
lines changed

5 files changed

+264
-3
lines changed

openbb_platform/obbject_extensions/charting/integration/test_charting_api.py

+39
Original file line numberDiff line numberDiff line change
@@ -719,3 +719,42 @@ def test_charting_etf_holdings(params, headers):
719719
assert chart
720720
assert not fig
721721
assert list(chart.keys()) == ["content", "format"]
722+
723+
724+
@parametrize(
725+
"params",
726+
[
727+
(
728+
{
729+
"provider": "econdb",
730+
"country": "united_kingdom",
731+
"date": None,
732+
"chart": True,
733+
}
734+
),
735+
(
736+
{
737+
"provider": "fred",
738+
"date": "2023-05-10,2024-05-10",
739+
"chart": True,
740+
}
741+
),
742+
],
743+
)
744+
@pytest.mark.integration
745+
def test_charting_fixedincome_government_yield_curve(params, headers):
746+
"""Test chart fixedincome government yield curve."""
747+
params = {p: v for p, v in params.items() if v}
748+
body = (json.dumps({"extra_params": {"chart_params": {"title": "test chart"}}}),)
749+
query_str = get_querystring(params, [])
750+
url = f"http://0.0.0.0:8000/api/v1/fixedincome/government/yield_curve?{query_str}"
751+
result = requests.get(url, headers=headers, timeout=10, json=body)
752+
assert isinstance(result, requests.Response)
753+
assert result.status_code == 200
754+
755+
chart = result.json()["chart"]
756+
fig = chart.pop("fig", {})
757+
758+
assert chart
759+
assert not fig
760+
assert list(chart.keys()) == ["content", "format"]

openbb_platform/obbject_extensions/charting/integration/test_charting_python.py

+31
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,34 @@ def test_charting_etf_holdings(params, obb):
590590
assert len(result.results) > 0
591591
assert result.chart.content
592592
assert isinstance(result.chart.fig, OpenBBFigure)
593+
594+
595+
@parametrize(
596+
"params",
597+
[
598+
(
599+
{
600+
"provider": "econdb",
601+
"country": "united_kingdom",
602+
"date": None,
603+
"chart": True,
604+
}
605+
),
606+
(
607+
{
608+
"provider": "fred",
609+
"date": "2023-05-10,2024-05-10",
610+
"chart": True,
611+
}
612+
),
613+
],
614+
)
615+
@pytest.mark.integration
616+
def test_charting_fixedincome_government_yield_curve(params, obb):
617+
"""Test chart fixedincome government yield curve."""
618+
result = obb.fixedincome.government.yield_curve(**params)
619+
assert result
620+
assert isinstance(result, OBBject)
621+
assert len(result.results) > 0
622+
assert result.chart.content
623+
assert isinstance(result.chart.fig, OpenBBFigure)

openbb_platform/obbject_extensions/charting/openbb_charting/charting_router.py

+158-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
from openbb_core.app.model.charts.chart import ChartFormat
1010
from openbb_core.app.utils import basemodel_to_df
11+
from openbb_core.provider.abstract.data import Data
1112
from plotly.graph_objs import Figure
1213

1314
from openbb_charting.core.chart_style import ChartStyle
@@ -19,6 +20,7 @@
1920
from openbb_charting.utils.generic_charts import bar_chart
2021
from openbb_charting.utils.helpers import (
2122
calculate_returns,
23+
duration_sorter,
2224
heikin_ashi,
2325
should_share_axis,
2426
z_score_standardization,
@@ -406,7 +408,7 @@ def equity_price_historical( # noqa: PLR0912
406408
name=data[col].name,
407409
mode="lines",
408410
hovertemplate=hovertemplate,
409-
line=dict(width=1, color=LARGE_CYCLER[i % len(LARGE_CYCLER)]),
411+
line=dict(width=2, color=LARGE_CYCLER[i % len(LARGE_CYCLER)]),
410412
yaxis=yaxis,
411413
)
412414

@@ -608,7 +610,7 @@ def _ta_ma(**kwargs):
608610
name=name,
609611
mode="lines",
610612
hovertemplate=f"{name}: %{{y}}<extra></extra>",
611-
line=dict(width=1, color=LARGE_CYCLER[color]),
613+
line=dict(width=2, color=LARGE_CYCLER[color]),
612614
showlegend=True,
613615
)
614616
color += 1
@@ -1119,7 +1121,7 @@ def economy_fred_series( # noqa: PLR0912
11191121
name=df_ta[col].name,
11201122
mode="lines",
11211123
hovertemplate=f"{df_ta[col].name}: %{{y}}<extra></extra>",
1122-
line=dict(width=1, color=LARGE_CYCLER[i % len(LARGE_CYCLER)]),
1124+
line=dict(width=2, color=LARGE_CYCLER[i % len(LARGE_CYCLER)]),
11231125
yaxis="y1" if kwargs.get("same_axis") is True else yaxes,
11241126
)
11251127

@@ -1292,3 +1294,156 @@ def technical_relative_rotation(
12921294
content = figure.to_plotly_json()
12931295

12941296
return figure, content
1297+
1298+
1299+
def fixedincome_government_yield_curve( # noqa: PLR0912
1300+
**kwargs,
1301+
) -> Tuple[OpenBBFigure, Dict[str, Any]]:
1302+
"""Government Yield Curve Chart."""
1303+
data = kwargs.get("data", None)
1304+
df: pd.DataFrame = pd.DataFrame()
1305+
if data:
1306+
if isinstance(data, pd.DataFrame) and not data.empty: # noqa: SIM108
1307+
df = data
1308+
elif isinstance(data, (list, Data)):
1309+
df = basemodel_to_df(data, index=None) # type: ignore
1310+
else:
1311+
pass
1312+
else:
1313+
df = pd.DataFrame([d.model_dump() for d in kwargs["obbject_item"]]) # type: ignore
1314+
1315+
if df.empty:
1316+
raise ValueError("Error: No data to plot.")
1317+
1318+
if "maturity" not in df.columns:
1319+
raise ValueError("Error: Maturity column not found in the data.")
1320+
1321+
if "rate" not in df.columns:
1322+
raise ValueError("Error: Rate column not found in the data.")
1323+
1324+
if "date" not in df.columns:
1325+
raise ValueError("Error: Date column not found in the data.")
1326+
1327+
provider = kwargs.get("provider")
1328+
df["date"] = df["date"].astype(str)
1329+
maturities = duration_sorter(df["maturity"].unique().tolist())
1330+
1331+
# Use the supplied colors, if any.
1332+
colors = kwargs.get("colors", [])
1333+
if not colors:
1334+
colors = LARGE_CYCLER
1335+
color_count = 0
1336+
1337+
figure = OpenBBFigure().create_subplots(shared_xaxes=False)
1338+
figure.update_layout(ChartStyle().plotly_template.get("layout", {}))
1339+
1340+
def create_fig(figure, df, dates, color_count, country: Optional[str] = None):
1341+
"""Create a scatter for each date in the data."""
1342+
for date in dates:
1343+
color = colors[color_count % len(colors)]
1344+
plot_df = df[df["date"] == date].copy()
1345+
plot_df["rate"] = plot_df["rate"].apply(lambda x: x * 100)
1346+
plot_df = (
1347+
plot_df.drop(columns=["date"])
1348+
.set_index("maturity")
1349+
.filter(items=maturities, axis=0)
1350+
.reset_index()
1351+
.rename(columns={"maturity": "Maturity", "rate": "Yield"})
1352+
)
1353+
plot_df["Maturity"] = [
1354+
(
1355+
d.split("_")[1] + " " + d.split("_")[0].title()
1356+
if d != "long_term"
1357+
else "Long Term"
1358+
)
1359+
for d in plot_df["Maturity"]
1360+
]
1361+
figure.add_scatter(
1362+
x=plot_df["Maturity"],
1363+
y=plot_df["Yield"],
1364+
# fill=fill,
1365+
mode="lines+markers",
1366+
name=f"{country} - {date}" if country else date,
1367+
line=dict(width=3, color=color),
1368+
marker=dict(size=10, color=color),
1369+
hovertemplate=(
1370+
"Maturity: %{x}<br>Yield: %{y}%<extra></extra>"
1371+
if len(dates) == 1
1372+
else "%{fullData.name}<br>Maturity: %{x}<br>Yield: %{y}%<extra></extra>"
1373+
),
1374+
)
1375+
color_count += 1
1376+
return figure, color_count
1377+
1378+
dates = df.date.unique().tolist()
1379+
figure, color_count = create_fig(figure, df, dates, color_count)
1380+
1381+
# Set the title for the chart
1382+
country: str = ""
1383+
if provider in ("federal_reserve", "fmp"):
1384+
country = "United States"
1385+
elif provider == "ecb":
1386+
curve_type = (
1387+
getattr(kwargs["extra_params"], "yield_curve_type", "")
1388+
.replace("_", " ")
1389+
.title()
1390+
)
1391+
grade = getattr(kwargs["extra_params"], "rating", "").replace("_", " ")
1392+
grade = grade.upper() if grade == "aaa" else "All Ratings"
1393+
country = f"Euro Area ({grade}) {curve_type}"
1394+
elif provider == "fred":
1395+
curve_type = getattr(kwargs["extra_params"], "yield_curve_type", "")
1396+
curve_type = (
1397+
"Real Rates"
1398+
if curve_type == "real"
1399+
else curve_type.replace("_", " ").title()
1400+
)
1401+
country = f"United States {curve_type}"
1402+
elif provider == "econdb":
1403+
country = kwargs["standard_params"].get("country")
1404+
country = country.replace("_", " ").title() if country else "United States"
1405+
country = country + " " if country else ""
1406+
title = kwargs.get("title", "")
1407+
if not title:
1408+
title = f"{country}Yield Curve"
1409+
if len(dates) == 1:
1410+
title = f"{country} Yield Curve - {dates[0]}"
1411+
1412+
# Update the layout of the figure.
1413+
figure.update_layout(
1414+
title=dict(text=title, x=0.5, font=dict(size=20)),
1415+
plot_bgcolor="rgba(0,0,0,0)",
1416+
xaxis=dict(
1417+
title="Maturity",
1418+
ticklen=0,
1419+
showgrid=False,
1420+
),
1421+
yaxis=dict(
1422+
title="Yield (%)",
1423+
ticklen=0,
1424+
showgrid=True,
1425+
gridcolor="rgba(128,128,128,0.3)",
1426+
),
1427+
legend=dict(
1428+
orientation="v",
1429+
yanchor="top",
1430+
xanchor="right",
1431+
y=0.95,
1432+
x=0,
1433+
xref="paper",
1434+
font=dict(size=12),
1435+
bgcolor="rgba(0,0,0,0)",
1436+
),
1437+
margin=dict(
1438+
b=25,
1439+
t=10,
1440+
),
1441+
)
1442+
1443+
layout_kwargs = kwargs.get("layout_kwargs", {})
1444+
if layout_kwargs:
1445+
figure.update_layout(layout_kwargs)
1446+
1447+
content = figure.show(external=True)
1448+
1449+
return figure, content

openbb_platform/obbject_extensions/charting/openbb_charting/query_params.py

+17
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,23 @@ class MAQueryParams(ChartQueryParams):
262262
)
263263

264264

265+
class FixedincomeGovernmentYieldCurve(ChartQueryParams):
266+
"""Fixed Income Government Yield Curve Chart Query Params."""
267+
268+
title: Optional[str] = Field(
269+
default=None,
270+
description="Title of the chart.",
271+
)
272+
colors: Optional[List[str]] = Field(
273+
default=None,
274+
description="List of colors to use for the lines.",
275+
)
276+
layout_kwargs: Optional[Dict[str, Any]] = Field(
277+
default=None,
278+
description="Additional keyword arguments to pass to the Plotly `update_layout` method.",
279+
)
280+
281+
265282
class TechnicalSMAChartQueryParams(MAQueryParams):
266283
"""Technical SMA Chart Query Params."""
267284

openbb_platform/obbject_extensions/charting/openbb_charting/utils/helpers.py

+19
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,22 @@ def heikin_ashi(data: pd.DataFrame) -> pd.DataFrame:
9292
df[item] = ha[f"HA_{item}"]
9393

9494
return df
95+
96+
97+
def duration_sorter(durations: list) -> list:
98+
"""Sort durations labeled as month_5, year_5, etc."""
99+
100+
def duration_to_months(duration):
101+
"""Convert duration to months."""
102+
if duration == "long_term":
103+
return 360
104+
parts = duration.split("_")
105+
months = 0
106+
for i in range(0, len(parts), 2):
107+
number = int(parts[i + 1])
108+
if parts[i] == "year":
109+
number *= 12 # Convert years to months
110+
months += number
111+
return months
112+
113+
return sorted(durations, key=duration_to_months)

0 commit comments

Comments
 (0)