Skip to content
Merged

Lint #62

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
poetry run isort -c .
- name: Lint with pylint
run: |
poetry run pylint orm_importer
poetry run pylint orm_importer --fail-under 8

test:

Expand Down
51 changes: 30 additions & 21 deletions orm_importer/importer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List
from typing import Any, List, Optional

import networkx as nx
import overpy
Expand All @@ -16,11 +16,11 @@
get_additional_signals,
get_opposite_edge_pairs,
get_signal_classification_number,
get_signal_direction,
get_signal_function,
get_signal_kind,
get_signal_name,
get_signal_states,
getSignalDirection,
is_end_node,
is_same_edge,
is_signal,
Expand All @@ -35,15 +35,18 @@ def __init__(self):
self.top_nodes: list[OverpyNode] = []
self.node_data: dict[str, OverpyNode] = {}
self.ways: dict[str, List[overpy.Way]] = defaultdict(list)
self.paths: dict[str, List[List]] = defaultdict(list)
self.paths: dict[tuple[Optional[Any], Optional[Any]], List[List]] = defaultdict(list)
self.api = overpy.Overpass(url="https://osm.hpi.de/overpass/api/interpreter")
self.topology = Topology()

def _get_track_objects(self, polygon: str, railway_option_types: list[str]):
query_parts = ""
for type in railway_option_types:
query_parts = query_parts + f'way["railway"="{type}"](poly: "{polygon}");node(w)(poly: "{polygon}");'
query = f'({query_parts});out body;'
for _type in railway_option_types:
query_parts = (
query_parts
+ f'way["railway"="{_type}"](poly: "{polygon}");node(w)(poly: "{polygon}");'
)
query = f"({query_parts});out body;"
print(query)
return self._query_api(query)

Expand All @@ -52,21 +55,21 @@ def _query_api(self, query):
return result

def _build_graph(self, track_objects):
G = nx.Graph()
graph = nx.Graph()
for way in track_objects.ways:
previous_node = None
for idx, node_id in enumerate(way._node_ids):
try:
node = track_objects.get_node(node_id)
self.node_data[node_id] = node
self.ways[str(node_id)].append(way)
G.add_node(node.id)
graph.add_node(node.id)
if previous_node:
G.add_edge(previous_node.id, node.id)
graph.add_edge(previous_node.id, node.id)
previous_node = node
except overpy.exception.DataIncomplete:
continue
return G
return graph

