Skip to content

Commit

Permalink
feat(primitives): ar
Browse files Browse the repository at this point in the history
  • Loading branch information
vladyoslav committed Oct 27, 2024
1 parent cd8095f commit dee20b8
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 9 deletions.
2 changes: 2 additions & 0 deletions internal/domain/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from internal.domain.task.entities import AfdTask # noqa: F401
from internal.domain.task.entities import AcTask # noqa: F401
from internal.domain.task.entities import IndTask # noqa: F401
from internal.domain.task.entities import AindTask # noqa: F401
from internal.domain.task.entities import ArTask # noqa: F401
3 changes: 3 additions & 0 deletions internal/domain/task/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from internal.domain.task.entities.ac import AcTask
from internal.domain.task.entities.ind import IndTask
from internal.domain.task.entities.aind import AindTask
from internal.domain.task.entities.ar import ArTask
from internal.domain.task.value_objects import PrimitiveName


Expand Down Expand Up @@ -32,4 +33,6 @@ def match_task_by_primitive_name(primitive_name: PrimitiveName):
return IndTask()
case PrimitiveName.aind:
return AindTask()
case PrimitiveName.ar:
return ArTask()
assert_never(primitive_name)
1 change: 1 addition & 0 deletions internal/domain/task/entities/ar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from internal.domain.task.entities.ar.ar_task import ArTask # noqa: F401
58 changes: 58 additions & 0 deletions internal/domain/task/entities/ar/ar_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from desbordante.ar import ArAlgorithm
from desbordante.ar.algorithms import Apriori
from internal.domain.task.entities.task import Task
from internal.domain.task.value_objects import PrimitiveName, IncorrectAlgorithmName
from internal.domain.task.value_objects.ar import ArTaskConfig, ArTaskResult
from internal.domain.task.value_objects.ar import (
ArAlgoName,
ArModel,
ArAlgoResult,
)


class ArTask(Task[ArAlgorithm, ArTaskConfig, ArTaskResult]):
"""
Task class for Association Rule (AR) mining.
This class handles the execution of different AR algorithms and processes
the results into the appropriate format. It implements the abstract methods
defined in the Task base class.
Methods:
- _match_algo_by_name(algo_name: ArAlgoName) -> ArAlgorithm:
Match AR algorithm by its name.
- _collect_result(algo: ArAlgorithm) -> ArTaskResult:
Process the output of the AR algorithm and return the result.
"""

def _collect_result(self, algo: ArAlgorithm) -> ArTaskResult:
"""
Collect and process the AR result.
Args:
algo (ArAlgorithm): AR algorithm to process.
Returns:
ArTaskResult: The processed result containing association rules.
"""
ar_ids = algo.get_ar_ids()
ar_strings = algo.get_ars()
algo_result = ArAlgoResult(
ars=list(map(ArModel.from_ar, ar_strings)),
ar_ids=list(map(ArModel.from_ar_ids, ar_ids)),
)
return ArTaskResult(primitive_name=PrimitiveName.ar, result=algo_result)

def _match_algo_by_name(self, algo_name: str) -> ArAlgorithm:
"""
Match the association rule algorithm by name.
Args:
algo_name (ArAlgoName): The name of the AR algorithm.
Returns:
ArAlgorithm: The corresponding algorithm instance.
"""
match algo_name:
case ArAlgoName.Apriori:
return Apriori()
case _:
raise IncorrectAlgorithmName(algo_name, "AR")
17 changes: 11 additions & 6 deletions internal/domain/task/entities/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import desbordante
import pandas
from internal.domain.task.value_objects import TaskConfig, TaskResult
from internal.domain.task.value_objects import PrimitiveName, TaskConfig, TaskResult


class Task[A: desbordante.Algorithm, C: TaskConfig, R: TaskResult](ABC):
Expand Down Expand Up @@ -60,10 +60,15 @@ def execute(self, table: pandas.DataFrame, task_config: C) -> R:
algo_config = task_config.config
options = algo_config.model_dump(exclude_unset=True, exclude={"algo_name"})
algo = self._match_algo_by_name(algo_config.algo_name)
# TODO: IND, AIND requires multiple tables
try:
algo.load_data(table=table)
except desbordante.ConfigurationError:
algo.load_data(tables=[table])

# TODO: FIX THIS PLS!!!
match task_config.primitive_name:
case PrimitiveName.ind | PrimitiveName.aind:
algo.load_data(tables=[table])
case PrimitiveName.ar:
algo.load_data(table=table, input_format=options["input_format"])
case _:
algo.load_data(table=table)

algo.execute(**options)
return self._collect_result(algo)
19 changes: 17 additions & 2 deletions internal/domain/task/value_objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from internal.domain.task.value_objects.ac import AcTaskConfig, AcTaskResult
from internal.domain.task.value_objects.ind import IndTaskConfig, IndTaskResult
from internal.domain.task.value_objects.aind import AindTaskConfig, AindTaskResult
from internal.domain.task.value_objects.ar import ArTaskConfig, ArTaskResult

