Skip to content

Commit

Permalink
다시 롤백
Browse files Browse the repository at this point in the history
  • Loading branch information
shing100 committed Aug 4, 2024
1 parent 18a8a16 commit afa820f
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 196 deletions.
183 changes: 90 additions & 93 deletions transformers_openai_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,27 @@
import re
import time
import uuid
import logging
import torch
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Callable, Mapping, Optional
from flask import Flask, make_response, request, abort, jsonify
from functools import wraps
from .models import CausalLM, Model, Seq2Seq
from .metrics import Metrics
from .utils import apply_chat_template, load_chat_template

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = Flask(__name__)
models: Dict[str, Model] = {}
metrics: Optional[Metrics] = None
models = {}
metrics: Optional[Metrics]

def check_token(f: Callable) -> Callable:
def extract_assistant_response(text):
match = re.search(r'assistant\s*\n([\s\S]*)', text, re.IGNORECASE)
if match:
return match.group(1).strip()
return "Assistant's response not found."


def check_token(f: Callable):
@wraps(f)
def decorator(*args, **kwargs):
bearer_tokens = app.config.get('BEARER_TOKENS')
Expand All @@ -31,63 +34,61 @@ def decorator(*args, **kwargs):
token = authorization[7:]
if token in bearer_tokens:
return f(*args, **kwargs)
logger.warning("Invalid token attempt")
return make_response(jsonify({'message': 'Invalid token'}), 401)
return make_response(jsonify({
'message': 'Invalid token'
}), 401)

return decorator


@app.route('/v1/chat/completions', methods=['POST'])
@check_token
def chat_completion():
try:
data = request.json
model_name = data.get('model')
messages = data.get('messages', [])

if not model_name or not messages:
logger.error("Missing required parameters: model or messages")
return jsonify({"error": "model and messages are required"}), 400

model: Model = models.get(model_name)
if not model:
logger.error(f"Model {model_name} not found")
return jsonify({"error": f"Model {model_name} not found"}), 404

if not isinstance(model, CausalLM):
logger.error(f"Model {model_name} does not support chat completions")
return jsonify({"error": f"Model {model_name} does not support chat completions"}), 400

# Apply chat template
prompt = apply_chat_template(model.chat_template, messages)

# Generate response
response = model.generate(prompt)

result = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response['choices'][0]['text'],
},
"finish_reason": "stop"
}],
"usage": response['usage']
}

# Update metrics if enabled
global metrics
if metrics is not None:
metrics.update(result)
data = request.json
model_name = data.get('model')
messages = data.get('messages', [])

if not model_name or not messages:
return jsonify({"error": "model and messages are required"}), 400

model: Model = models.get(model_name)
if not model:
return jsonify({"error": f"Model {model_name} not found"}), 404

if not isinstance(model, CausalLM):
return jsonify({"error": f"Model {model_name} does not support chat completions"}), 400

# Apply chat template
prompt = apply_chat_template(model.chat_template, messages)

# Generate response
response = model.generate(prompt)

# Extract only the assistant's response
assistant_response = extract_assistant_response(response['text'])

result = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": assistant_response,
},
"finish_reason": "stop"
}],
"usage": response['usage']
}

# Update metrics if enabled
global metrics
if metrics is not None:
metrics.update(result)

return jsonify(result)
except Exception as e:
logger.exception("An error occurred during chat completion")
return jsonify({"error": str(e)}), 500
return jsonify(result)


def convert_model_config(val: Optional[Mapping[str, Any]]) -> Mapping[str, Any]:
Expand Down Expand Up @@ -186,40 +187,36 @@ def metrics_():


def make_transformers_openai_api(config_path: str) -> Flask:
try:
app.config.from_file(config_path, load=json.load)

if app.config.get('METRICS', 1) != 0:
global metrics
metrics = Metrics()

for mapping, config in app.config['MODELS'].items():
if not config.get('ENABLED', True):
continue
model_config = convert_model_config(config.get('MODEL_CONFIG', {}))
model_device = config.get('MODEL_DEVICE', 'cuda')
tokenizer_config = convert_tokenizer_config(config.get('TOKENIZER_CONFIG', {}))
tokenizer_device = config.get('TOKENIZER_DEVICE', 'cuda')
generate_config = convert_generate_config(config.get('GENERATE_CONFIG', {}))
decode_config = convert_decode_config(config.get('DECODE_CONFIG', {}))

if config['TYPE'] == 'Seq2Seq':
models[mapping] = Seq2Seq(
config['NAME'], model_config, model_device, tokenizer_config,
tokenizer_device, generate_config, decode_config
)
elif config['TYPE'] == 'CausalLM':
chat_template_name = config.get('CHAT_TEMPLATE')
chat_template = load_chat_template(chat_template_name) if chat_template_name else ''
models[mapping] = CausalLM(
config['NAME'], model_config, model_device, tokenizer_config,
tokenizer_device, generate_config, decode_config, chat_template
)
app.config.from_file(config_path, load=json.load)

if app.config.get('METRICS', 1) != 0:
global metrics
metrics = Metrics()

