Skip to content

Commit

Permalink
feat: add new Calibration core object (#1799)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Nov 28, 2024
1 parent 9a2368b commit b86d380
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Clipping and speaker/source alignment issues in speech separation pipeline have
- feat(task): add support for manual optimization in `SpeakerDiarization` task
- feat(utils): add `hidden` option to `ProgressHook`
- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter
- feat(core): add `Calibration` class to calibrate logits/distances into probabilities

### Improvements

Expand Down
151 changes: 151 additions & 0 deletions pyannote/audio/core/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import os
from pathlib import Path
from typing import Optional, Text, Union

import numpy as np
import safetensors.numpy
import scipy.interpolate
from sklearn.isotonic import IsotonicRegression
from sklearn.utils.validation import NotFittedError, check_is_fitted

from pyannote.audio.utils.hf_hub import AssetFileName, download_from_hf_hub


class Calibration(IsotonicRegression):
"""Logit/distance calibration"""

def __init__(self):
super().__init__(y_min=0.0, y_max=1.0, increasing="auto", out_of_bounds="clip")

def save(self, path: str):
"""Save fitted calibration to disk
Parameters
----------
path : str
Path to the file where the calibration should be saved
Raises
------
NotFittedError
If the calibration is not fitted yet
"""
try:
check_is_fitted(self)
except NotFittedError:
raise NotFittedError("Cannot save an unfitted model.")

tensor_dict = {
"X_min_": self.X_min_,
"X_max_": self.X_max_,
"X_thresholds_": self.X_thresholds_,
"y_thresholds_": self.y_thresholds_,
"increasing_": self.increasing_,
}

safetensors.numpy.save_file(tensor_dict, path)

@classmethod
def from_file(cls, path: str) -> "Calibration":
"""Load calibration from disk
Parameters
----------
path : str
Path to the file where the calibration is saved
Returns
-------
calibration : Calibration
Fitted calibration
"""
calibration = cls()

tensor_dict = safetensors.numpy.load_file(path)
for key, value in tensor_dict.items():
setattr(calibration, key, value)

calibration.f_ = scipy.interpolate.interp1d(
np.hstack(
[
[np.min(calibration.X_thresholds_) - 1.0],
calibration.X_thresholds_,
[np.max(calibration.X_thresholds_) + 1.0],
]
),
np.hstack(
[
[1.0 - calibration.increasing_],
calibration.y_thresholds_,
[1.0 * calibration.increasing_],
]
),
kind="linear",
bounds_error=False,
)

return calibration

@classmethod
def from_pretrained(
cls,
checkpoint: str,
subfolder: Optional[str] = None,
token: Optional[Text] = None,
cache_dir: Optional[Union[str, Path]] = None,
**kwargs,
) -> Optional["Calibration"]:
"""Load calibration from disk or Huggingface Hub
Parameters
----------
checkpoint : Path or str
Path to checkpoint or a model identifier from the hf.co model hub.
subfolder : str, optional
Folder inside the hf.co model repo.
token : str, optional
When loading a private hf.co model, set `token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory.
"""
if os.path.isfile(checkpoint):
return cls.from_file(checkpoint)

path = download_from_hf_hub(
checkpoint,
AssetFileName.Calibration,
subfolder=subfolder,
cache_dir=cache_dir,
token=token,
)

if path is None:
return None

return cls.from_file(path)
93 changes: 93 additions & 0 deletions pyannote/audio/utils/hf_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from enum import Enum
from pathlib import Path
from typing import Optional, Union

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError


# Correspondence between asset_file type
# and their filename on Huggingface
class AssetFileName(Enum):
Calibration = "calibration.safetensors"
Model = "pytorch_model.bin"
Pipeline = "config.yaml"


def download_from_hf_hub(
checkpoint: str,
asset_file: AssetFileName,
subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
token: Union[bool, str, None] = None,
) -> Optional[str]:
"""Download file from Huggingface Hub
Parameters
----------
checkpoint : Path or str
Model identifier from the hf.co model hub.
asset_file : AssetFileName
Type of asset file to download.
subfolder : str, optional
Folder inside the hf.co model repo.
token : str, optional
When loading a private hf.co model, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
"""

if "@" in checkpoint:
model_id, revision = checkpoint.split("@")
else:
model_id, revision = checkpoint, None

try:
return hf_hub_download(
model_id,
asset_file.value,
subfolder=subfolder,
repo_type="model",
revision=revision,
library_name="pyannote",
cache_dir=cache_dir,
token=token,
)
except HfHubHTTPError:
print(
f"""
Could not download {asset_file.name.lower()} from {model_id}.
It might be because the repository is private or gated:
* visit https://hf.co/{model_id} to accept user conditions
* visit https://hf.co/settings/tokens to create an authentication token
* load the {asset_file.name.lower()} with the `token` argument:
>>> {asset_file.name}.from_pretrained('{model_id}', use_auth_token=...)
"""
)
return
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pyannote.metrics >= 3.2
pyannote.pipeline >= 3.0.1
pytorch_metric_learning >= 2.1.0
rich >= 12.0.0
safetensors >= 0.4.5
semver >= 3.0.0
soundfile >= 0.12.1
speechbrain >= 1.0.0
Expand Down

0 comments on commit b86d380

Please sign in to comment.