Skip to content

Commit

Permalink
Merge pull request #17 from dataforgoodfr/feat/model-repository
Browse files Browse the repository at this point in the history
Add model repository
  • Loading branch information
samuelrince authored Mar 19, 2024
2 parents 1316e3a + 2fa220c commit 473a6bf
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 52 deletions.
22 changes: 19 additions & 3 deletions genai_impact/compute_impacts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
from typing import Union

from pydantic import BaseModel

ENERGY_PROFILE = 1.17e-4
GPU_ENERGY_ALPHA = 8.91e-5
GPU_ENERGY_BETA = 1.43e-3


class Impacts(BaseModel):
energy: float
energy_range: tuple[float, float]
energy_unit: str = "Wh"


def compute_llm_impact(
model_parameter_count: float,
model_parameter_count: Union[float, tuple[float, float]],
output_token_count: int,
) -> Impacts:
return Impacts(energy=ENERGY_PROFILE * model_parameter_count * output_token_count)
if isinstance(model_parameter_count, (tuple, list)):
# TODO: check tuple or list validity (min, max)
energy_min = output_token_count * (GPU_ENERGY_ALPHA * model_parameter_count[0] + GPU_ENERGY_BETA)
energy_max = output_token_count * (GPU_ENERGY_ALPHA * model_parameter_count[1] + GPU_ENERGY_BETA)
energy_avg = (energy_min + energy_max) / 2
else:
energy_avg = output_token_count * (GPU_ENERGY_ALPHA * model_parameter_count + GPU_ENERGY_BETA)
energy_min = energy_avg
energy_max = energy_avg
return Impacts(
energy=energy_avg,
energy_range=(energy_min, energy_max)
)
35 changes: 35 additions & 0 deletions genai_impact/data/models.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
provider,name,total_parameters,active_parameters,warnings,sources
openai,gpt-3.5-turbo,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-0125,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-0301,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-0613,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-1106,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-16k,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-16k-0613,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-instruct,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-3.5-turbo-instruct-0914,20;70,20;70,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4-0125-preview,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4-0613,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4-1106-preview,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4-turbo-preview,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
openai,gpt-4-vision-preview,1760,220;880,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit
mistralai,mistral-large-2402,540,135;540,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-large-latest,540,135;540,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-medium,70;180,45;180,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-medium-2312,70;180,45;180,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-medium-latest,70;180,45;180,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-small,46.7,12.9,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-small-2312,46.7,12.9,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-small-2402,46.7,12.9,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-small-latest,46.7,12.9,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
mistralai,mistral-tiny,7.3,7.3,,https://docs.mistral.ai/models/#sizes
mistralai,mistral-tiny-2312,7.3,7.3,,https://docs.mistral.ai/models/#sizes
mistralai,open-mistral-7b,7.3,7.3,,https://docs.mistral.ai/models/#sizes
mistralai,open-mixtral-8x7b,46.7,46.7,,https://docs.mistral.ai/models/#sizes
anthropic,claude-3-opus-20240229,2000,250;1000,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
anthropic,claude-3-sonnet-20240229,800,100;400,model_architecture_not_released,
anthropic,claude-3-haiku-20240307,300,75;150,model_architecture_not_released,
anthropic,claude-2.1,130,130,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
anthropic,claude-2.0,130,130,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
anthropic,claude-instant-1.2,20;70,20;70,model_architecture_not_released,
84 changes: 84 additions & 0 deletions genai_impact/model_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from csv import DictReader
from dataclasses import dataclass
from enum import Enum
from typing import Optional


class Providers(Enum):
anthropic = "anthropic"
mistralai = "mistralai"
openai = "openai"


class Warnings(Enum):
model_architecture_not_released = "model_architecture_not_released"


@dataclass
class Model:
provider: str
name: str
total_parameters: Optional[float] = None
active_parameters: Optional[float] = None
total_parameters_range: Optional[tuple[float, float]] = None
active_parameters_range: Optional[tuple[float, float]] = None
warnings: Optional[list[str]] = None
sources: Optional[list[str]] = None


class ModelRepository:

def __init__(self, models: list[Model]) -> None:
self.__models = models

def find_model(self, provider: str, model_name: str) -> Optional[Model]:
for model in self.__models:
if model.provider == provider and model_name == model.name:
return model
return None

@classmethod
def from_csv(cls, filepath: Optional[str] = None) -> "ModelRepository":
if filepath is None:
filepath = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "models.csv")
models = []
with open(filepath) as fd:
csv = DictReader(fd)
for row in csv:
total_parameters = None
total_parameters_range = None
if ";" in row["total_parameters"]:
total_parameters_range = [float(p) for p in row["total_parameters"].split(";")]
elif row["total_parameters"] != "":
total_parameters = float(row["total_parameters"])