def _get_next_top_node(self, node, edge: "tuple[str, str]", path):
node_to_id = edge[1]
Expand Down Expand Up @@ -115,7 +118,7 @@ def _add_signals(self, path, edge: model.Edge, node_before, node_after):
signal_geo_point
),
side_distance=dist_edge(node_before, node_after, node),
direction=getSignalDirection(
direction=get_signal_direction(
edge, self.ways, path, node.tags["railway:signal:direction"]
),
function=get_signal_function(node),
Expand All @@ -134,8 +137,8 @@ def _get_edge_speed(self, edge: Edge):
common_ways = ways_a.intersection(ways_b)
if len(common_ways) != 1:
return None
maxspeed = common_ways.pop().tags.get("maxspeed", None)
return int(maxspeed) if maxspeed else None
max_speed = common_ways.pop().tags.get("maxspeed", None)
return int(max_speed) if max_speed else None

def _should_add_edge(self, node_a: model.Node, node_b: model.Node, path: list[int]):
edge_not_present = not self.topology.get_edge_by_nodes(node_a, node_b)
Expand All @@ -145,11 +148,14 @@ def _should_add_edge(self, node_a: model.Node, node_b: model.Node, path: list[in
present_paths = self.paths[(node_a, node_b)] + self.paths[(node_b, node_a)]
return path not in present_paths and reversed_path not in present_paths

def run(self, polygon, railway_option_types):
def run(self, polygon, railway_option_types: list[str] = None):
if railway_option_types is None:
railway_option_types = ["rail"]
track_objects = self._get_track_objects(polygon, railway_option_types)
self.graph = self._build_graph(track_objects)

# ToDo: Check whether all edges really link to each other in ORM or if there might be edges missing for nodes that are just a few cm from each other
# ToDo: Check whether all edges really link to each other in ORM or if there might be
# edges missing for nodes that are just a few cm from each other
# Only nodes with max 1 edge or that are a switch can be top nodes
for node_id in self.graph.nodes:
node = self.node_data[node_id]
Expand Down Expand Up @@ -203,7 +209,8 @@ def run(self, polygon, railway_option_types):
e for e in self.topology.edges.values() if e.node_a == node or e.node_b == node
]

# merge edges, this means removing the switch and allowing only one path for each origin
# merge edges, this means removing the switch and
# allowing only one path for each origin
edge_pair_1, edge_pair_2 = get_opposite_edge_pairs(connected_edges, node)
new_edge_1 = merge_edges(*edge_pair_1, node)
new_edge_2 = merge_edges(*edge_pair_2, node)
Expand All @@ -228,8 +235,10 @@ def run(self, polygon, railway_option_types):
except DataIncomplete:
nodes = way.get_nodes(resolve_missing=True)
for candidate in nodes:
# we are only interested in nodes outside the bounding box as every node
# that has been previously known was already visited as part of the graph
# we are only interested in nodes outside the
# bounding box as every node that has been
# previously known was already visited as
# part of the graph
if (
candidate.id != int(node.name)
and candidate.id not in self.node_data.keys()
Expand All @@ -248,9 +257,9 @@ def run(self, polygon, railway_option_types):
break
if not substitute_found:
# if no substitute was found, the third node seems to be inside the bounding box
# this can happen when a node is connected to the same node twice (e.g. station on
# lines with only one track). WARNING: this produced weird results in the past.
# It should be okay to do it after the check above.
# this can happen when a node is connected to the same node twice (e.g.
# station on lines with only one track). WARNING: this produced weird
# results in the past. It should be okay to do it after the check above.
connected_edges = [
e
for e in self.topology.edges.values()
Expand Down
40 changes: 18 additions & 22 deletions orm_importer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
def dist_edge(node_before, node_after, signal):
# Calculate distance from point(signal) to edge between node before and after
# TODO: Validate that this is really correct!
return 3.95
p1 = np.array((node_before.lat, node_before.lon))
p2 = np.array((node_after.lat, node_after.lon))
p3 = np.array((signal.lat, signal.lon))
Expand All @@ -41,7 +40,8 @@ def is_end_node(node, graph):

def is_signal(node):
# we cannot use railway=signal as a condition, as buffer stops violate this assumption.
# Instead, we check for the signal direction as we cannot create a signal without a direction anyway
# Instead, we check for the signal direction as we cannot create a
# signal without a direction anyway
return "railway:signal:direction" in node.tags.keys()


Expand All @@ -61,7 +61,7 @@ def is_same_edge(e1: tuple, e2: tuple):
return False


def getSignalDirection(edge: Edge, ways: dict[str, List[Way]], path, signal_direction_tag: str):
def get_signal_direction(edge: Edge, ways: dict[str, List[Way]], path, signal_direction_tag: str):
edge_is_forward = None
for way in ways[edge.node_a.name]:
node_a = int(edge.node_a.name)
Expand Down Expand Up @@ -90,12 +90,12 @@ def getSignalDirection(edge: Edge, ways: dict[str, List[Way]], path, signal_dire
not edge_is_forward and signal_direction_tag == "backward"
):
return "in"
else:
return "gegen"
return "gegen"


def get_signal_states(signal_tags: dict):
# Sh0 is tagged as Hp0 in OSM since a few years, but not all tags have been replaced so we convert them
# Sh0 is tagged as Hp0 in OSM since a few years, but not all tags
# have been replaced so we convert them
raw_states = []
raw_states += signal_tags.get("railway:signal:main:states", "").split(";")
raw_states += signal_tags.get("railway:signal:combined:states", "").split(";")
Expand Down Expand Up @@ -160,12 +160,11 @@ def get_signal_function(signal: Node) -> str:
tag = next(t for t in signal.tags.keys() if t.endswith(":function"))
if signal.tags[tag] == "entry":
return "Einfahr_Signal"
elif signal.tags[tag] == "exit":
if signal.tags[tag] == "exit":
return "Ausfahr_Signal"
elif signal.tags[tag] == "block":
if signal.tags[tag] == "block":
return "Block_Signal"
else:
return "andere"
return "andere"
except StopIteration:
return "andere"

Expand All @@ -176,36 +175,33 @@ def get_signal_kind(signal: Node) -> str:
# ORM Reference: https://wiki.openstreetmap.org/wiki/OpenRailwayMap/Tagging/Signal
if "railway:signal:main" in signal.tags.keys():
return "Hauptsignal"
elif "railway:signal:distant" in signal.tags.keys():
if "railway:signal:distant" in signal.tags.keys():
return "Vorsignal"
elif "railway:signal:combined" in signal.tags.keys():
if "railway:signal:combined" in signal.tags.keys():
return "Mehrabschnittssignal"
elif "railway:signal:shunting" in signal.tags.keys() or (
if "railway:signal:shunting" in signal.tags.keys() or (
"railway:signal:minor" in signal.tags.keys()
and (
signal.tags["railway:signal:minor"] == "DE-ESO:sh0"
or signal.tags["railway:signal:minor"] == "DE-ESO:sh2"
)
):
return "Sperrsignal"
elif (
"railway:signal:main" in signal.tags.keys() and "railway:signal:minor" in signal.tags.keys()
):
if "railway:signal:main" in signal.tags.keys() and "railway:signal:minor" in signal.tags.keys():
return "Hauptsperrsignal"
# Names in comment are not yet supported by PlanPro generator
elif "railway:signal:main_repeated" in signal.tags.keys():
if "railway:signal:main_repeated" in signal.tags.keys():
return "andere" # 'Vorsignalwiederholer'
elif "railway:signal:minor" in signal.tags.keys():
if "railway:signal:minor" in signal.tags.keys():
return "andere" # 'Zugdeckungssignal'
elif "railway:signal:crossing" in signal.tags.keys():
if "railway:signal:crossing" in signal.tags.keys():
return "andere" # 'Überwachungssignal'
elif (
if (
"railway:signal:combined" in signal.tags.keys()
and "railway:signal:minor" in signal.tags.keys()
):
return "andere" # 'Mehrabschnittssperrsignal'
else:
return "andere"
return "andere"


def get_signal_name(node: Node):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def test_query_griebnitzsee(mock_converter):

assert len(res.nodes) == 10
assert len(res.edges) == 9
assert len(res.signals) == 9
assert len(res.signals) == 17