-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* commit changes * commit changes * commit changes * commit changes * commit changes * Update workflows.py * Updates * Updates * Updates * Changes * Pinning requirements * Pinning dev requirements * Updated datasource * Added data cleaning and preprocessing logic * Code linting * Code linting * Added black to requirements * Updated gitignore file * Updated preprocessing function * Added training script * Created flyte workflow * workflowoutput Co-authored-by: Ali Abbas Jaffri <[email protected]>
- Loading branch information
1 parent
faaf72a
commit 67ff9a7
Showing
24 changed files
with
975 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,3 +143,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
*/gtzan/* | ||
projects/bravemusic/bravemusic/gtzan/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
!.flyte |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
admin: | ||
# For GRPC endpoints you might want to use dns:///flyte.myexample.com | ||
endpoint: dns:///playground.hosted.unionai.cloud | ||
authType: Pkce | ||
# Change insecure flag to ensure that you use the right setting for your environment | ||
insecure: false | ||
storage: | ||
type: stow | ||
stow: | ||
kind: s3 | ||
config: | ||
auth_type: iam | ||
region: us-east-2 | ||
logger: | ||
# Logger settings to control logger output. Useful to debug logger: | ||
show-source: true | ||
level: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[sdk] | ||
# This option specifies the python packages where-in to search for workflows and tasks workflow packages. These workflows and tasks are then serialized during the `make serialize` commands | ||
workflow_packages=bravemusic | ||
|
||
[auth] | ||
raw_output_data_prefix=s3://open-compute-playground |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
admin: | ||
# For GRPC endpoints you might want to use dns:///flyte.myexample.com | ||
endpoint: dns:///localhost:30081 | ||
authType: Pkce | ||
insecure: true | ||
logger: | ||
show-source: true | ||
level: 0 | ||
storage: | ||
connection: | ||
access-key: minio | ||
auth-type: accesskey | ||
disable-ssl: true | ||
endpoint: http://localhost:30084 | ||
region: us-east-1 | ||
secret-key: miniostorage | ||
type: minio | ||
container: "my-s3-bucket" | ||
enable-multicontainer: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[sdk] | ||
# This option specifies the python packages where-in to search for workflows and tasks workflow packages. These workflows and tasks are then serialized during the `make serialize` commands | ||
workflow_packages=bravemusic | ||
|
||
[auth] | ||
raw_output_data_prefix=s3://my-s3-bucket/flytelab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
tasks: | ||
- init: | | ||
python -m venv ~/venvs/brave | ||
source ~/venvs/brave/bin/activate | ||
pip install -r requirements.txt -r requirements-dev.txt | ||
command: python3 projects/bravemusic/bravemusic/workflows.py | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
FROM python:3.9-slim-buster | ||
|
||
WORKDIR /root | ||
ENV VENV /opt/venv | ||
ENV LANG C.UTF-8 | ||
ENV LC_ALL C.UTF-8 | ||
ENV PYTHONPATH /root | ||
|
||
# e.g. flyte.config or sandbox.config | ||
ARG config | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
libsm6 \ | ||
libxext6 \ | ||
libxrender-dev \ | ||
ffmpeg \ | ||
build-essential | ||
|
||
# Install the AWS cli separately to prevent issues with boto being written over | ||
RUN pip3 install awscli | ||
|
||
ENV VENV /opt/venv | ||
|
||
# Virtual environment | ||
RUN python3 -m venv ${VENV} | ||
ENV PATH="${VENV}/bin:$PATH" | ||
|
||
# Install Python dependencies | ||
COPY requirements.txt /root | ||
RUN pip install -r /root/requirements.txt | ||
|
||
COPY bravemusic /root/bravemusic | ||
COPY $config /root/flyte.config | ||
|
||
# This image is supplied by the build script and will be used to determine the version | ||
# when registering tasks, workflows, and launch plans | ||
ARG image | ||
ENV FLYTE_INTERNAL_IMAGE $image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Design Doc: Brave-Hyenas-2 | ||
## MLOps Community: Engineering labs | ||
|
||
| Team name | brave-hyenas-2 | | ||
|---------------------|:----------------------------------------:| | ||
|Project name | brave-hyenas-2 | | ||
| Project description | Hackathon - brave-hyenas-2 team | | ||
|Using GPUs? (Yes/No) | No | | ||
|
||
|
||
|
||
### Problem Statement | ||
What problem are you solving? | ||
It’s usually hard to identify correctly what kind of music genre is playing thus our team embraced in tackling to classify music genre using deep learning. | ||
|
||
|
||
### ...... | ||
|
||
|
||
|
||
|
||
|
||
![new](https://user-images.githubusercontent.com/85021780/161294904-a4158856-0558-424f-9f07-85aef8f4b423.jpg) | ||
|
||
|
||
|
||
|
||
### Solution (working progress) .... |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
import tarfile | ||
import git | ||
|
||
GIT_URL = "https://huggingface.co/datasets/marsyas/gtzan" | ||
GTZAN_PATH = "./gtzan" | ||
GTZAN_ZIP_FILE_PATH = "./gtzan/data" | ||
GTZAN_ZIP_FILE_NAME = "genres.tar.gz" | ||
|
||
|
||
class Progress(git.remote.RemoteProgress): | ||
def update(self, op_code, cur_count, max_count=None, message=""): | ||
print(self._cur_line) | ||
|
||
|
||
def download_gtzan_repo(): | ||
if not os.path.isdir(GTZAN_PATH) or not any(os.scandir(GTZAN_PATH)): | ||
git.Repo.clone_from(url=GIT_URL, to_path=GTZAN_PATH, progress=Progress()) | ||
extract_gtzan_repo_tarball() | ||
else: | ||
print("dataset already exists") | ||
|
||
|
||
def extract_gtzan_repo_tarball(): | ||
# open file | ||
file = tarfile.open(f"{GTZAN_ZIP_FILE_PATH}/{GTZAN_ZIP_FILE_NAME}") | ||
# extracting file | ||
file.extractall(GTZAN_ZIP_FILE_PATH) | ||
file.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
download_gtzan_repo() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os | ||
import math | ||
import librosa | ||
from datasource import download_gtzan_repo, GTZAN_ZIP_FILE_PATH | ||
|
||
SAMPLE_RATE = 22050 | ||
TRACK_DURATION = 30 # measured in seconds | ||
SAMPLES_PER_TRACK = SAMPLE_RATE * TRACK_DURATION | ||
BAD_FORMATS = ["jazz.00054.wav"] | ||
|
||
|
||
def clean_dataset(): | ||
for (dir_path, dir_names, filenames) in os.walk(f"{GTZAN_ZIP_FILE_PATH}/genres/"): | ||
print(dir_path) | ||
[ | ||
os.remove(f"{dir_path}{filename}") | ||
for filename in filenames | ||
if not filename.endswith(".wav") | ||
] | ||
[ | ||
os.renames( | ||
old=f"{dir_path}/{filename}", | ||
new=f"{dir_path}/{filename}".replace("._", ""), | ||
) | ||
for filename in filenames | ||
if f"{dir_path}/{filename}".startswith("._") | ||
] | ||
[ | ||
os.remove(f"{dir_path}/{filename}") | ||
for filename in filenames | ||
if filename.startswith("._") | ||
] | ||
|
||
|
||
def preprocess( | ||
dataset_path: str, | ||
num_mfcc: int = 13, | ||
n_fft: int = 2048, | ||
hop_length: int = 512, | ||
num_segments: int = 10, | ||
) -> dict: | ||
data = {"mapping": [], "labels": [], "mfcc": []} | ||
|
||
samples_per_segment = int(SAMPLES_PER_TRACK / num_segments) | ||
num_mfcc_vectors_per_segment = math.ceil(samples_per_segment / hop_length) | ||
|
||
# loop through all genre sub-folder | ||
for i, (dir_path, dir_names, filenames) in enumerate( | ||
os.walk(f"{GTZAN_ZIP_FILE_PATH}/genres/") | ||
): | ||
# ensure we're processing a genre sub-folder level | ||
if dir_path is not dataset_path: | ||
# save genre label (i.e., sub-folder name) in the mapping | ||
semantic_label = dir_path.split("/")[-1] | ||
print(semantic_label) | ||
data["mapping"].append(semantic_label) | ||
print("Processing: {}".format(semantic_label)) | ||
|
||
# process all audio files in genre sub-dir | ||
for f in filenames: | ||
if f not in BAD_FORMATS: | ||
# load audio file | ||
file_path = os.path.join(dir_path, f) | ||
signal, sample_rate = librosa.load(path=file_path, sr=SAMPLE_RATE) | ||
|
||
# process all segments of audio file | ||
for d in range(num_segments): | ||
|
||
# calculate start and finish sample for current segment | ||
start = samples_per_segment * d | ||
finish = start + samples_per_segment | ||
|
||
# extract mfcc | ||
mfcc = librosa.feature.mfcc( | ||
y=signal[start:finish], | ||
sr=sample_rate, | ||
n_mfcc=num_mfcc, | ||
n_fft=n_fft, | ||
hop_length=hop_length, | ||
) | ||
mfcc = mfcc.T | ||
|
||
# store only mfcc feature with expected number of vectors | ||
if len(mfcc) == num_mfcc_vectors_per_segment: | ||
data["mfcc"].append(mfcc.tolist()) | ||
data["labels"].append(i - 1) | ||
# print("{}, segment:{}".format(file_path, d + 1)) | ||
return data | ||
|
||
|
||
if __name__ == "__main__": | ||
download_gtzan_repo() | ||
clean_dataset() | ||
data = preprocess(dataset_path=GTZAN_ZIP_FILE_PATH) | ||
print(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import json | ||
import typing | ||
import warnings | ||
import numpy as np | ||
from tensorflow import keras | ||
from dataclasses import dataclass | ||
from preprocess import preprocess | ||
from datasource import GTZAN_ZIP_FILE_PATH | ||
from dataclasses_json import dataclass_json | ||
from flytekit.types.directory import FlyteDirectory | ||
from sklearn.model_selection import train_test_split | ||
|
||
|
||
warnings.filterwarnings("ignore") | ||
MODELSAVE = [typing.TypeVar("str")] | ||
model_file = typing.NamedTuple("Model", model=FlyteDirectory[MODELSAVE]) | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class Hyperparameters(object): | ||
batch_size: int = 32 | ||
metrics: str = "accuracy" | ||
loss = ("sparse_categorical_crossentropy",) | ||
epochs: int = 30 | ||
learning_rate: float = 0.0001 | ||
|
||
|
||
def train( | ||
data: dict, | ||
hp: Hyperparameters | ||
) -> model_file: | ||
# with open("data.json", "r") as fp: | ||
# data = json.load(fp) | ||
|
||
# convert lists to numpy arrays | ||
X = np.array(data["mfcc"]) | ||
y = np.array(data["labels"]) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) | ||
|
||
model = keras.Sequential( | ||
[ | ||
keras.layers.Flatten(input_shape=(X.shape[1], X.shape[2])), | ||
keras.layers.Dense(512, activation="relu"), | ||
keras.layers.Dense(256, activation="relu"), | ||
keras.layers.Dense(64, activation="relu"), | ||
keras.layers.Dense(10, activation="softmax"), | ||
] | ||
) | ||
optimiser = keras.optimizers.Adam(learning_rate=hp.learning_rate) | ||
model.compile( | ||
optimizer=optimiser, | ||
loss=hp.loss, | ||
metrics=[hp.metrics], | ||
) | ||
# train model | ||
model.fit( | ||
X_train, | ||
y_train, | ||
validation_data=(X_test, y_test), | ||
batch_size=hp.batch_size, | ||
epochs=hp.epochs, | ||
) | ||
|
||
Dir = "model" | ||
model.save(Dir) | ||
return model | ||
|
||
|
||
if __name__ == '__main__': | ||
data = preprocess(dataset_path=GTZAN_ZIP_FILE_PATH) | ||
model = train( | ||
data=data, | ||
hp=Hyperparameters(epochs=1) | ||
) |
Oops, something went wrong.