diff --git a/nshmdb/nshmdb.py b/nshmdb/nshmdb.py index bae1a1a..a85ecea 100644 --- a/nshmdb/nshmdb.py +++ b/nshmdb/nshmdb.py @@ -440,7 +440,21 @@ def get_fault_names(self) -> set[str]: name for (name,) in conn.execute("SELECT name FROM parent_fault").fetchall() } + + def get_fault_ids(self) -> set[int]: + """Get the list of fault ids in the database. + Returns + ------- + set[int] + The list of fault ids. + """ + with self.connection() as conn: + return { + fault_id + for (fault_id,) in conn.execute("SELECT fault_id FROM fault").fetchall() + } + def query( self, query_str: str, diff --git a/nshmdb/scripts/nshm_db_generator.py b/nshmdb/scripts/nshm_db_generator.py index 84ce0bd..c760112 100644 --- a/nshmdb/scripts/nshm_db_generator.py +++ b/nshmdb/scripts/nshm_db_generator.py @@ -302,3 +302,6 @@ def main( rupture_fault_join_df.to_sql( "rupture_faults", conn, index=False, if_exists="append" ) + +if __name__ == "__main__": + app()