diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index affb115a10..ebafd6a521 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -324,7 +324,6 @@ def predict(self, dataset, segment="test"): class LSTM(nn.Module): - """LSTM Model Args: @@ -414,7 +413,6 @@ def forward(self, x): class Transformer(nn.Module): - """Transformer Model Args: @@ -475,7 +473,6 @@ def forward(self, x): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index d784aed57e..9daba91153 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -162,13 +162,15 @@ def create_account_instance( init_cash=init_cash, position_dict=position_dict, pos_type=pos_type, - benchmark_config={} - if benchmark is None - else { - "benchmark": benchmark, - "start_time": start_time, - "end_time": end_time, - }, + benchmark_config=( + {} + if benchmark is None + else { + "benchmark": benchmark, + "start_time": start_time, + "end_time": end_time, + } + ), ) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 8e7440ba9e..e7c6041efd 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -622,9 +622,11 @@ def cal_trade_indicators( print( "[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format( freq, - trade_start_time - if isinstance(trade_start_time, str) - else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"), + ( + trade_start_time + if isinstance(trade_start_time, str) + else trade_start_time.strftime("%Y-%m-%d %H:%M:%S") + ), fulfill_rate, price_advantage, positive_rate, diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 95ec9b91e9..86d366d205 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -3,6 +3,7 @@ The interface should be redesigned carefully in the future. """ + import pandas as pd from typing import Tuple from qlib import get_module_logger diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index 964febf11c..bc9a6aa977 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -511,7 +511,6 @@ def predict(self, dataset, segment="test"): class RNN(nn.Module): - """RNN Model Args: @@ -601,7 +600,6 @@ def forward(self, x): class Transformer(nn.Module): - """Transformer Model Args: @@ -649,7 +647,6 @@ def forward(self, x): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index 9ba960eebd..bad19ddfdc 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -373,7 +373,6 @@ def generate_trade_decision(self, execute_result=None): class EnhancedIndexingStrategy(WeightStrategyBase): - """Enhanced Indexing Strategy Enhanced indexing combines the arts of active management and passive management, diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 76f3ed4048..f19dfe08fa 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -71,15 +71,11 @@ def fetch_df_by_index( if fetch_orig: for slc in idx_slc: if slc != slice(None, None): - return df.loc[ - pd.IndexSlice[idx_slc], - ] # noqa: E231 + return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231 else: # pylint: disable=W0120 return df else: - return df.loc[ - pd.IndexSlice[idx_slc], - ] # noqa: E231 + return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231 def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index ede1f8e3ad..1ebb16f18b 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -30,7 +30,6 @@ def __call__(self, ensemble_dict: dict, *args, **kwargs): class SingleKeyEnsemble(Ensemble): - """ Extract the object if there is only one key and value in the dict. Make the result more readable. {Only key: Only value} -> Only value @@ -64,7 +63,6 @@ def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) - class RollingEnsemble(Ensemble): - """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index b2594f707d..c3c0e48ef8 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -247,9 +247,7 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np v1 = y.T.dot(z) / t - cov_mkt[:, None] * S roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt v3 = z.T.dot(z) / t - var_mkt * S - roff3 = ( - np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 - ) + roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index f2988d843f..d545e4bc9a 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -90,7 +90,6 @@ def get_collector(self) -> Collector: class RollingStrategy(OnlineStrategy): - """ This example strategy always uses the latest rolling model sas online models. """ diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 92abc8beec..a65b1f58ee 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -146,9 +146,7 @@ def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: return ( self._include_fields if self._include_fields - else set(df_columns) - set(self._exclude_fields) - if self._exclude_fields - else df_columns + else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns ) @staticmethod diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py index 34d304ed78..1ca9cfc942 100644 --- a/scripts/dump_pit.py +++ b/scripts/dump_pit.py @@ -132,9 +132,11 @@ def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]: return ( set(self._include_fields) if self._include_fields - else set(df[self.field_column_name]) - set(self._exclude_fields) - if self._exclude_fields - else set(df[self.field_column_name]) + else ( + set(df[self.field_column_name]) - set(self._exclude_fields) + if self._exclude_fields + else set(df[self.field_column_name]) + ) ) def get_filenames(self, symbol, field, interval):