Skip to content

Commit

Permalink
chore: update import script to handle format a bit nicer (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Volubyl committed May 3, 2023
1 parent 1469be5 commit ddeddac
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
12 changes: 4 additions & 8 deletions tests/tools/test_import_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ async def test_import_catalog_example(tmp_path: Path) -> None:
csv_path.write_text(
dedent(
"""\
titre;description;mots_cles;nom_orga;siret_orga;id_alt_orga;service;si;contact_service;contact_personne;date_pub;date_maj;freq_maj;couv_geo;url;formats;licence;donnees_geoloc
Titre1;Description1;"Tag 1,Tag 2";Ministère 1;11004601800013;;Direction1;SI1;[email protected];[email protected];;2022-10-06;annuelle;aquitaine;;geojson, xls, oracle et shp;etalab-2.0;oui
Titre2;Description2;"Tag 1,Tag 3";Ministère 1;11004601800013;;Direction1;SI2;[email protected];[email protected];;;Invalid;NSP;;Information manquante;etalab-2.0;oui
titre;description;mots_cles;nom_orga;siret_orga;id_alt_orga;service;si;contact_service;contact_personne;date_pub;date_maj;freq_maj;couv_geo;url;formats;licence;donnees_geoloc;publication_restriction
Titre1;Description1;"Tag 1,Tag 2";Ministère 1;11004601800013;;Direction1;SI1;[email protected];[email protected];;2022-10-06;annuelle;aquitaine;;geojson, xls, oracle et shp;etalab-2.0;oui;
Titre2;Description2;"Tag 1,Tag 3";Ministère 1;11004601800013;;Direction1;SI2;[email protected];[email protected];;;Invalid;NSP;;Information manquante;etalab-2.0;oui;
""" # noqa
)
)
Expand Down Expand Up @@ -78,11 +78,7 @@ async def test_import_catalog_example(tmp_path: Path) -> None:
assert all(d["params"]["organization_siret"] == siret for d in initdata["datasets"])

d0 = initdata["datasets"][0]["params"]
assert sorted(d0["formats"]) == [
"fichier SIG (Shapefile, ...)",
"fichier tabulaire (Excell, CSV,...)",
"oracle et shp",
]
assert sorted(d0["formats"]) == ["geojson", "oracle et shp", "xls"]
assert d0["geographical_coverage"] == "aquitaine"
assert d0["update_frequency"] == "yearly"
assert "[[ Notes d'import automatique ]]" not in d0["description"]
Expand Down
72 changes: 49 additions & 23 deletions tools/import_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from server.application.catalogs.queries import GetCatalogBySiret
from server.application.catalogs.views import ExtraFieldView
from server.application.dataformats.queries import GetAllDataFormat
from server.application.organizations.queries import GetOrganizationBySiret
from server.application.tags.queries import GetAllTags
from server.config.di import bootstrap, resolve
Expand Down Expand Up @@ -61,24 +62,27 @@ def _map_geographical_coverage(value: Optional[str], config: Config) -> str:


def _map_formats(
value: Optional[str], import_notes: TextIO, config: Config
value: Optional[str],
existing_format_names: List[str],
formats_to_create: List[str],
) -> List[str]:

if not value:
return []

def _map_format(value: str) -> List[str]:
format_names = []

dataformat = config.formats.map.get(value, value)
# Split and normalize tag names. For example:
# "périmètre délimité des abords (PDA), urbanisme; géolocalisation"
# -> {"périmètre délimité des abords (PDA)", "urbanisme", "géolocalisation"}
cleaned_names = set(name.strip() for name in value.replace(";", ",").split(","))

try:
return [dataformat]
except ValueError:
return [value]
for name in cleaned_names:
if name not in existing_format_names and name not in formats_to_create:
formats_to_create.append(name)

result = list(set(f for val in value.split(",") for f in _map_format(val.strip())))
format_names.append(name)

return result
return format_names


def _map_contact_emails(value: Optional[str]) -> List[str]:
Expand Down Expand Up @@ -128,14 +132,21 @@ def _map_tag_ids(
cleaned_names = set(name.strip() for name in value.replace(";", ",").split(","))

for name in cleaned_names:
try:
tag_id = existing_tag_ids_by_name[name]
except KeyError:
tag_id = id_factory()
tag = {"id": str(tag_id), "params": {"name": name}}
tags_to_create.append(tag)
if name in existing_tag_ids_by_name:
format_id = existing_tag_ids_by_name[name]
else:
res = next(
(sub for sub in tags_to_create if sub["params"]["name"] == name),
None,
)
if res is not None:
format_id = res["id"]
else:
format_id = id_factory()
tag = {"id": str(format_id), "params": {"name": name}}
tags_to_create.append(tag)

tag_ids.append(str(tag_id))
tag_ids.append(str(format_id))

return tag_ids

Expand Down Expand Up @@ -178,11 +189,18 @@ async def main(config_path: Path, out_path: Path) -> int:
tags = await bus.execute(GetAllTags())
existing_tag_ids_by_name = {tag.name: tag.id for tag in tags}

with config.input_csv.path.open(encoding=config.input_csv.encoding) as f:
formats = await bus.execute(GetAllDataFormat())
existing_format_name = [format.name for format in formats]

with config.input_csv.path.open(mode="r", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter=config.input_csv.delimiter)
fieldnames = list(reader.fieldnames or [])
rows = list(reader)

if reader.fieldnames is None:
raise ValueError("No field name found: is your CSV file well formatted ?")

fieldnames = set(reader.fieldnames)

common_fields = {
"titre",
"description",
Expand All @@ -199,9 +217,10 @@ async def main(config_path: Path, out_path: Path) -> int:
"url",
"licence",
"mots_cles",
"publication_restriction",
}

actual_extra_fields = set(fieldnames) - common_fields - config.ignore_fields
actual_extra_fields = fieldnames - common_fields - config.ignore_fields

if actual_extra_fields != expected_extra_fields:
raise ValueError(
Expand All @@ -210,9 +229,11 @@ async def main(config_path: Path, out_path: Path) -> int:
)

tags_to_create: List[dict] = []
formats_to_create: List[str] = []
datasets = []

for k, row in enumerate(rows):

if (siret_orga := row["siret_orga"]) != organization.siret:
raise ValueError(
f"at row {k}: {siret_orga=!r} does not match {organization.siret=!r}"
Expand All @@ -228,13 +249,14 @@ async def main(config_path: Path, out_path: Path) -> int:
params: dict = {}

params["organization_siret"] = organization.siret
params["publication_restriction"] = row["publication_restriction"]
params["title"] = row["titre"]
params["description"] = row["description"]
params["service"] = row["service"] or None
params["geographical_coverage"] = _map_geographical_coverage(
row["couv_geo"] or None, config
)
params["formats"] = _map_formats(row["formats"] or None, import_notes, config)

params["technical_source"] = row["si"] or None
params["producer_email"] = row["contact_service"] or None
params["contact_emails"] = _map_contact_emails(row["contact_personne"] or None)
Expand All @@ -250,6 +272,10 @@ async def main(config_path: Path, out_path: Path) -> int:
row["mots_cles"] or None, existing_tag_ids_by_name, tags_to_create
)

params["formats"] = _map_formats(
row["formats"] or None, existing_format_name, formats_to_create
)

params["extra_field_values"] = _map_extra_field_values(
row, catalog.extra_fields
)
Expand All @@ -267,10 +293,10 @@ async def main(config_path: Path, out_path: Path) -> int:
users=[],
tags=tags_to_create,
datasets=datasets,
formats=[],
formats=formats_to_create,
).dict()

out_path.write_text(yaml.safe_dump(initdata))
out_path.write_text(yaml.safe_dump(initdata, allow_unicode=True), encoding="utf-8")

return 0

Expand Down

0 comments on commit ddeddac

Please sign in to comment.