Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

180 cleanup for 195 #254

Merged
merged 9 commits into from
Mar 19, 2024
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ kaleido
numpy>=1.25.0 # test suite will fail if user installed lower than this
sphinx
sphinx-rtd-theme
dask
dask[dataframe]
pyarrow >= 14.0.1 # 14.0.0 has security vulnerability
osmium # has dependencies on `cmake` and `boost` which require brew install
tqdm
Expand Down
7 changes: 7 additions & 0 deletions src/transport_performance/gtfs/cleaners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A set of functions that clean the gtfs data."""
from typing import Union
import warnings

import numpy as np

Expand Down Expand Up @@ -45,6 +46,12 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
exp_type=str,
)

# warn users if passed one of the passed trip_id's is not present in the
# GTFS.
for _id in trip_id:
if _id not in gtfs.feed.trips.trip_id.unique():
warnings.warn(UserWarning(f"trip_id '{_id}' not found in GTFS"))

r-leyshon marked this conversation as resolved.
Show resolved Hide resolved
# drop relevant records from tables
gtfs.feed.trips = gtfs.feed.trips[
~gtfs.feed.trips["trip_id"].isin(trip_id)
Expand Down
41 changes: 25 additions & 16 deletions src/transport_performance/gtfs/gtfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pyprojroot import here
import pandas as pd
import os
import math
import plotly.graph_objects as go
from typing import Union, TYPE_CHECKING
import pathlib
Expand All @@ -19,9 +18,12 @@
_is_expected_filetype,
_check_iterable,
_type_defence,
_check_attribute,
_gtfs_defence,
_validate_datestring,
_enforce_file_extension,
_gtfs_defence,
_check_parent_dir_exists,
_check_item_in_iter,
)
from transport_performance.utils.constants import PKG_PATH

Expand Down Expand Up @@ -284,14 +286,20 @@ def _add_validation_row(
An error is raised if the validity df does not exist

"""
# TODO: add dtype defences from defence.py once gtfs-html-new is merged
if "validity_df" not in gtfs.__dict__.keys():
raise AttributeError(
_gtfs_defence(gtfs, "gtfs")
_type_defence(_type, "_type", str)
_type_defence(message, "message", str)
_type_defence(rows, "rows", list)
_check_attribute(
gtfs,
"validity_df",
message=(
"The validity_df does not exist as an "
"attribute of your GtfsInstance object, \n"
"Did you forget to run the .is_valid() method?"
)

),
)
_check_item_in_iter(_type, ["warning", "error"], "_type")
temp_df = pd.DataFrame(
{
"type": [_type],
Expand Down Expand Up @@ -343,21 +351,22 @@ def filter_gtfs_around_trip(
An error is raised if a shapeID is not available

"""
# TODO: Add datatype defences once merged
# NOTE: No defence for units as its deleted later on
_gtfs_defence(gtfs, "gtfs")
_type_defence(trip_id, "trip_id", str)
_type_defence(buffer_dist, "buffer_dist", int)
_type_defence(crs, "crs", str)
_check_parent_dir_exists(out_pth, "out_pth", create=True)
trips = gtfs.feed.trips
shapes = gtfs.feed.shapes

shape_id = list(trips[trips["trip_id"] == trip_id]["shape_id"])[0]

# defence
# try/except for math.isnan() returning TypeError for strings
try:
if math.isnan(shape_id):
raise ValueError(
"'shape_id' not available for trip with trip_id: " f"{trip_id}"
)
except TypeError:
pass
if pd.isna(shape_id):
raise ValueError(
"'shape_id' not available for trip with trip_id: " f"{trip_id}"
)
r-leyshon marked this conversation as resolved.
Show resolved Hide resolved

# create a buffer around the trip
trip_shape = shapes[shapes["shape_id"] == shape_id]
Expand Down
6 changes: 3 additions & 3 deletions tests/gtfs/test_gtfs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Test_AddValidationRow(object):
"""Tests for _add_validation_row()."""

def test__add_validation_row_defence(self):
"""Defensive tests for _add_test_validation_row()."""
"""Defensive tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
with pytest.raises(
AttributeError,
Expand All @@ -291,9 +291,9 @@ def test__add_validation_row_defence(self):
)

def test__add_validation_row_on_pass(self):
"""General tests for _add_test_validation_row()."""
"""General tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
gtfs.is_valid(far_stops=False)
gtfs.is_valid()

_add_validation_row(
gtfs=gtfs, _type="warning", message="test", table="stops"
Expand Down
12 changes: 9 additions & 3 deletions tests/gtfs/test_multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ def test_clean_feeds_on_pass(self, multi_gtfs_fixture):
"""General tests for .clean_feeds()."""
# validate and do quick check on validity_df
valid_df = multi_gtfs_fixture.is_valid()
assert len(valid_df) == 12, "validity_df not as expected"
n = 13
n_out = len(valid_df)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
# clean feed
multi_gtfs_fixture.clean_feeds()
# ensure cleaning has occured
new_valid = multi_gtfs_fixture.is_valid()
assert len(new_valid) == 9
n = 10
n_out = len(new_valid)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
assert np.array_equal(
list(new_valid.iloc[3][["type", "table"]].values),
["error", "routes"],
Expand All @@ -290,7 +294,9 @@ def test_is_valid_defences(self, multi_gtfs_fixture):
def test_is_valid_on_pass(self, multi_gtfs_fixture):
"""General tests for is_valid()."""
valid_df = multi_gtfs_fixture.is_valid()
assert len(valid_df) == 12, "Validation df not as expected"
n = 13
n_out = len(valid_df)
assert n_out == n, f"Expected validity_df of len {n}, found {n_out}"
assert np.array_equal(
list(valid_df.iloc[3][["type", "message"]].values),
(["warning", "Fast Travel Between Consecutive Stops"]),
Expand Down
5 changes: 3 additions & 2 deletions tests/gtfs/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def test_is_valid(self, gtfs_fixture):
), f"Expected DataFrame. Found: {type(gtfs_fixture.validity_df)}"
shp = gtfs_fixture.validity_df.shape
assert shp == (
7,
8,
4,
), f"Attribute `validity_df` expected a shape of (7,4). Found: {shp}"
), f"Attribute `validity_df` expected a shape of (8,4). Found: {shp}"
exp_cols = pd.Index(["type", "message", "table", "rows"])
found_cols = gtfs_fixture.validity_df.columns
assert (
Expand Down Expand Up @@ -334,6 +334,7 @@ def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture):
fun_out = mocked_print.mock_calls
assert fun_out == [
call("Unrecognized column agency_noc"),
call("Feed expired"),
call("Repeated pair (route_short_name, route_long_name)"),
call("Unrecognized column stop_direction_name"),
call("Unrecognized column platform_code"),
Expand Down
6 changes: 6 additions & 0 deletions tests/utils/test_defence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from _pytest.python_api import RaisesContext
import pandas as pd
from pyprojroot import here
from contextlib import nullcontext as does_not_raise

from transport_performance.utils.defence import (
_check_iterable,
Expand All @@ -21,6 +22,7 @@
_is_expected_filetype,
_enforce_file_extension,
)
from transport_performance.gtfs.validation import GtfsInstance


class Test_CheckIter(object):
Expand Down Expand Up @@ -251,6 +253,10 @@ def test__gtfs_defence():
),
):
_gtfs_defence("tester", "test")
# passing test
with does_not_raise():
gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip"))
_gtfs_defence(gtfs=gtfs, param_nm="gtfs")
r-leyshon marked this conversation as resolved.
Show resolved Hide resolved


class Test_TypeDefence(object):
Expand Down
Loading