|
1 | 1 | from typing import Any, Callable, Dict, List, Optional, Tuple |
2 | 2 | import pandas as pd |
3 | 3 | import numpy as np |
4 | | -from functools import reduce |
5 | 4 | from .definitions import evaluated_cols |
6 | 5 | from .checks import _run_checks |
7 | 6 |
|
@@ -170,13 +169,25 @@ def _calculate_otm_pct(data: pd.DataFrame) -> pd.DataFrame: |
170 | 169 | ) |
171 | 170 |
|
172 | 171 |
|
| 172 | +def _get_leg_quantity(leg: Tuple) -> int: |
| 173 | + """Get quantity for a leg, defaulting to 1 if not specified.""" |
| 174 | + return leg[2] if len(leg) > 2 else 1 |
| 175 | + |
| 176 | + |
173 | 177 | def _apply_ratios(data: pd.DataFrame, leg_def: List[Tuple]) -> pd.DataFrame: |
174 | | - """Apply position ratios (long/short multipliers) to entry and exit prices.""" |
| 178 | + """Apply position ratios (long/short multipliers) and quantities to entry and exit prices.""" |
175 | 179 | for idx in range(1, len(leg_def) + 1): |
176 | 180 | entry_col = f"entry_leg{idx}" |
177 | 181 | exit_col = f"exit_leg{idx}" |
178 | | - entry_kwargs = {entry_col: lambda r: r[entry_col] * leg_def[idx - 1][0].value} |
179 | | - exit_kwargs = {exit_col: lambda r: r[exit_col] * leg_def[idx - 1][0].value} |
| 182 | + leg = leg_def[idx - 1] |
| 183 | + multiplier = leg[0].value * _get_leg_quantity(leg) |
| 184 | + # Use default arguments to capture values at each iteration (avoid late binding) |
| 185 | + entry_kwargs = { |
| 186 | + entry_col: lambda r, col=entry_col, m=multiplier: r[col] * m |
| 187 | + } |
| 188 | + exit_kwargs = { |
| 189 | + exit_col: lambda r, col=exit_col, m=multiplier: r[col] * m |
| 190 | + } |
180 | 191 | data = data.assign(**entry_kwargs).assign(**exit_kwargs) |
181 | 192 |
|
182 | 193 | return data |
@@ -206,6 +217,16 @@ def _assign_profit( |
206 | 217 | return data |
207 | 218 |
|
208 | 219 |
|
| 220 | +def _rename_leg_columns( |
| 221 | + data: pd.DataFrame, leg_idx: int, join_on: List[str] |
| 222 | +) -> pd.DataFrame: |
| 223 | + """Rename columns with leg suffix, excluding join columns.""" |
| 224 | + rename_map = { |
| 225 | + col: f"{col}_leg{leg_idx}" for col in data.columns if col not in join_on |
| 226 | + } |
| 227 | + return data.rename(columns=rename_map) |
| 228 | + |
| 229 | + |
209 | 230 | def _strategy_engine( |
210 | 231 | data: pd.DataFrame, |
211 | 232 | leg_def: List[Tuple], |
@@ -237,19 +258,19 @@ def _rule_func( |
237 | 258 | ) -> pd.DataFrame: |
238 | 259 | return d if r is None else r(d, ld) |
239 | 260 |
|
240 | | - partials = [leg[1](data) for leg in leg_def] |
| 261 | + # Pre-rename columns for each leg to avoid suffix issues with 3+ legs |
| 262 | + partials = [ |
| 263 | + _rename_leg_columns(leg[1](data).copy(), idx, join_on or []) |
| 264 | + for idx, leg in enumerate(leg_def, start=1) |
| 265 | + ] |
241 | 266 | suffixes = [f"_leg{idx}" for idx in range(1, len(leg_def) + 1)] |
242 | 267 |
|
243 | | - return ( |
244 | | - reduce( |
245 | | - lambda left, right: pd.merge( |
246 | | - left, right, on=join_on, how="inner", suffixes=suffixes |
247 | | - ), |
248 | | - partials, |
249 | | - ) |
250 | | - .pipe(_rule_func, rules, leg_def) |
251 | | - .pipe(_assign_profit, leg_def, suffixes) |
252 | | - ) |
| 268 | + # Merge all legs sequentially |
| 269 | + result = partials[0] |
| 270 | + for partial in partials[1:]: |
| 271 | + result = pd.merge(result, partial, on=join_on, how="inner") |
| 272 | + |
| 273 | + return result.pipe(_rule_func, rules, leg_def).pipe(_assign_profit, leg_def, suffixes) |
253 | 274 |
|
254 | 275 |
|
255 | 276 | def _process_strategy(data: pd.DataFrame, **context: Any) -> pd.DataFrame: |
|
0 commit comments