from internal.domain.task.value_objects.config import TaskConfig # noqa: F401
from internal.domain.task.value_objects.result import TaskResult # noqa: F401
Expand All @@ -24,11 +25,25 @@
)

OneOfTaskConfig = Annotated[
Union[FdTaskConfig, AfdTaskConfig, AcTaskConfig, IndTaskConfig, AindTaskConfig],
Union[
FdTaskConfig,
AfdTaskConfig,
AcTaskConfig,
IndTaskConfig,
AindTaskConfig,
ArTaskConfig,
],
Field(discriminator="primitive_name"),
]

OneOfTaskResult = Annotated[
Union[FdTaskResult, AfdTaskResult, AcTaskResult, IndTaskResult, AindTaskResult],
Union[
FdTaskResult,
AfdTaskResult,
AcTaskResult,
IndTaskResult,
AindTaskResult,
ArTaskResult,
],
Field(discriminator="primitive_name"),
]
23 changes: 23 additions & 0 deletions internal/domain/task/value_objects/ar/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Literal

from pydantic import BaseModel

from internal.domain.task.value_objects.primitive_name import PrimitiveName
from internal.domain.task.value_objects.ar.algo_config import OneOfArAlgoConfig
from internal.domain.task.value_objects.ar.algo_result import ( # noqa: F401
ArAlgoResult,
ArModel,
)
from internal.domain.task.value_objects.ar.algo_name import ArAlgoName # noqa: F401


class BaseArTaskModel(BaseModel):
primitive_name: Literal[PrimitiveName.ar]


class ArTaskConfig(BaseArTaskModel):
config: OneOfArAlgoConfig


class ArTaskResult(BaseArTaskModel):
result: ArAlgoResult
36 changes: 36 additions & 0 deletions internal/domain/task/value_objects/ar/algo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Literal, Annotated
from pydantic import Field
from internal.domain.common import OptionalModel
from internal.domain.task.value_objects.ar.algo_name import ArAlgoName
from internal.domain.task.value_objects.ar.algo_descriptions import descriptions


class BaseArConfig(OptionalModel):
__non_optional_fields__ = {
"algo_name",
}


class AprioriConfig(BaseArConfig):
algo_name: Literal[ArAlgoName.Apriori]

has_tid: Annotated[bool, Field(description=descriptions["has_tid"])]
minconf: Annotated[float, Field(ge=0, le=1, description=descriptions["minconf"])]
minsup: Annotated[float, Field(ge=0, le=1, description=descriptions["minsup"])]
input_format: Annotated[
str,
Literal["singular", "tabular"],
Field(description=descriptions["input_format"]),
]
item_column_index: Annotated[
int, Field(ge=0, description=descriptions["item_column_index"])
]
tid_column_index: Annotated[
int, Field(ge=0, description=descriptions["tid_column_index"])
]


OneOfArAlgoConfig = Annotated[
AprioriConfig,
Field(discriminator="algo_name"),
]
8 changes: 8 additions & 0 deletions internal/domain/task/value_objects/ar/algo_descriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
descriptions = {
"has_tid": "Indicates that the first column contains the transaction IDs",
"minconf": "Minimum confidence value (between 0 and 1)",
"input_format": "Format of the input dataset for AR mining",
"item_column_index": "Index of the column where an item name is stored",
"minsup": "Minimum support value (between 0 and 1)",
"tid_column_index": "Index of the column where a TID is stored",
}
5 changes: 5 additions & 0 deletions internal/domain/task/value_objects/ar/algo_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from enum import StrEnum, auto


class ArAlgoName(StrEnum):
Apriori = auto()
21 changes: 21 additions & 0 deletions internal/domain/task/value_objects/ar/algo_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel
from desbordante.ar import ARStrings, ArIDs


class ArModel(BaseModel):
@classmethod
def from_ar(cls, ar: ARStrings):
return cls(confidence=ar.confidence, left=ar.left, right=ar.right)

@classmethod
def from_ar_ids(cls, ar_id: ArIDs):
return cls(confidence=ar_id.confidence, left=ar_id.left, right=ar_id.right)

confidence: float
left: list[str]
right: list[str]


class ArAlgoResult(BaseModel):
ars: list[ArModel]
ar_ids: list[ArModel]
2 changes: 1 addition & 1 deletion internal/domain/task/value_objects/primitive_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class PrimitiveName(StrEnum):
fd = auto()
afd = auto()
# ar = auto()
ar = auto()
ac = auto()
ind = auto()
aind = auto()
Expand Down

0 comments on commit dee20b8

Please sign in to comment.