Skip to content

Commit

Permalink
refactor: Separate train_xgboost_model into two functions for better …
Browse files Browse the repository at this point in the history
…modularity
  • Loading branch information
janezlapajne committed Sep 1, 2024
1 parent 97340da commit f6183cf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from source.misc import check_spectral_images, display_spectral_image
from source.processing import (
convert_selected_areas_to_train_data,
find_transformation_between_images,
select_areas_on_images,
train_xgboost_model,
Expand Down Expand Up @@ -51,7 +52,8 @@ def select_areas(label: str, category: str):
@app.command()
def train_model():
selected_areas = load_all_selected_areas()
encoder, model = train_xgboost_model(selected_areas)
X, y = convert_selected_areas_to_train_data(selected_areas)
encoder, model = train_xgboost_model(X, y)
save_model(encoder, model)


Expand Down
3 changes: 2 additions & 1 deletion source/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .model import train_xgboost_model
from .model import convert_selected_areas_to_train_data, train_xgboost_model
from .selector import select_areas_on_images
from .transformator import find_transformation_between_images

__all__ = [
"find_transformation_between_images",
"select_areas_on_images",
"train_xgboost_model",
"convert_selected_areas_to_train_data",
]
10 changes: 8 additions & 2 deletions source/processing/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from siapy.entities import Pixels
from sklearn.preprocessing import LabelEncoder
from xgboost import XGBClassifier
Expand All @@ -10,9 +11,9 @@
)


def train_xgboost_model(
def convert_selected_areas_to_train_data(
selected_areas: dict[str, dict[str, list[Pixels]]],
) -> tuple[LabelEncoder,XGBClassifier]:
) -> tuple[list[np.ndarray], list[str]]:
image_set_cam1, image_set_cam2 = read_spectral_images()
labels_cam1, labels_cam2 = extract_labels_from_spectral_images(
image_set_cam1, image_set_cam2
Expand All @@ -29,7 +30,12 @@ def train_xgboost_model(
signal_mean = siagnatures.signals.mean()
X.append(signal_mean)
y.append(category)
return X, y


def train_xgboost_model(
X: list[np.ndarray], y: list[str]
) -> tuple[LabelEncoder, XGBClassifier]:
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)
model = XGBClassifier().fit(X, y_encoded)
Expand Down

0 comments on commit f6183cf

Please sign in to comment.