for mapping, config in app.config['MODELS'].items():
if config.get('ENABLED', True) == False:
continue
model_config = convert_model_config(config.get('MODEL_CONFIG'))
model_device = config.get('MODEL_DEVICE', 'cuda')
tokenizer_config = convert_tokenizer_config(
config.get('TOKENIZER_CONFIG'))
tokenizer_device = config.get('TOKENIZER_DEVICE', 'cuda')
generate_config = convert_generate_config(
config.get('GENERATE_CONFIG'))
decode_config = convert_decode_config(
config.get('DECODE_CONFIG'))
if config['TYPE'] == 'Seq2Seq':
models[mapping] = Seq2Seq(
config['NAME'], model_config, model_device, tokenizer_config, tokenizer_device, generate_config, decode_config)
elif config['TYPE'] == 'CausalLM':
chat_template_name = config.get('CHAT_TEMPLATE')
if chat_template_name:
chat_template = load_chat_template(chat_template_name)
else:
raise ValueError(f'Unknown model type {config["TYPE"]}')
chat_template = ''
models[mapping] = CausalLM(
config['NAME'], model_config, model_device, tokenizer_config, tokenizer_device, generate_config, decode_config, chat_template)
else:
raise RuntimeError(f'Unknown model type {config["TYPE"]}')

logger.info(f"Loaded {len(models)} models successfully")
return app
except Exception as e:
logger.exception("Failed to initialize the application")
raise
return app
110 changes: 7 additions & 103 deletions transformers_openai_api/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional, Tuple, Dict
from typing import Any, List, Mapping, Optional
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
from .utils import apply_chat_template
import torch
import torch.distributed as dist
import logging

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -135,110 +133,20 @@ def __init__(self, pretrained_model_name_or_path: str,
generate_config: Mapping[str, Any],
decode_config: Mapping[str, Any],
chat_template: str) -> None:

if torch.cuda.is_available():
self.n_gpu = torch.cuda.device_count()
self.model_device = f"cuda:{torch.cuda.current_device()}"
else:
self.n_gpu = 0
self.model_device = "cpu"

logger.info(f"Using device: {self.model_device}")

self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **model_config)

if self.n_gpu > 1:
logger.info(f"Using {self.n_gpu} GPUs")
self.model = torch.nn.DataParallel(self.model)

self.model.to(self.model_device)

if model_device is not None:
self.model = self.model.to(model_device)
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, **tokenizer_config)
self.generate_config = generate_config
self.decode_config = decode_config
self.tokenizer_device = tokenizer_device
self.chat_template = chat_template

if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.model.config.eos_token_id

logger.info(f"Model device: {self.model_device}")
logger.info(f"Tokenizer pad_token: {self.tokenizer.pad_token}")
logger.info(f"Model pad_token_id: {self.model.config.pad_token_id}")

# KV 캐시 초기화
self.kv_cache = None

def _get_kv_cache(self, input_ids: torch.Tensor) -> Tuple[
torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if self.kv_cache is None:
return input_ids, None

cache_len = self.kv_cache[0].shape[2]
if cache_len > 0:
input_ids = input_ids[:, cache_len:]

return input_ids, self.kv_cache

def _update_kv_cache(self, past_key_values: Tuple[torch.Tensor, torch.Tensor]) -> None:
if self.kv_cache is None:
self.kv_cache = past_key_values
else:
self.kv_cache = tuple(torch.cat([c.to(self.model_device), p.to(self.model_device)], dim=2) for c, p in
zip(self.kv_cache, past_key_values))

@torch.no_grad()
def generate(self, input_text: str) -> Mapping[str, Any]:
tokenized_input = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

input_ids = tokenized_input.input_ids.to(self.model_device)
attention_mask = tokenized_input.attention_mask.to(self.model_device)

device_generate_config = {
k: v.to(self.model_device) if isinstance(v, torch.Tensor) else v
for k, v in self.generate_config.items()
}

logger.info(f"Input device: {input_ids.device}")
logger.info(f"Model device: {next(self.model.parameters()).device}")

# KV 캐시 적용
input_ids, past_key_values = self._get_kv_cache(input_ids)
if past_key_values is not None:
past_key_values = tuple(p.to(self.model_device) for p in past_key_values)

if isinstance(self.model, torch.nn.DataParallel):
output = self.model.module.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
past_key_values=past_key_values,
**device_generate_config
)
else:
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
past_key_values=past_key_values,
**device_generate_config
)

# KV 캐시 업데이트
if hasattr(self.model, 'module'):
model = self.model.module
else:
model = self.model

if hasattr(model, 'get_encoder'):
self._update_kv_cache(model.get_encoder().past_key_values)
else:
self._update_kv_cache(output.past_key_values)

input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.tokenizer_device)
output = self.model.generate(input_ids, **self.generate_config)
response = self.tokenizer.decode(output[0], **self.decode_config)

return {
Expand All @@ -252,8 +160,4 @@ def generate(self, input_text: str) -> Mapping[str, Any]:

def chat_completions(self, messages: List[Mapping[str, str]]) -> Mapping[str, Any]:
prompt = apply_chat_template(self.chat_template, messages)
return self.generate(prompt)

def reset_kv_cache(self) -> None:
"""KV 캐시를 리셋합니다."""
self.kv_cache = None
return self.generate(prompt)

0 comments on commit afa820f

Please sign in to comment.