From a4cd5501cd00399b520e42cb6bdaa487a8287e95 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sat, 15 Jul 2023 20:45:32 +0100 Subject: [PATCH] Created load and save model run functions --- src/soundevent/data/processed_clip.py | 8 + src/soundevent/io/__init__.py | 9 + src/soundevent/io/annotation_projects.py | 123 +++--- src/soundevent/io/datasets.py | 105 +---- src/soundevent/io/formats/__init__.py | 39 ++ .../io/{format.py => formats/aoef.py} | 362 +++++++++++++++++- src/soundevent/io/model_runs.py | 151 ++++++++ src/soundevent/io/types.py | 42 ++ 8 files changed, 690 insertions(+), 149 deletions(-) create mode 100644 src/soundevent/io/formats/__init__.py rename src/soundevent/io/{format.py => formats/aoef.py} (66%) create mode 100644 src/soundevent/io/types.py diff --git a/src/soundevent/data/processed_clip.py b/src/soundevent/data/processed_clip.py index 6737ee4..0fe842b 100644 --- a/src/soundevent/data/processed_clip.py +++ b/src/soundevent/data/processed_clip.py @@ -35,6 +35,7 @@ into the predicted events and aid in subsequent analysis and interpretation. """ from typing import List +from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -47,6 +48,9 @@ class ProcessedClip(BaseModel): """Processed clip.""" + uuid: UUID = Field(default_factory=uuid4, repr=False) + """Unique identifier for the processed clip.""" + clip: Clip """The clip that was processed.""" @@ -58,3 +62,7 @@ class ProcessedClip(BaseModel): features: List[Feature] = Field(default_factory=list) """List of features associated with the clip.""" + + def __hash__(self): + """Hash function for the processed clip.""" + return hash(self.uuid) diff --git a/src/soundevent/io/__init__.py b/src/soundevent/io/__init__.py index 4170a94..ba95aa8 100644 --- a/src/soundevent/io/__init__.py +++ b/src/soundevent/io/__init__.py @@ -4,9 +4,18 @@ sound event data. """ +from soundevent.io.annotation_projects import ( + load_annotation_project, + save_annotation_project, +) from soundevent.io.datasets import load_dataset, save_dataset +from soundevent.io.model_runs import load_model_run, save_model_run __all__ = [ + "load_annotation_project", "load_dataset", + "load_model_run", + "save_annotation_project", "save_dataset", + "save_model_run", ] diff --git a/src/soundevent/io/annotation_projects.py b/src/soundevent/io/annotation_projects.py index 48fc180..9404d2b 100644 --- a/src/soundevent/io/annotation_projects.py +++ b/src/soundevent/io/annotation_projects.py @@ -1,73 +1,63 @@ """Save and loading functions for annotation projects.""" - -import os -import sys from pathlib import Path -from typing import Callable, Dict, Union +from typing import Dict from soundevent import data -from soundevent.io.format import AnnotationProjectObject, is_json - -if sys.version_info < (3, 8): - from typing_extensions import Protocol -else: - from typing import Protocol - - -PathLike = Union[str, os.PathLike] - - -class Saver(Protocol): - """Protocol for saving annotation projects.""" +from soundevent.io.formats import aoef, infer_format +from soundevent.io.types import Loader, PathLike, Saver - def __call__( - self, - project: data.AnnotationProject, - path: PathLike, - audio_dir: PathLike = ".", - ) -> None: - """Save annotation project to path.""" - ... +SAVE_FORMATS: Dict[str, Saver[data.AnnotationProject]] = {} +LOAD_FORMATS: Dict[str, Loader[data.AnnotationProject]] = {} -class Loader(Protocol): - """Protocol for loading annotation projects.""" +def load_annotation_project( + path: PathLike, + audio_dir: PathLike = ".", +) -> data.AnnotationProject: + """Load annotation project from path. - def __call__( - self, path: PathLike, audio_dir: PathLike = "." - ) -> data.AnnotationProject: - """Load annotation project from path.""" - ... + Parameters + ---------- + path: PathLike + Path to the file with the annotation project. + audio_dir: PathLike, optional + Path to the directory containing the audio files, by default ".". The + audio file paths in the annotation project will be relative to this + directory. -SAVE_FORMATS: Dict[str, Saver] = {} -LOAD_FORMATS: Dict[str, Loader] = {} -FORMATS: Dict[str, Callable[[PathLike], bool]] = {} + Returns + ------- + annotation_project: data.AnnotationProject + The loaded annotation project. + Raises + ------ + FileNotFoundError + If the path does not exist. -def load_annotation_project(path: PathLike) -> data.AnnotationProject: - """Load annotation project from path.""" + NotImplementedError + If the format of the file is not supported. + """ path = Path(path) if not path.exists(): raise FileNotFoundError(f"Path {path} does not exist.") - for format_name, is_format in FORMATS.items(): - if not is_format(path): - continue - - return LOAD_FORMATS[format_name](path) + try: + format_ = infer_format(path) + except ValueError as e: + raise NotImplementedError(f"File {path} format not supported.") from e - raise NotImplementedError( - f"Could not find a loader for {path}. " - f"Supported formats are: {list(FORMATS.keys())}" - ) + loader = LOAD_FORMATS[format_] + return loader(path, audio_dir=audio_dir) def save_annotation_project( project: data.AnnotationProject, path: PathLike, audio_dir: PathLike = ".", + format: str = "aoef", ) -> None: """Save annotation project to path. @@ -75,26 +65,36 @@ def save_annotation_project( ---------- project: data.AnnotationProject Annotation project to save. + path: PathLike Path to save annotation project to. + + audio_dir: PathLike, optional + Path to the directory containing the audio files, by default ".". The + audio file paths in the annotation project will be relative to this + directory. + + format: str, optional + Format to save the annotation project in, by default "aoef". + + Raises + ------ + NotImplementedError + If the format is not supported. + """ path = Path(path) - for format_name, is_format in FORMATS.items(): - if not is_format(path): - continue - - SAVE_FORMATS[format_name](project, path, audio_dir=audio_dir) - return + try: + saver = SAVE_FORMATS[format] + except KeyError as e: + raise NotImplementedError(f"Format {format} not supported.") from e - raise NotImplementedError( - f"Could not find a saver for {path}. " - f"Supported formats are: {list(FORMATS.keys())}" - ) + saver(project, path, audio_dir=audio_dir) def save_annotation_project_in_aoef_format( - project: data.AnnotationProject, + obj: data.AnnotationProject, path: PathLike, audio_dir: PathLike = ".", ) -> None: @@ -102,8 +102,8 @@ def save_annotation_project_in_aoef_format( path = Path(path) audio_dir = Path(audio_dir).resolve() annotation_project_object = ( - AnnotationProjectObject.from_annotation_project( - project, + aoef.AnnotationProjectObject.from_annotation_project( + obj, audio_dir=audio_dir, ) ) @@ -122,12 +122,11 @@ def load_annotation_project_in_aoef_format( """Load annotation project from path in AOEF format.""" path = Path(path) audio_dir = Path(audio_dir).resolve() - annotation_project_object = AnnotationProjectObject.model_validate_json( - path.read_text() + annotation_project_object = ( + aoef.AnnotationProjectObject.model_validate_json(path.read_text()) ) return annotation_project_object.to_annotation_project(audio_dir=audio_dir) SAVE_FORMATS["aoef"] = save_annotation_project_in_aoef_format LOAD_FORMATS["aoef"] = load_annotation_project_in_aoef_format -FORMATS["aoef"] = is_json diff --git a/src/soundevent/io/datasets.py b/src/soundevent/io/datasets.py index d862976..fe85ace 100644 --- a/src/soundevent/io/datasets.py +++ b/src/soundevent/io/datasets.py @@ -4,89 +4,21 @@ Datasets of recordings. """ -import os -import sys from pathlib import Path -from typing import Callable, Dict, Union +from typing import Dict from soundevent import data -from soundevent.io.format import DatasetObject, is_json - -if sys.version_info < (3, 6): - from typing_extensions import Literal -else: - from typing import Literal - -if sys.version_info < (3, 8): - from typing_extensions import Protocol -else: - from typing import Protocol +from soundevent.io.formats import aoef, infer_format +from soundevent.io.types import Loader, PathLike, Saver __all__ = [ "load_dataset", "save_dataset", ] -PathLike = Union[str, os.PathLike] - - -DatasetFormat = Literal["csv", "json"] - - -class Saver(Protocol): - def __call__( - self, - dataset: data.Dataset, - path: PathLike, - audio_dir: PathLike = ".", - ) -> None: - ... - - -class Loader(Protocol): - def __call__( - self, path: PathLike, audio_dir: PathLike = "." - ) -> data.Dataset: - ... - -Inferrer = Callable[[PathLike], bool] - -SAVE_FORMATS: Dict[DatasetFormat, Saver] = {} - -LOAD_FORMATS: Dict[DatasetFormat, Loader] = {} - -INFER_FORMATS: Dict[DatasetFormat, Inferrer] = { - "json": is_json, -} - - -def infer_format(path: PathLike) -> DatasetFormat: - """Infer the format of a file. - - Parameters - ---------- - path : Path - Path to the file to infer the format of. - - Returns - ------- - format : DatasetFormat - The inferred format of the file. - - Raises - ------ - ValueError - If the format of the file cannot be inferred. - - """ - for format_, inferrer in INFER_FORMATS.items(): - if inferrer(path): - return format_ - - raise ValueError( - f"Cannot infer format of file {path}, or format not supported." - ) +SAVE_FORMATS: Dict[str, Saver[data.Dataset]] = {} +LOAD_FORMATS: Dict[str, Loader[data.Dataset]] = {} def load_dataset(path: PathLike, audio_dir: PathLike = ".") -> data.Dataset: @@ -112,6 +44,11 @@ def load_dataset(path: PathLike, audio_dir: PathLike = ".") -> data.Dataset: If the format of the file is not supported. """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Path {path} does not exist.") + try: format_ = infer_format(path) except ValueError as e: @@ -125,7 +62,7 @@ def save_dataset( dataset: data.Dataset, path: PathLike, audio_dir: PathLike = ".", - format: DatasetFormat = "json", + format: str = "aoef", ) -> None: """Save a Dataset to a file. @@ -138,7 +75,7 @@ def save_dataset( audio_dir : Path, optional Path to the directory containing the audio files, by default ".". format : DatasetFormat, optional - The format to save the dataset in, by default "json". + The format to save the dataset in, by default "aoef". Raises ------ @@ -154,7 +91,7 @@ def save_dataset( saver(dataset, path, audio_dir=audio_dir) -def load_dataset_json_format( +def load_dataset_aoef_format( path: PathLike, audio_dir: PathLike = ".", ) -> data.Dataset: @@ -177,16 +114,13 @@ def load_dataset_json_format( audio_dir = Path(audio_dir).resolve() with open(path, "r") as f: - dataset = DatasetObject.model_validate_json(f.read()) + dataset = aoef.DatasetObject.model_validate_json(f.read()) return dataset.to_dataset(audio_dir=audio_dir) -LOAD_FORMATS["json"] = load_dataset_json_format - - -def save_dataset_json_format( - dataset: data.Dataset, +def save_dataset_aoef_format( + obj: data.Dataset, path: PathLike, audio_dir: PathLike = ".", ) -> None: @@ -205,8 +139,8 @@ def save_dataset_json_format( """ audio_dir = Path(audio_dir).resolve() - dataset_object = DatasetObject.from_dataset( - dataset, + dataset_object = aoef.DatasetObject.from_dataset( + obj, audio_dir=audio_dir, ) @@ -214,4 +148,5 @@ def save_dataset_json_format( f.write(dataset_object.model_dump_json(indent=None, exclude_none=True)) -SAVE_FORMATS["json"] = save_dataset_json_format +SAVE_FORMATS["aoef"] = save_dataset_aoef_format +LOAD_FORMATS["aoef"] = load_dataset_aoef_format diff --git a/src/soundevent/io/formats/__init__.py b/src/soundevent/io/formats/__init__.py new file mode 100644 index 0000000..6b64375 --- /dev/null +++ b/src/soundevent/io/formats/__init__.py @@ -0,0 +1,39 @@ +"""Storage formats for soundevent objects.""" +from soundevent.io.formats.aoef import is_json +from soundevent.io.types import PathLike + +__all__ = [ + "infer_format", +] + +FORMATS = { + "aoef": is_json, +} + + +def infer_format(path: PathLike) -> str: + """Infer the format of a file. + + Parameters + ---------- + path : Path + Path to the file to infer the format of. + + Returns + ------- + format : str + The inferred format of the file. + + Raises + ------ + ValueError + If the format of the file cannot be inferred. + + """ + for format_, inferrer in FORMATS.items(): + if inferrer(path): + return format_ + + raise ValueError( + f"Cannot infer format of file {path}, or format not supported." + ) diff --git a/src/soundevent/io/format.py b/src/soundevent/io/formats/aoef.py similarity index 66% rename from src/soundevent/io/format.py rename to src/soundevent/io/formats/aoef.py index 1422245..164a116 100644 --- a/src/soundevent/io/format.py +++ b/src/soundevent/io/formats/aoef.py @@ -28,10 +28,10 @@ import os import sys from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from uuid import UUID, uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, Field from soundevent import data @@ -795,6 +795,364 @@ def to_annotation_project( ) +class PredictedSoundEventObject(BaseModel): + id: int + + uuid: UUID + + sound_event: int + + score: float = Field(..., ge=0.0, le=1.0) + + tags: Optional[List[Tuple[int, float]]] = None + + features: Optional[Dict[str, float]] = None + + @classmethod + def from_predicted_sound_event( + cls, + predicted_sound_event: data.PredictedSoundEvent, + predicted_sound_events: Dict[data.PredictedSoundEvent, Self], + recordings: Dict[data.Recording, RecordingObject], + sound_events: Dict[data.SoundEvent, SoundEventObject], + tags: Dict[data.Tag, TagObject], + audio_dir: Path = Path("."), + ) -> Self: + """Convert a predicted sound event to a predicted sound event object.""" + if predicted_sound_event in predicted_sound_events: + return predicted_sound_events[predicted_sound_event] + + predicted_tags = [ + (TagObject.from_tag(tag.tag, tags).id, tag.score) + for tag in predicted_sound_event.tags + ] + + return cls( + id=len(predicted_sound_events), + uuid=predicted_sound_event.id, + sound_event=SoundEventObject.from_sound_event( + predicted_sound_event.sound_event, + sound_events=sound_events, + recordings=recordings, + tags=tags, + audio_dir=audio_dir, + ).id, + score=predicted_sound_event.score, + tags=predicted_tags if predicted_tags else None, + features={ + feature.name: feature.value + for feature in predicted_sound_event.features + } + if predicted_sound_event.features + else None, + ) + + def to_predicted_sound_event( + self, + recordings: Optional[Dict[int, data.Recording]] = None, + sound_events: Optional[Dict[int, data.SoundEvent]] = None, + tags: Optional[Dict[int, data.Tag]] = None, + ) -> data.PredictedSoundEvent: + """Convert a predicted sound event object to a predicted sound event.""" + if recordings is None: + recordings = {} + + if sound_events is None: + sound_events = {} + + if tags is None: + tags = {} + + if self.sound_event not in sound_events: + raise ValueError( + f"Sound event with ID {self.sound_event} not found " + "in sound events." + ) + + return data.PredictedSoundEvent( + id=self.uuid, + sound_event=sound_events[self.sound_event], + score=self.score, + tags=[ + data.PredictedTag( + tag=tags[tag[0]], + score=tag[1], + ) + for tag in self.tags or [] + if tag[0] in tags + ], + features=[ + data.Feature( + name=feature[0], + value=feature[1], + ) + for feature in (self.features or {}).items() + ], + ) + + +class ProcessedClipObject(BaseModel): + id: int + + uuid: UUID + + clip: int + + sound_events: Optional[List[int]] = None + + tags: Optional[List[Tuple[int, float]]] = None + + features: Optional[Dict[str, float]] = None + + @classmethod + def from_processed_clip( + cls, + processed_clip: data.ProcessedClip, + processed_clips: Dict[data.ProcessedClip, Self], + recordings: Dict[data.Recording, RecordingObject], + clips: Dict[data.Clip, ClipObject], + tags: Dict[data.Tag, TagObject], + sound_events: Dict[data.SoundEvent, SoundEventObject], + predicted_sound_events: Dict[ + data.PredictedSoundEvent, PredictedSoundEventObject + ], + audio_dir: Path = Path("."), + ) -> Self: + """Convert a processed clip to a processed clip object.""" + if processed_clip in processed_clips: + return processed_clips[processed_clip] + + predicted_sound_events_ids = [ + PredictedSoundEventObject.from_predicted_sound_event( + sound_event, + predicted_sound_events=predicted_sound_events, + recordings=recordings, + sound_events=sound_events, + tags=tags, + audio_dir=audio_dir, + ).id + for sound_event in processed_clip.sound_events + ] + + predicted_tags = [ + (TagObject.from_tag(tag.tag, tags).id, tag.score) + for tag in processed_clip.tags + ] + + return cls( + id=len(processed_clips), + uuid=processed_clip.uuid, + clip=ClipObject.from_clip( + processed_clip.clip, + clips=clips, + recordings=recordings, + tags=tags, + audio_dir=audio_dir, + ).id, + sound_events=predicted_sound_events_ids + if predicted_sound_events_ids + else None, + tags=predicted_tags if predicted_tags else None, + features={ + feature.name: feature.value + for feature in processed_clip.features + } + if processed_clip.features + else None, + ) + + def to_processed_clip( + self, + clips: Optional[Dict[int, data.Clip]] = None, + sound_events: Optional[Dict[int, data.PredictedSoundEvent]] = None, + tags: Optional[Dict[int, data.Tag]] = None, + ) -> data.ProcessedClip: + """Convert a processed clip object to a processed clip.""" + if clips is None: + clips = {} + + if sound_events is None: + sound_events = {} + + if tags is None: + tags = {} + + if self.clip not in clips: + raise ValueError(f"Clip with ID {self.clip} not found in clips.") + + return data.ProcessedClip( + uuid=self.uuid, + clip=clips[self.clip], + sound_events=[ + sound_events[sound_event] + for sound_event in (self.sound_events or []) + if sound_event in sound_events + ], + tags=[ + data.PredictedTag( + tag=tags[tag[0]], + score=tag[1], + ) + for tag in self.tags or [] + if tag[0] in tags + ], + features=[ + data.Feature( + name=feature[0], + value=feature[1], + ) + for feature in (self.features or {}).items() + ], + ) + + +class ModelRunInfo(BaseModel): + uuid: UUID + + model: str + + description: Optional[str] = None + + date_created: datetime.datetime + + @classmethod + def from_model_run( + cls, + model_run: data.ModelRun, + date_created: Optional[datetime.datetime] = None, + ) -> Self: + """Convert a model run to a model run object.""" + if date_created is None: + date_created = datetime.datetime.now() + return cls( + uuid=model_run.id, + model=model_run.model, + date_created=date_created, + ) + + +class ModelRunObject(BaseModel): + info: ModelRunInfo + + tags: Optional[List[TagObject]] = None + + recordings: Optional[List[RecordingObject]] = None + + clips: Optional[List[ClipObject]] = None + + processed_clips: Optional[List[ProcessedClipObject]] = None + + sound_events: Optional[List[SoundEventObject]] = None + + predicted_sound_events: Optional[List[PredictedSoundEventObject]] = None + + @classmethod + def from_model_run( + cls, + model_run: data.ModelRun, + audio_dir: Path = Path("."), + date_created: Optional[datetime.datetime] = None, + ) -> Self: + """Convert a model run to a model run object.""" + if date_created is None: + date_created = datetime.datetime.now() + + processed_clips: Dict[data.ProcessedClip, ProcessedClipObject] = {} + recordings: Dict[data.Recording, RecordingObject] = {} + clips: Dict[data.Clip, ClipObject] = {} + tags: Dict[data.Tag, TagObject] = {} + sound_events: Dict[data.SoundEvent, SoundEventObject] = {} + predicted_sound_events: Dict[ + data.PredictedSoundEvent, PredictedSoundEventObject + ] = {} + + processed_clips_list = [ + ProcessedClipObject.from_processed_clip( + processed_clip, + processed_clips=processed_clips, + recordings=recordings, + clips=clips, + tags=tags, + sound_events=sound_events, + predicted_sound_events=predicted_sound_events, + audio_dir=audio_dir, + ) + for processed_clip in model_run.clips + ] + + return cls( + info=ModelRunInfo.from_model_run( + model_run, + date_created=date_created, + ), + clips=list(clips.values()) if clips else None, + tags=list(tags.values()) if tags else None, + recordings=list(recordings.values()) if recordings else None, + sound_events=list(sound_events.values()) if sound_events else None, + predicted_sound_events=list(predicted_sound_events.values()) + if predicted_sound_events + else None, + processed_clips=processed_clips_list, + ) + + def to_model_run( + self, + audio_dir: Path = Path("."), + ) -> data.ModelRun: + """Convert a model run object to a model run.""" + tags: Dict[int, data.Tag] = {} + recordings: Dict[int, data.Recording] = {} + clips: Dict[int, data.Clip] = {} + sound_events: Dict[int, data.SoundEvent] = {} + predicted_sound_events: Dict[int, data.PredictedSoundEvent] = {} + processed_clips: Dict[int, data.ProcessedClip] = {} + + for tag in self.tags or []: + tags[tag.id] = tag.to_tag() + + for recording in self.recordings or []: + recordings[recording.id] = recording.to_recording( + tags=tags, audio_dir=audio_dir + ) + + for clip in self.clips or []: + clips[clip.id] = clip.to_clip( + tags=tags, + recordings=recordings, + ) + + for sound_event in self.sound_events or []: + sound_events[sound_event.id] = sound_event.to_sound_event( + recordings=recordings, + tags=tags, + ) + + for predicted_sound_event in self.predicted_sound_events or []: + predicted_sound_events[ + predicted_sound_event.id + ] = predicted_sound_event.to_predicted_sound_event( + recordings=recordings, + sound_events=sound_events, + tags=tags, + ) + + for processed_clip in self.processed_clips or []: + processed_clips[ + processed_clip.id + ] = processed_clip.to_processed_clip( + clips=clips, + sound_events=predicted_sound_events, + tags=tags, + ) + + return data.ModelRun( + id=self.info.uuid, + model=self.info.model, + created_on=self.info.date_created, + clips=list(processed_clips.values()), + ) + + def is_json(path: Union[str, os.PathLike]) -> bool: """Check if a file is a JSON file.""" path = Path(path) diff --git a/src/soundevent/io/model_runs.py b/src/soundevent/io/model_runs.py index e69de29..5392516 100644 --- a/src/soundevent/io/model_runs.py +++ b/src/soundevent/io/model_runs.py @@ -0,0 +1,151 @@ +"""Model Runs IO module of the soundevent package. + +Here you can find the classes and functions for reading and writing model runs. +""" + +from pathlib import Path +from typing import Dict + +from soundevent import data +from soundevent.io.formats import aoef, infer_format +from soundevent.io.types import Loader, PathLike, Saver + +__all__ = [ + "load_model_run", + "save_model_run", +] + + +SAVE_FORMATS: Dict[str, Saver[data.ModelRun]] = {} +LOAD_FORMATS: Dict[str, Loader[data.ModelRun]] = {} + + +def load_model_run(path: PathLike, audio_dir: PathLike = ".") -> data.ModelRun: + """Load a ModelRun from a file. + + Parameters + ---------- + path : PathLike + Path to the file to load. + + audio_dir : PathLike, optional + Path to the directory containing the audio files, by default ".". + The audio file paths in the dataset will be relative to this directory. + + Returns + ------- + model_run : ModelRun + The loaded model run. + + Raises + ------ + NotImplementedError + If the format of the file is not supported. + + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Path {path} does not exist.") + + try: + format_ = infer_format(path) + except ValueError as e: + raise NotImplementedError(f"File {path} format not supported.") from e + + loader = LOAD_FORMATS[format_] + return loader(path, audio_dir=audio_dir) + + +def save_model_run( + model_run: data.ModelRun, + path: PathLike, + audio_dir: PathLike = ".", + format: str = "aoef", +) -> None: + """Save a ModelRun to a file. + + Parameters + ---------- + model_run : ModelRun + The model run to save. + path : Path + Path to the file to save the dataset to. + audio_dir : Path, optional + Path to the directory containing the audio files, by default ".". + format : DatasetFormat, optional + The format to save the dataset in, by default "aoef". + + Raises + ------ + NotImplementedError + If the format of the file is not supported. + + """ + try: + saver = SAVE_FORMATS[format] + except KeyError as e: + raise NotImplementedError(f"Format {format} not supported.") from e + + saver(model_run, path, audio_dir=audio_dir) + + +def load_model_run_aoef_format( + path: PathLike, + audio_dir: PathLike = ".", +) -> data.ModelRun: + """Load a ModelRun from a JSON file in AOEF format. + + Parameters + ---------- + path : Path + Path to the file to load. + + audio_dir : Path, optional + Path to the directory containing the audio files, by default ".". + The audio file paths in the dataset will be relative to this directory. + + Returns + ------- + model_run : ModelRun + The loaded model run. + """ + audio_dir = Path(audio_dir).resolve() + + with open(path, "r") as f: + dataset = aoef.ModelRunObject.model_validate_json(f.read()) + + return dataset.to_model_run(audio_dir=audio_dir) + + +def save_model_run_aoef_format( + obj: data.ModelRun, + path: PathLike, + audio_dir: PathLike = ".", +) -> None: + """Save a ModelRun to a JSON file in AOEF format. + + Parameters + ---------- + obj : ModelRun + The model run to save. + path : PathLike + Path to the file to save the dataset to. + audio_dir : PathLike, optional + Path to the directory containing the audio files, by default ".". + The audio file paths in the dataset will be relative to this directory. + + """ + audio_dir = Path(audio_dir).resolve() + + dataset_object = aoef.ModelRunObject.from_model_run( + obj, + audio_dir=audio_dir, + ) + + with open(path, "w") as f: + f.write(dataset_object.model_dump_json(indent=None, exclude_none=True)) + + +SAVE_FORMATS["aoef"] = save_model_run_aoef_format +LOAD_FORMATS["aoef"] = load_model_run_aoef_format diff --git a/src/soundevent/io/types.py b/src/soundevent/io/types.py new file mode 100644 index 0000000..7b07c6a --- /dev/null +++ b/src/soundevent/io/types.py @@ -0,0 +1,42 @@ +"""Submodule of io module containing type definitions.""" +import os +import sys +from typing import Generic, TypeVar, Union + +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + +__all__ = [ + "PathLike", + "Saver", + "Loader", +] + +PathLike = Union[str, os.PathLike] + + +D = TypeVar("D", contravariant=True) +T = TypeVar("T", covariant=True) + + +class Saver(Protocol, Generic[D]): + """Protocol for saving functions.""" + + def __call__( + self, + obj: D, + path: PathLike, + audio_dir: PathLike = ".", + ) -> None: + """Save object to path.""" + ... + + +class Loader(Protocol, Generic[T]): + """Protocol for loading functions.""" + + def __call__(self, path: PathLike, audio_dir: PathLike = ".") -> T: + """Load object from path.""" + ...