Skip to content

Commit 8afe05b

Browse files
updated modal scripts, added serving for non-optimized but more accurate model
1 parent 81a75f6 commit 8afe05b

File tree

3 files changed

+196
-95
lines changed

3 files changed

+196
-95
lines changed

inference/ephemeral.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

inference/serving-non-optimized.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import modal
2+
3+
MODEL_ALIAS = "gemma2"
4+
MODELS_DIR = f"/{MODEL_ALIAS}"
5+
VOLUME_NAME = f"{MODEL_ALIAS}"
6+
7+
MODEL_NAME = "google/shieldgemma-2b"
8+
9+
image = (
10+
modal.Image.debian_slim(python_version="3.10")
11+
.pip_install(
12+
"vllm",
13+
"torch",
14+
"transformers",
15+
"ray",
16+
"huggingface_hub",
17+
"hf-transfer",
18+
"accelerate"
19+
)
20+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
21+
)
22+
23+
app = modal.App(f"{MODEL_ALIAS}-non-optimized", image=image)
24+
25+
26+
try:
27+
volume = modal.Volume.lookup(VOLUME_NAME, create_if_missing=False)
28+
except modal.exception.NotFoundError:
29+
raise Exception("Download models first with modal run download_model.py")
30+
31+
32+
GPU_CONFIG = modal.gpu.A10G(count=1)
33+
34+
35+
@app.cls(gpu=GPU_CONFIG, secrets=[modal.Secret.from_name("huggingface-secret")], volumes={MODELS_DIR: volume},)
36+
class Model:
37+
@modal.enter()
38+
def load(self):
39+
import torch
40+
from transformers import AutoTokenizer, AutoModelForCausalLM
41+
42+
model_directory = MODELS_DIR + "/" + MODEL_NAME
43+
self.model = AutoModelForCausalLM.from_pretrained(
44+
model_directory,
45+
device_map="auto",
46+
torch_dtype=torch.bfloat16,
47+
)
48+
self.tokenizer = AutoTokenizer.from_pretrained(model_directory)
49+
50+
@modal.method()
51+
def generate(self, chat, is_user_prompt=True, enforce_policies=None):
52+
import torch
53+
from torch.nn.functional import softmax
54+
55+
tokenizer = self.tokenizer
56+
model = self.model
57+
58+
print(f"Model: Loaded on device")
59+
print(f"Model: Chat {chat}")
60+
61+
INPUT_POLICIES = {
62+
"NO_DANGEROUS_CONTENT": "\"No Dangerous Content\": The prompt shall not contain or seek generation of content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",
63+
"NO_HARASSMENT": "\"No Harassment\": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).",
64+
"NO_HATE_SPEECH": "\"No Hate Speech\": The prompt shall not contain or seek generation of content that targeting identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. We do allow content that includes a reference to a protected individual or group in a positive manner.",
65+
"NO_SEXUAL_CONTENT": "\"No Sexually Explicit Information\": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted."
66+
}
67+
68+
OUTPUT_POLICIES = {
69+
"NO_DANGEROUS_CONTENT": "\"No Dangerous Content\": The chatbot shall not generate content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).",
70+
"NO_HARASSMENT": "\"No Harassment\": The chatbot shall not generate content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).",
71+
"NO_HATE_SPEECH": "\"No Hate Speech\": The chatbot shall not generate content that targeting identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. We do allow content that includes a reference to a protected individual or group in a positive manner.",
72+
"NO_SEXUAL_CONTENT": "\"No Sexually Explicit Information\": The chatbot shall not generate content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted."
73+
}
74+
75+
contructed_guideline = ""
76+
selected_policy_set = INPUT_POLICIES if is_user_prompt else OUTPUT_POLICIES
77+
78+
if is_user_prompt:
79+
enforce_policies = enforce_policies or ["NO_DANGEROUS_CONTENT", "NO_HARASSMENT", "NO_HATE_SPEECH", "NO_SEXUAL_CONTENT"]
80+
for policy in enforce_policies:
81+
if contructed_guideline == "":
82+
contructed_guideline = selected_policy_set[policy]
83+
else:
84+
contructed_guideline = contructed_guideline + "\n * " + selected_policy_set[policy]
85+
86+
87+
inputs = tokenizer.apply_chat_template(chat, guideline=contructed_guideline, return_tensors="pt", return_dict=True).to(model.device)
88+
89+
chat_template_display = tokenizer.apply_chat_template(chat, tokenize=False, guideline=contructed_guideline)
90+
print(f"Model: Chat Template: {chat_template_display}")
91+
92+
with torch.no_grad():
93+
logits = model(**inputs).logits
94+
95+
# Extract the logits for the Yes and No tokens
96+
vocab = tokenizer.get_vocab()
97+
selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]
98+
99+
# Convert these logits to a probability with softmax
100+
probabilities = softmax(selected_logits, dim=0)
101+
102+
# Return probability of 'Yes'
103+
score = probabilities[0].item()
104+
print(score) # 0.7310585379600525
105+
106+
print(f"Model: Score: {score}")
107+
108+
return score
109+
110+
@modal.exit()
111+
def stop_engine(self):
112+
if GPU_CONFIG.count > 1:
113+
import ray
114+
115+
ray.shutdown()
116+
117+
118+
@app.function(
119+
keep_warm=1,
120+
allow_concurrent_inputs=10,
121+
timeout=60 * 10,
122+
secrets=[modal.Secret.from_dotenv()],
123+
volumes={MODELS_DIR: volume}
124+
)
125+
@modal.asgi_app(label="fa-hg-sg2b")
126+
def tgi_app():
127+
import os
128+
129+
import fastapi
130+
from fastapi.middleware.cors import CORSMiddleware
131+
132+
from typing import List
133+
from pydantic import BaseModel
134+
135+
TOKEN = os.getenv("TOKEN")
136+
if TOKEN is None:
137+
raise ValueError("Please set the TOKEN environment variable")
138+
139+
volume.reload() # ensure we have the latest version of the weights
140+
141+
app = fastapi.FastAPI()
142+
143+
http_bearer = fastapi.security.HTTPBearer(
144+
scheme_name="Bearer Token",
145+
description="See code for authentication details.",
146+
)
147+
app.add_middleware(
148+
CORSMiddleware,
149+
allow_origins=["*"],
150+
allow_credentials=True,
151+
allow_methods=["*"],
152+
allow_headers=["*"],
153+
)
154+
155+
async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
156+
if api_key.credentials != TOKEN:
157+
raise fastapi.HTTPException(
158+
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
159+
detail="Invalid authentication credentials",
160+
)
161+
return {"username": "authenticated_user"}
162+
163+
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
164+
165+
class ChatMessages(BaseModel):
166+
role: str
167+
content: str
168+
169+
class ChatClassificationRequestBody(BaseModel):
170+
chat: List[ChatMessages]
171+
172+
173+
@router.post("/v1/chat/classification")
174+
async def chat_classification_response(body: ChatClassificationRequestBody):
175+
chat = body.model_dump().get("chat",[])
176+
177+
print("Serving request for chat classification...")
178+
print(f"Chat: {chat}")
179+
score = Model().generate.remote(chat)
180+
181+
is_unsafe = score > 0.5
182+
183+
return {
184+
"class": "unsafe" if is_unsafe else "safe",
185+
"score": score,
186+
}
187+
188+
189+
app.include_router(router)
190+
return app
191+
192+
193+
# @app.local_entrypoint()
194+
# def main():
195+
# model = Model()
196+
# model.generate.remote()
File renamed without changes.

0 commit comments

Comments
 (0)