active_parameters = None
active_parameters_range = None
if ";" in row["active_parameters"]:
active_parameters_range = [float(p) for p in row["active_parameters"].split(";")]
elif row["active_parameters"] != "":
active_parameters = float(row["active_parameters"])

warnings = None
if row["warnings"] != "":
warnings = [Warnings(w).name for w in row["warnings"].split(";")]

sources = None
if row["sources"] != "":
sources = row["sources"].split(";")

models.append(Model(
provider=Providers(row["provider"]).name,
name=row["name"],
total_parameters=total_parameters,
active_parameters=active_parameters,
total_parameters_range=total_parameters_range,
active_parameters_range=active_parameters_range,
warnings=warnings,
sources=sources
))
return cls(models)


models = ModelRepository.from_csv()
28 changes: 12 additions & 16 deletions genai_impact/tracers/anthropic_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact
from genai_impact.model_repository import models

try:
from anthropic import Anthropic as _Anthropic
Expand All @@ -11,31 +12,26 @@
_Anthropic = object()
_Message = object()

_MODEL_SIZES = {
"claude-3-haiku-20240307": 10,
"claude-3-sonnet-20240229": 10, # fake data
"claude-3-opus-20240229": 440, # fake data
}


class Message(_Message):
impacts: Impacts


def _set_impacts(response: Message) -> Impacts:
model_size = _MODEL_SIZES.get(response.model)
output_tokens = response.usage.output_tokens
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return impacts


def anthropic_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> Message:
response = wrapped(*args, **kwargs)
impacts = _set_impacts(response)
model = models.find_model(provider="anthropic", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for anthropic provider.")
return response
output_tokens = response.usage.output_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size,
output_token_count=output_tokens
)
return Message(**response.model_dump(), impacts=impacts)


Expand Down
19 changes: 9 additions & 10 deletions genai_impact/tracers/mistralai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact
from genai_impact.model_repository import models

try:
from mistralai.client import MistralClient as _MistralClient
Expand All @@ -14,14 +15,6 @@
_ChatCompletionResponse = object()


_MODEL_SIZES = {
"mistral-tiny": 7.3,
"mistral-small": 12.9, # mixtral active parameters count
"mistral-medium": 70,
"mistral-large": 440,
}


class ChatCompletionResponse(_ChatCompletionResponse):
impacts: Impacts

Expand All @@ -30,10 +23,16 @@ def mistralai_chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
model = models.find_model(provider="mistralai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for mistralai provider.")
return response
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
model_parameter_count=model_size,
output_token_count=output_tokens
)
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)

Expand Down
31 changes: 9 additions & 22 deletions genai_impact/tracers/openai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,7 @@
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact

_MODEL_SIZES = {
"gpt-4-0125-preview": None,
"gpt-4-turbo-preview": None,
"gpt-4-1106-preview": None,
"gpt-4-vision-preview": None,
"gpt-4": 440,
"gpt-4-0314": 440,
"gpt-4-0613": 440,
"gpt-4-32k": 440,
"gpt-4-32k-0314": 440,
"gpt-4-32k-0613": 440,
"gpt-3.5-turbo": 70,
"gpt-3.5-turbo-16k": 70,
"gpt-3.5-turbo-0301": 70,
"gpt-3.5-turbo-0613": 70,
"gpt-3.5-turbo-1106": 70,
"gpt-3.5-turbo-0125": 70,
"gpt-3.5-turbo-16k-0613": 70,
}
from genai_impact.model_repository import models


class ChatCompletion(_ChatCompletion):
Expand All @@ -35,10 +16,16 @@ def openai_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
model = models.find_model(provider="openai", model_name=response.model)
if model is None:
# TODO: Replace with proper logging
print(f"Could not find model `{response.model}` for openai provider.")
return response
output_tokens = response.usage.completion_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
model_parameter_count=model_size,
output_token_count=output_tokens
)
return ChatCompletion(**response.model_dump(), impacts=impacts)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ ignore = [
"PLR0913",
"RET504",
"RET505",
"COM812"
"COM812",
"PTH"
]

fixable = [
Expand Down
24 changes: 24 additions & 0 deletions tests/test_model_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from genai_impact.model_repository import ModelRepository, Model


def test_create_model_repository_default():
models = ModelRepository.from_csv()
assert isinstance(models, ModelRepository)
assert models.find_model(provider="openai", model_name="gpt-3.5-turbo") is not None


def test_create_model_repository_from_scratch():
models = ModelRepository([
Model(provider="provider-test", name="model-test")
])
assert models.find_model(provider="provider-test", model_name="model-test")


def test_find_unknown_provider():
models = ModelRepository.from_csv()
assert models.find_model(provider="provider-test", model_name="gpt-3.5-turbo") is None


def test_find_unknown_model_name():
models = ModelRepository.from_csv()
assert models.find_model(provider="openai", model_name="model-test") is None

0 comments on commit 473a6bf

Please sign in to comment.