diff --git a/nshmdb/nshmdb.py b/nshmdb/nshmdb.py index 42e43c9..1739048 100644 --- a/nshmdb/nshmdb.py +++ b/nshmdb/nshmdb.py @@ -361,7 +361,7 @@ def get_rupture_faults(self, rupture_id: int) -> dict[str, Fault]: (rupture_id,), ) fault_planes = cursor.fetchall() - faults = collections.defaultdict(lambda: Fault([])) + faults = collections.defaultdict(lambda: []) for ( _, top_left_lat, @@ -386,10 +386,10 @@ def get_rupture_faults(self, rupture_id: int) -> dict[str, Fault]: [bottom_left_lat, bottom_left_lon, bottom], ] ) - faults[parent_name].planes.append( + faults[parent_name].append( Plane(coordinates.wgs_depth_to_nztm(corners)) ) - return faults + return {name: Fault(planes) for name, planes in faults.items()} def get_rupture_fault_info(self, rupture_id: int) -> dict[str, FaultInfo]: """Get the rupture fault information for a given rupture.