diff --git a/osm_rawdata/importer.py b/osm_rawdata/importer.py index 86565eb..6406254 100755 --- a/osm_rawdata/importer.py +++ b/osm_rawdata/importer.py @@ -23,7 +23,14 @@ import logging import subprocess import sys +import os +import concurrent.futures +import geojson +from geojson import Feature, FeatureCollection from sys import argv +from pathlib import Path +from cpuinfo import get_cpu_info +from shapely.geometry import shape import pyarrow.parquet as pq from codetiming import Timer @@ -34,6 +41,9 @@ from sqlalchemy.dialects.postgresql import JSONB, insert from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import create_database, database_exists +from sqlalchemy.engine.base import Connection +from shapely.geometry import Point, LineString, Polygon +from shapely import wkt, wkb # Find the other files for this project import osm_rawdata as rw @@ -45,6 +55,68 @@ # Instantiate logger log = logging.getLogger(__name__) +# The number of threads is based on the CPU cores +info = get_cpu_info() +cores = info['count'] + +def importThread( + data: list, + db: Connection, + ): + """Thread to handle importing + + Args: + data (list): The list of tiles to download + db + """ + log.debug(f"In importThread()") + #timer = Timer(text="importThread() took {seconds:.0f}s") + #timer.start() + ways = table( + "ways_poly", + column("id"), + column("user"), + column("geom"), + column("tags"), + ) + + nodes = table( + "nodes", + column("id"), + column("user"), + column("geom"), + column("tags"), + ) + + index = 0 + log.debug(f"DATA:{index} {len(data)}") + for feature in data: + log.debug(feature) + index -= 1 + entry = dict() + tags = feature['properties'] + tags['building'] = 'yes' + entry['id'] = index + ewkt = shape(feature["geometry"]) + geom = wkb.dumps(ewkt) + type = ewkt.geom_type + scalar = select(cast(tags, JSONB)) + + if type == 'Polygon': + sql = insert(ways).values( + # id = entry['id'], + geom=geom, + tags=scalar, + ) + elif type == 'Point': + sql = insert(nodes).values( + # id = entry['id'], + geom=geom, + tags=scalar, + ) + + db.execute(sql) + # db.commit() class MapImporter(object): def __init__( @@ -220,17 +292,46 @@ def importGeoJson( Returns: (bool): Whether the import finished sucessfully """ - engine = create_engine(f"postgresql://{self.dburi}", echo=True) - if not database_exists(engine.url): - create_database(engine.url) - else: - engine.connect() + # load the GeoJson file + file = open(infile, "r") + #size = os.path.getsize(infile) + #for line in file.readlines(): + # print(line) + data = geojson.load(file) + + future = None + result = None + index = 0 + connections = list() + + for thread in range(0, cores + 1): + engine = create_engine(f"postgresql://{self.dburi}", echo=False) + if not database_exists(engine.url): + create_database(engine.url) + connections.append(engine.connect()) + sessionmaker(autocommit=False, autoflush=False, bind=engine) + + if thread == 0: + meta = MetaData() + meta.create_all(engine) - sessionmaker(autocommit=False, autoflush=False, bind=engine) + # A chunk is a group of threads + entries = len(data['features']) + chunk = round(entries / cores) - meta = MetaData() - meta.create_all(engine) + if entries <= chunk: + result = importThread(data['features'], connections[0]) + return True + with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor: + block = 0 + while block <= entries: + log.debug("Dispatching Block %d:%d" % (block, block + chunk)) + result = executor.submit(importThread, data['features'][block : block + chunk], connections[index]) + block += chunk + index += 1 + executor.shutdown() + return True def main(): """This main function lets this class be run standalone by a bash script.""" @@ -260,11 +361,19 @@ def main(): ch.setFormatter(formatter) log.addHandler(ch) + # Create the database mi = MapImporter(args.uri) - if mi.importOSM(args.infile): - #if mi.importParquet(args.infile): - log.info(f"Imported {args.infile} into {args.uri}") + path = Path(args.infile) + + # And populate it with data + if path.suffix == ".osm": + mi.importOSM(args.infile) + elif path.suffix == ".geojson": + mi.importGeoJson(args.infile) + else: + mi.importParquet(args.infile) + log.info(f"Imported {args.infile} into {args.uri}") if __name__ == "__main__": """This is just a hook so this file can be run standalone during development."""