diff --git a/lib/data/data_loader.py b/lib/data/data_loader.py index b674ed0..ea58b69 100644 --- a/lib/data/data_loader.py +++ b/lib/data/data_loader.py @@ -118,7 +118,8 @@ def load_inat_dataset_from_parquet_h3(spatial_data_file, h3_resolution): return locs, class_ids, unique_taxa, h3_idx -def load_inat_dataset_from_parquet(spatial_data_file): + +def load_inat_dataset_from_parquet(spatial_data_file, inner_nodes=False): print("inat style dataset") print(" reading parquet") spatial_data = pd.read_parquet( @@ -131,7 +132,13 @@ def load_inat_dataset_from_parquet(spatial_data_file): "spatial_class_id", ], ) - spatial_data = spatial_data.dropna(subset="leaf_class_id") + + if not inner_nodes: + spatial_data = spatial_data.dropna(subset="leaf_class_id") + + # we won't need this anymore + _ = spatial_data.pop("leaf_class_id") + print(" cleaning dataset") spatial_data = clean_dataset(spatial_data) print(" shuffling") diff --git a/train_tf_sinr.py b/train_tf_sinr.py index 3df0874..513b284 100644 --- a/train_tf_sinr.py +++ b/train_tf_sinr.py @@ -68,7 +68,7 @@ def train_model(config_file): ) elif config["dataset_type"] == "inat": (locs, class_ids, unique_taxa) = load_inat_dataset_from_parquet( - config["inat_dataset"]["spatial_data"] + config["inat_dataset"]["spatial_data"], config["inner_nodes"] ) if config["inputs"]["covariates"] == "env": diff --git a/utils/make_ancestor_map.py b/utils/make_ancestor_map.py new file mode 100644 index 0000000..0d33d0e --- /dev/null +++ b/utils/make_ancestor_map.py @@ -0,0 +1,59 @@ +import click +import pandas as pd +from tqdm.auto import tqdm + +@click.command() +@click.option("--taxonomy_file", type=str, required=True) +@click.option("--output_file", type=str, required=True) +def make_ancestor_map(taxonomy_file, output_file): + """ + Generates a flat (taxon_id, ancestor_id) mapping from a taxonomy file, + including self-links, for use in DuckDB joins. + Reads and writes CSV files. + """ + tax = pd.read_csv(taxonomy_file) + + parent_map = tax.set_index("taxon_id")[ + "parent_taxon_id" + ].dropna().astype(int).to_dict() + + ancestor_map = {} + def get_ancestors(taxon_id): + if taxon_id in ancestor_map: + return ancestor_map[taxon_id] + ancestors = [] + + child_taxon_id = taxon_id + while child_taxon_id in parent_map: + child_taxon_id = parent_map[child_taxon_id] + ancestors.append(child_taxon_id) + ancestor_map[taxon_id] = ancestors + return ancestors + + all_taxa = tax["taxon_id"].dropna().astype(int).unique() + for taxon_id in tqdm(all_taxa): + _ = get_ancestors(taxon_id) + + rows = [] + for taxon_id, ancestors in ancestor_map.items(): + for ancestor_id in ancestors: + rows.append({ + "taxon_id": taxon_id, + "ancestor_id": ancestor_id, + }) + df = pd.DataFrame(rows) + + # include self + self_rows = [ + { "taxon_id": taxon_id, "ancestor_id": taxon_id } + for taxon_id in ancestor_map.keys() + ] + df = pd.concat([ + df, pd.DataFrame(self_rows) + ], ignore_index=True) + + df.to_csv(output_file, index=False) + + +if __name__ == "__main__": + make_ancestor_map() diff --git a/utils/make_inner_node_training_data.sql b/utils/make_inner_node_training_data.sql new file mode 100644 index 0000000..5aaa988 --- /dev/null +++ b/utils/make_inner_node_training_data.sql @@ -0,0 +1,32 @@ +SET taxonomy_file = '$taxonomy_file'; +SET ancestor_map_file = '$ancestor_map_file'; +SET spatial_data_file = '$spatial_data_file'; +SET output_file = '$output_file'; +SET sample_cap = $sample_cap; + +CREATE TABLE taxonomy as SELECT * from read_csv_auto(taxonomy_file); +CREATE TABLE ancestor_map AS SELECT * from read_csv_auto(ancestor_map_file); +CREATE TABLE spatial_data as SELECT latitude, longitude, taxon_id from read_parquet(spatial_data_file); + +CREATE TABLE expanded AS +SELECT + s.latitude, + s.longitude, + a.ancestor_id AS taxon_id +FROM spatial_data s +JOIN ancestor_map a +ON s.taxon_id = a.taxon_id; + +CREATE TABLE sampled AS +SELECT latitude, longitude, taxon_id +FROM ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY taxon_id + ORDER BY RANDOM() + ) AS row_num + FROM expanded +) +WHERE row_num <= sample_cap; + +COPY sampled to output_file (FORMAT_PARQUET);