diff --git a/.gitignore b/.gitignore index a15111e..db0ef56 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,9 @@ coverage.xml .pytest_cache/ cover/ +nfl_data_py/tests + + # Translations *.mo *.pot diff --git a/nfl_data_py/__init__.py b/nfl_data_py/__init__.py index 604ba7c..15d5adc 100644 --- a/nfl_data_py/__init__.py +++ b/nfl_data_py/__init__.py @@ -764,16 +764,17 @@ def import_ids(columns=None, ids=None): rem_cols = [x for x in df.columns if x not in avail_ids] tgt_ids = [x + '_id' for x in ids] - + + ret_columns = set(rem_cols + tgt_ids) # filter df to just specified columns if len(columns) > 0 and len(ids) > 0: - df = df[set(tgt_ids + columns)] + ret_columns = set(tgt_ids + columns) elif len(columns) > 0 and len(ids) == 0: - df = df[set(avail_ids + columns)] + ret_columns = set(avail_ids + columns) elif len(columns) == 0 and len(ids) > 0: - df = df[set(tgt_ids + rem_cols)] + ret_columns = set(tgt_ids + rem_cols) - return df + return df[list(ret_columns)] def import_contracts(): diff --git a/nfl_data_py/tests/nfl_test.py b/nfl_data_py/tests/nfl_test.py index 290b968..2c7eb19 100644 --- a/nfl_data_py/tests/nfl_test.py +++ b/nfl_data_py/tests/nfl_test.py @@ -167,6 +167,29 @@ def test_is_df_with_data(self): self.assertEqual(True, isinstance(s, pd.DataFrame)) self.assertTrue(len(s) > 0) + def test_import_using_ids(self): + ids = ["espn", "yahoo", "gsis"] + s = nfl.import_ids(ids=ids) + self.assertTrue(all([f"{id}_id" in s.columns for id in ids])) + + def test_import_using_columns(self): + ret_columns = ["name", "birthdate", "college"] + not_ret_columns = ["draft_year", "db_season", "team"] + s = nfl.import_ids(columns=ret_columns) + self.assertTrue(all([column in s.columns for column in ret_columns])) + self.assertTrue(all([column not in s.columns for column in not_ret_columns])) + + def test_import_using_ids_and_columns(self): + ret_ids = ["espn", "yahoo", "gsis"] + ret_columns = ["name", "birthdate", "college"] + not_ret_ids = ["cfbref_id", "pff_id", "prf_id"] + not_ret_columns = ["draft_year", "db_season", "team"] + s = nfl.import_ids(columns=ret_columns, ids=ret_ids) + self.assertTrue(all([column in s.columns for column in ret_columns])) + self.assertTrue(all([column not in s.columns for column in not_ret_columns])) + self.assertTrue(all([f"{id}_id" in s.columns for id in ret_ids])) + self.assertTrue(all([f"{id}_id" not in s.columns for id in not_ret_ids])) + class test_ngs(TestCase): def test_is_df_with_data(self): s = nfl.import_ngs_data('passing')