Skip to content

Commit

Permalink
feat: added write data to bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
mbsantiago committed Dec 19, 2023
1 parent e5fee3f commit 3926cab
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 9 deletions.
21 changes: 20 additions & 1 deletion src/soundevent/audio/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Currently only supports reading and writing of .wav files.
"""

from io import BytesIO
from pathlib import Path
from typing import Dict, Optional, Tuple

Expand Down Expand Up @@ -213,3 +213,22 @@ def load_clip(
"samplerate": recording.samplerate,
},
)


def audio_to_bytes(
data: np.ndarray,
samplerate: int,
bit_depth: int = 16,
) -> bytes:
"""Convert audio data to bytes."""
buffer = BytesIO()
with sf.SoundFile(
buffer,
mode="w",
samplerate=samplerate,
channels=data.shape[1],
format="RAW",
subtype=f"PCM_{bit_depth}",
) as fp:
fp.write(data)
return buffer.getvalue()
54 changes: 54 additions & 0 deletions src/soundevent/audio/media_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Functions for getting media information from WAV files."""
import hashlib
import struct
from dataclasses import dataclass
from typing import IO

Expand Down Expand Up @@ -192,3 +193,56 @@ def compute_md5_checksum(path: PathLike) -> str:
md5.update(buffer)
buffer = fp.read(BUFFER_SIZE)
return md5.hexdigest()


def generate_wav_header(
samplerate: int,
channels: int,
samples: int,
bit_depth: int = 16,
) -> bytes:
"""Generate the data of a WAV header.
This function generates the data of a WAV header according to the
given parameters. The WAV header is a 44-byte string that contains
information about the audio data, such as the sample rate, the
number of channels, and the number of samples. The WAV header
assumes that the audio data is PCM encoded.
Parameters
----------
samplerate
Sample rate in Hz.
channels
Number of channels.
samples
Number of samples.
bit_depth
The number of bits per sample. By default, it is 16 bits.
Notes
-----
The structure of the WAV header is described in
(WAV PCM soundfile format)[http://soundfile.sapp.org/doc/WaveFormat/].
"""

data_size = samples * channels * bit_depth // 8
byte_rate = samplerate * channels * bit_depth // 8
block_align = channels * bit_depth // 8

return struct.pack(
"<4si4s4sihhiihh4si", # Format string
b"RIFF", # RIFF chunk id
data_size + 36, # Size of the entire file minus 8 bytes
b"WAVE", # RIFF chunk id
b"fmt ", # fmt chunk id
16, # Size of the fmt chunk
1, # Audio format (1 corresponds to PCM)
channels, # Number of channels
samplerate, # Sample rate in Hz
byte_rate, # Byte rate
block_align, # Block align
bit_depth, # Number of bits per sample
b"data", # data chunk id
data_size, # Size of the data chunk
)
16 changes: 8 additions & 8 deletions src/soundevent/io/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["recording_set"] = "recording_set", # type: ignore
) -> data.RecordingSet: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -33,7 +33,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["dataset"] = "dataset", # type: ignore
) -> data.Dataset: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -43,7 +43,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["annotation_set"] = "annotation_set", # type: ignore
) -> data.AnnotationSet: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -53,7 +53,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["annotation_project"] = "annotation_project", # type: ignore
) -> data.AnnotationProject: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -63,7 +63,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["prediction_set"] = "prediction_set", # type: ignore
) -> data.PredictionSet: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -73,7 +73,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["model_run"] = "model_run", # type: ignore
) -> data.ModelRun: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -83,7 +83,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["evaluation_set"] = "evaluation_set", # type: ignore
) -> data.EvaluationSet: # type: ignore
...
... # pragma: no cover


@overload
Expand All @@ -93,7 +93,7 @@ def load(
format: Optional[str] = "aoef",
type: Literal["evaluation"] = "evaluation", # type: ignore
) -> data.Evaluation: # type: ignore
...
... # pragma: no cover


def load(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_audio/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Common fixtures for audio tests."""
from pathlib import Path

import pytest

SAMPLE_AUDIO = Path(__file__).parent / "24bitdepth.wav"


@pytest.fixture
def sample_24_bit_audio() -> Path:
"""Return a Path object to a 24-bit WAV file."""
return SAMPLE_AUDIO
10 changes: 10 additions & 0 deletions tests/test_audio/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pathlib import Path

from soundevent.audio.io import audio_to_bytes, load_audio


def test_audio_to_bytes(sample_24_bit_audio: Path):
original_bytes = sample_24_bit_audio.read_bytes()[44:] # ignore header
audio, sr = load_audio(sample_24_bit_audio)
result_bytes = audio_to_bytes(audio, sr, bit_depth=24)
assert len(result_bytes) == len(original_bytes)
35 changes: 35 additions & 0 deletions tests/test_audio/test_media_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test suite for soundevent.audio.media_info module."""
from pathlib import Path

from soundevent.audio.media_info import (
MediaInfo,
generate_wav_header,
get_media_info,
)


def test_can_read_media_info(sample_24_bit_audio: Path):
media_info = get_media_info(sample_24_bit_audio)
assert isinstance(media_info, MediaInfo)
assert media_info.duration_s == 19.4015625
assert media_info.bit_depth == 24
assert media_info.samplerate_hz == 96000
assert media_info.channels == 1
assert media_info.samples == 1862550
assert media_info.audio_format == 1


def test_can_generate_wav_header():
"""Example WAV header from http://soundfile.sapp.org/doc/WaveFormat/."""
header = bytes.fromhex(
"52 49 46 46 24 08 00 00 57 41 56 45 66 6d 74 20 10 00 "
"00 00 01 00 02 00 22 56 00 00 88 58 01 00 04 00 10 00 "
"64 61 74 61 00 08 00 00 00 00 00 00 24 17 1e f3 3c 13 "
"3c 14 16 f9 18 f9 34 e7 23 a6 3c f2 24 f2 11 ce 1a 0d"
)
channels = 2
samplerate = 22050
bit_depth = 16
samples = 2048 // (channels * bit_depth // 8)
generated = generate_wav_header(samplerate, channels, samples, bit_depth)
assert header[:44] == generated

0 comments on commit 3926cab

Please sign in to comment.