Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions lib/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def load_inat_dataset_from_parquet_h3(spatial_data_file, h3_resolution):
return locs, class_ids, unique_taxa, h3_idx


def load_inat_dataset_from_parquet(spatial_data_file):

def load_inat_dataset_from_parquet(spatial_data_file, inner_nodes=False):
print("inat style dataset")
print(" reading parquet")
spatial_data = pd.read_parquet(
Expand All @@ -131,7 +132,13 @@ def load_inat_dataset_from_parquet(spatial_data_file):
"spatial_class_id",
],
)
spatial_data = spatial_data.dropna(subset="leaf_class_id")

if not inner_nodes:
spatial_data = spatial_data.dropna(subset="leaf_class_id")

# we won't need this anymore
_ = spatial_data.pop("leaf_class_id")

print(" cleaning dataset")
spatial_data = clean_dataset(spatial_data)
print(" shuffling")
Expand Down
2 changes: 1 addition & 1 deletion train_tf_sinr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def train_model(config_file):
)
elif config["dataset_type"] == "inat":
(locs, class_ids, unique_taxa) = load_inat_dataset_from_parquet(
config["inat_dataset"]["spatial_data"]
config["inat_dataset"]["spatial_data"], config["inner_nodes"]
)

if config["inputs"]["covariates"] == "env":
Expand Down
59 changes: 59 additions & 0 deletions utils/make_ancestor_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import click
import pandas as pd
from tqdm.auto import tqdm

@click.command()
@click.option("--taxonomy_file", type=str, required=True)
@click.option("--output_file", type=str, required=True)
def make_ancestor_map(taxonomy_file, output_file):
"""
Generates a flat (taxon_id, ancestor_id) mapping from a taxonomy file,
including self-links, for use in DuckDB joins.
Reads and writes CSV files.
"""
tax = pd.read_csv(taxonomy_file)

parent_map = tax.set_index("taxon_id")[
"parent_taxon_id"
].dropna().astype(int).to_dict()

ancestor_map = {}
def get_ancestors(taxon_id):
if taxon_id in ancestor_map:
return ancestor_map[taxon_id]
ancestors = []

child_taxon_id = taxon_id
while child_taxon_id in parent_map:
child_taxon_id = parent_map[child_taxon_id]
ancestors.append(child_taxon_id)
ancestor_map[taxon_id] = ancestors
return ancestors

all_taxa = tax["taxon_id"].dropna().astype(int).unique()
for taxon_id in tqdm(all_taxa):
_ = get_ancestors(taxon_id)

rows = []
for taxon_id, ancestors in ancestor_map.items():
for ancestor_id in ancestors:
rows.append({
"taxon_id": taxon_id,
"ancestor_id": ancestor_id,
})
df = pd.DataFrame(rows)

# include self
self_rows = [
{ "taxon_id": taxon_id, "ancestor_id": taxon_id }
for taxon_id in ancestor_map.keys()
]
df = pd.concat([
df, pd.DataFrame(self_rows)
], ignore_index=True)

df.to_csv(output_file, index=False)


if __name__ == "__main__":
make_ancestor_map()
32 changes: 32 additions & 0 deletions utils/make_inner_node_training_data.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
SET taxonomy_file = '$taxonomy_file';
SET ancestor_map_file = '$ancestor_map_file';
SET spatial_data_file = '$spatial_data_file';
SET output_file = '$output_file';
SET sample_cap = $sample_cap;

CREATE TABLE taxonomy as SELECT * from read_csv_auto(taxonomy_file);
CREATE TABLE ancestor_map AS SELECT * from read_csv_auto(ancestor_map_file);
CREATE TABLE spatial_data as SELECT latitude, longitude, taxon_id from read_parquet(spatial_data_file);

CREATE TABLE expanded AS
SELECT
s.latitude,
s.longitude,
a.ancestor_id AS taxon_id
FROM spatial_data s
JOIN ancestor_map a
ON s.taxon_id = a.taxon_id;

CREATE TABLE sampled AS
SELECT latitude, longitude, taxon_id
FROM (
SELECT *,
ROW_NUMBER() OVER (
PARTITION BY taxon_id
ORDER BY RANDOM()
) AS row_num
FROM expanded
)
WHERE row_num <= sample_cap;

COPY sampled to output_file (FORMAT_PARQUET);