-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cd8095f
commit dee20b8
Showing
12 changed files
with
186 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from internal.domain.task.entities.ar.ar_task import ArTask # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from desbordante.ar import ArAlgorithm | ||
from desbordante.ar.algorithms import Apriori | ||
from internal.domain.task.entities.task import Task | ||
from internal.domain.task.value_objects import PrimitiveName, IncorrectAlgorithmName | ||
from internal.domain.task.value_objects.ar import ArTaskConfig, ArTaskResult | ||
from internal.domain.task.value_objects.ar import ( | ||
ArAlgoName, | ||
ArModel, | ||
ArAlgoResult, | ||
) | ||
|
||
|
||
class ArTask(Task[ArAlgorithm, ArTaskConfig, ArTaskResult]): | ||
""" | ||
Task class for Association Rule (AR) mining. | ||
This class handles the execution of different AR algorithms and processes | ||
the results into the appropriate format. It implements the abstract methods | ||
defined in the Task base class. | ||
Methods: | ||
- _match_algo_by_name(algo_name: ArAlgoName) -> ArAlgorithm: | ||
Match AR algorithm by its name. | ||
- _collect_result(algo: ArAlgorithm) -> ArTaskResult: | ||
Process the output of the AR algorithm and return the result. | ||
""" | ||
|
||
def _collect_result(self, algo: ArAlgorithm) -> ArTaskResult: | ||
""" | ||
Collect and process the AR result. | ||
Args: | ||
algo (ArAlgorithm): AR algorithm to process. | ||
Returns: | ||
ArTaskResult: The processed result containing association rules. | ||
""" | ||
ar_ids = algo.get_ar_ids() | ||
ar_strings = algo.get_ars() | ||
algo_result = ArAlgoResult( | ||
ars=list(map(ArModel.from_ar, ar_strings)), | ||
ar_ids=list(map(ArModel.from_ar_ids, ar_ids)), | ||
) | ||
return ArTaskResult(primitive_name=PrimitiveName.ar, result=algo_result) | ||
|
||
def _match_algo_by_name(self, algo_name: str) -> ArAlgorithm: | ||
""" | ||
Match the association rule algorithm by name. | ||
Args: | ||
algo_name (ArAlgoName): The name of the AR algorithm. | ||
Returns: | ||
ArAlgorithm: The corresponding algorithm instance. | ||
""" | ||
match algo_name: | ||
case ArAlgoName.Apriori: | ||
return Apriori() | ||
case _: | ||
raise IncorrectAlgorithmName(algo_name, "AR") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Literal | ||
|
||
from pydantic import BaseModel | ||
|
||
from internal.domain.task.value_objects.primitive_name import PrimitiveName | ||
from internal.domain.task.value_objects.ar.algo_config import OneOfArAlgoConfig | ||
from internal.domain.task.value_objects.ar.algo_result import ( # noqa: F401 | ||
ArAlgoResult, | ||
ArModel, | ||
) | ||
from internal.domain.task.value_objects.ar.algo_name import ArAlgoName # noqa: F401 | ||
|
||
|
||
class BaseArTaskModel(BaseModel): | ||
primitive_name: Literal[PrimitiveName.ar] | ||
|
||
|
||
class ArTaskConfig(BaseArTaskModel): | ||
config: OneOfArAlgoConfig | ||
|
||
|
||
class ArTaskResult(BaseArTaskModel): | ||
result: ArAlgoResult |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Literal, Annotated | ||
from pydantic import Field | ||
from internal.domain.common import OptionalModel | ||
from internal.domain.task.value_objects.ar.algo_name import ArAlgoName | ||
from internal.domain.task.value_objects.ar.algo_descriptions import descriptions | ||
|
||
|
||
class BaseArConfig(OptionalModel): | ||
__non_optional_fields__ = { | ||
"algo_name", | ||
} | ||
|
||
|
||
class AprioriConfig(BaseArConfig): | ||
algo_name: Literal[ArAlgoName.Apriori] | ||
|
||
has_tid: Annotated[bool, Field(description=descriptions["has_tid"])] | ||
minconf: Annotated[float, Field(ge=0, le=1, description=descriptions["minconf"])] | ||
minsup: Annotated[float, Field(ge=0, le=1, description=descriptions["minsup"])] | ||
input_format: Annotated[ | ||
str, | ||
Literal["singular", "tabular"], | ||
Field(description=descriptions["input_format"]), | ||
] | ||
item_column_index: Annotated[ | ||
int, Field(ge=0, description=descriptions["item_column_index"]) | ||
] | ||
tid_column_index: Annotated[ | ||
int, Field(ge=0, description=descriptions["tid_column_index"]) | ||
] | ||
|
||
|
||
OneOfArAlgoConfig = Annotated[ | ||
AprioriConfig, | ||
Field(discriminator="algo_name"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
descriptions = { | ||
"has_tid": "Indicates that the first column contains the transaction IDs", | ||
"minconf": "Minimum confidence value (between 0 and 1)", | ||
"input_format": "Format of the input dataset for AR mining", | ||
"item_column_index": "Index of the column where an item name is stored", | ||
"minsup": "Minimum support value (between 0 and 1)", | ||
"tid_column_index": "Index of the column where a TID is stored", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from enum import StrEnum, auto | ||
|
||
|
||
class ArAlgoName(StrEnum): | ||
Apriori = auto() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from pydantic import BaseModel | ||
from desbordante.ar import ARStrings, ArIDs | ||
|
||
|
||
class ArModel(BaseModel): | ||
@classmethod | ||
def from_ar(cls, ar: ARStrings): | ||
return cls(confidence=ar.confidence, left=ar.left, right=ar.right) | ||
|
||
@classmethod | ||
def from_ar_ids(cls, ar_id: ArIDs): | ||
return cls(confidence=ar_id.confidence, left=ar_id.left, right=ar_id.right) | ||
|
||
confidence: float | ||
left: list[str] | ||
right: list[str] | ||
|
||
|
||
class ArAlgoResult(BaseModel): | ||
ars: list[ArModel] | ||
ar_ids: list[ArModel] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters