Skip to content

Commit

Permalink
180 cleanup for 195 (#254)
Browse files Browse the repository at this point in the history
* feat: add defences to _add_validation_row()

* feat: better null checking in filter_gtfs_around_trip()

* fix: remove non-implemented param from test

* test: add passing test for _gtfs_instance

* feat: warning in drop_trips when stop_id isn't in the GTFS

* fix: Update failing tests due to GTFS fixture feed expiring

* fix: Ensure dask complete is installed

* chore: remove old TODO comment

* chore: add info

---------

Co-authored-by: r-leyshon <[email protected]>
  • Loading branch information
CBROWN-ONS and r-leyshon authored Mar 19, 2024
1 parent 7bb4863 commit 51ee214
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 25 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ kaleido
numpy>=1.25.0 # test suite will fail if user installed lower than this
sphinx
sphinx-rtd-theme
dask
dask[dataframe]
pyarrow >= 14.0.1 # 14.0.0 has security vulnerability
osmium # has dependencies on `cmake` and `boost` which require brew install
tqdm
Expand Down
7 changes: 7 additions & 0 deletions src/transport_performance/gtfs/cleaners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A set of functions that clean the gtfs data."""
from typing import Union
import warnings

import numpy as np

Expand Down Expand Up @@ -45,6 +46,12 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
exp_type=str,
)

# warn users if passed one of the passed trip_id's is not present in the
# GTFS.
for _id in trip_id:
if _id not in gtfs.feed.trips.trip_id.unique():
warnings.warn(UserWarning(f"trip_id '{_id}' not found in GTFS"))

# drop relevant records from tables
gtfs.feed.trips = gtfs.feed.trips[
~gtfs.feed.trips["trip_id"].isin(trip_id)
Expand Down
40 changes: 24 additions & 16 deletions src/transport_performance/gtfs/gtfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pyprojroot import here
import pandas as pd
import os
import math
import plotly.graph_objects as go
from typing import Union, TYPE_CHECKING
import pathlib
Expand All @@ -19,9 +18,12 @@
_is_expected_filetype,
_check_iterable,
_type_defence,
_check_attribute,
_gtfs_defence,
_validate_datestring,
_enforce_file_extension,
_gtfs_defence,
_check_parent_dir_exists,
_check_item_in_iter,
)
from transport_performance.utils.constants import PKG_PATH

Expand Down Expand Up @@ -284,14 +286,20 @@ def _add_validation_row(
An error is raised if the validity df does not exist
"""
# TODO: add dtype defences from defence.py once gtfs-html-new is merged
if "validity_df" not in gtfs.__dict__.keys():
raise AttributeError(
_gtfs_defence(gtfs, "gtfs")
_type_defence(_type, "_type", str)
_type_defence(message, "message", str)
_type_defence(rows, "rows", list)
_check_attribute(
gtfs,
"validity_df",
message=(
"The validity_df does not exist as an "
"attribute of your GtfsInstance object, \n"
"Did you forget to run the .is_valid() method?"
)

),
)
_check_item_in_iter(_type, ["warning", "error"], "_type")
temp_df = pd.DataFrame(
{
"type": [_type],
Expand Down Expand Up @@ -343,21 +351,21 @@ def filter_gtfs_around_trip(
An error is raised if a shapeID is not available
"""
# TODO: Add datatype defences once merged
_gtfs_defence(gtfs, "gtfs")
_type_defence(trip_id, "trip_id", str)
_type_defence(buffer_dist, "buffer_dist", int)
_type_defence(crs, "crs", str)
_check_parent_dir_exists(out_pth, "out_pth", create=True)
trips = gtfs.feed.trips
shapes = gtfs.feed.shapes

shape_id = list(trips[trips["trip_id"] == trip_id]["shape_id"])[0]

# defence
# try/except for math.isnan() returning TypeError for strings
try:
if math.isnan(shape_id):
raise ValueError(
"'shape_id' not available for trip with trip_id: " f"{trip_id}"
)
except TypeError:
pass
if pd.isna(shape_id):
raise ValueError(
"'shape_id' not available for trip with trip_id: " f"{trip_id}"
)

# create a buffer around the trip
trip_shape = shapes[shapes["shape_id"] == shape_id]
Expand Down
6 changes: 3 additions & 3 deletions tests/gtfs/test_gtfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Test_AddValidationRow(object):
"""Tests for _add_validation_row()."""

def test__add_validation_row_defence(self):
"""Defensive tests for _add_test_validation_row()."""
"""Defensive tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
with pytest.raises(
AttributeError,
Expand All @@ -291,9 +291,9 @@ def test__add_validation_row_defence(self):
)

def test__add_validation_row_on_pass(self):
"""General tests for _add_test_validation_row()."""
"""General tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
gtfs.is_valid(far_stops=False)
gtfs.is_valid()

_add_validation_row(
gtfs=gtfs, _type="warning", message="test", table="stops"
Expand Down
12 changes: 9 additions & 3 deletions tests/gtfs/test_multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ def test_clean_feeds_on_pass(self, multi_gtfs_fixture):
"""General tests for .clean_feeds()."""
# validate and do quick check on validity_df
valid_df = multi_gtfs_fixture.is_valid()
assert len(valid_df) == 12, "validity_df not as expected"
n = 13
n_out = len(valid_df)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
# clean feed
multi_gtfs_fixture.clean_feeds()
# ensure cleaning has occured
new_valid = multi_gtfs_fixture.is_valid()
assert len(new_valid) == 9
n = 10
n_out = len(new_valid)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
assert np.array_equal(
list(new_valid.iloc[3][["type", "table"]].values),
["error", "routes"],
Expand All @@ -290,7 +294,9 @@ def test_is_valid_defences(self, multi_gtfs_fixture):
def test_is_valid_on_pass(self, multi_gtfs_fixture):
"""General tests for is_valid()."""
valid_df = multi_gtfs_fixture.is_valid()
assert len(valid_df) == 12, "Validation df not as expected"
n = 13
n_out = len(valid_df)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
assert np.array_equal(
list(valid_df.iloc[3][["type", "message"]].values),
(["warning", "Fast Travel Between Consecutive Stops"]),
Expand Down
5 changes: 3 additions & 2 deletions tests/gtfs/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def test_is_valid(self, gtfs_fixture):
), f"Expected DataFrame. Found: {type(gtfs_fixture.validity_df)}"
shp = gtfs_fixture.validity_df.shape
assert shp == (
7,
8,
4,
), f"Attribute `validity_df` expected a shape of (7,4). Found: {shp}"
), f"Attribute `validity_df` expected a shape of (8,4). Found: {shp}"
exp_cols = pd.Index(["type", "message", "table", "rows"])
found_cols = gtfs_fixture.validity_df.columns
assert (
Expand Down Expand Up @@ -334,6 +334,7 @@ def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture):
fun_out = mocked_print.mock_calls
assert fun_out == [
call("Unrecognized column agency_noc"),
call("Feed expired"),
call("Repeated pair (route_short_name, route_long_name)"),
call("Unrecognized column stop_direction_name"),
call("Unrecognized column platform_code"),
Expand Down
11 changes: 11 additions & 0 deletions tests/utils/test_defence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from _pytest.python_api import RaisesContext
import pandas as pd
from pyprojroot import here
from contextlib import nullcontext as does_not_raise

# INFO on the use of 'does_not_raise'
# https://docs.pytest.org/en/6.2.x/example/parametrize.html...#parametrizing...
# -conditional-raising
#

from transport_performance.utils.defence import (
_check_iterable,
Expand All @@ -21,6 +27,7 @@
_is_expected_filetype,
_enforce_file_extension,
)
from transport_performance.gtfs.validation import GtfsInstance


class Test_CheckIter(object):
Expand Down Expand Up @@ -251,6 +258,10 @@ def test__gtfs_defence():
),
):
_gtfs_defence("tester", "test")
# passing test
with does_not_raise():
gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip"))
_gtfs_defence(gtfs=gtfs, param_nm="gtfs")


class Test_TypeDefence(object):
Expand Down

0 comments on commit 51ee214

Please sign in to comment.