diff --git a/modeling/llm_post_training/data_process/README.md b/modeling/llm_post_training/data_process/README.md new file mode 100644 index 0000000..15e986b --- /dev/null +++ b/modeling/llm_post_training/data_process/README.md @@ -0,0 +1,341 @@ +# Knowledge Distillation System + +A generic, pluggable system for knowledge distillation using different LLM APIs and configurable prompts for various data generation tasks. + +## Features + +- **Generic LLM API Interface**: Easy switching between different LLM providers (OpenAI, Anthropic, etc.) +- **Batch Processing**: Efficient processing of large datasets with OpenAI batch inference API +- **Configurable Prompts**: YAML-based prompt templates for different tasks +- **Error Handling**: Robust error handling with retry logic and progress tracking +- **Flexible Output**: Support for JSON, JSONL, and CSV output formats +- **Intermediate Results**: Save progress during long-running tasks + +## Architecture + +``` +data_process/ +├── llm_api_interface.py # Abstract base class for LLM providers +├── prompt_manager.py # YAML-based prompt management +├── knowledge_distillation.py # Main orchestrator +├── providers/ # LLM provider implementations +│ ├── __init__.py +│ └── openai_provider.py # OpenAI API implementation +├── prompt_configs/ # YAML prompt configurations +│ ├── instruction_generation.yaml +│ ├── qa_generation.yaml +│ └── summarization.yaml +├── example_config.json # Sample configuration +├── requirements.txt # Dependencies +└── README.md # This file +``` + +## Quick Start + +### 1. Install Dependencies + +```bash +pip install openai pyyaml jinja2 pandas +``` + +### 2. Set Up API Keys + +Create a `.env` file in the parent directory: + +```bash +OPENAI_API_KEY=your-openai-api-key-here +``` + +### 3. Create Input Data + +Create a JSON file with your input data: + +```json +[ + { + "text": "Your source text here", + "domain": "optional domain context", + "difficulty": "optional difficulty level" + } +] +``` + +### 4. Configure the System + +Create a configuration file (e.g., `my_config.json`): + +```json +{ + "distillation": { + "provider_type": "openai", + "provider_config": { + "api_key": "your-openai-api-key-here", + "model": "gpt-3.5-turbo", + "use_batch_api": true, + "batch_timeout": 3600 + }, + "task_name": "instruction_generation", + "prompt_name": "basic_instruction", + "input_file": "input_data.json", + "output_file": "output_data.json", + "batch_size": 10, + "max_concurrent": 5, + "retry_attempts": 3, + "delay_between_batches": 1.0, + "save_intermediate": true, + "intermediate_dir": "intermediate_results" + }, + "prompt_config": { + "config_dir": "prompt_configs" + } +} +``` + +### 5. Run Knowledge Distillation + +```bash +python knowledge_distillation.py --config my_config.json +``` + +## Available Tasks + +### Instruction Generation + +Generate instruction-following training data from raw text. + +**Available Prompts:** + +- `basic_instruction`: Generate basic instruction-following examples +- `creative_task`: Generate creative tasks from text content +- `analysis_prompt`: Generate analytical tasks from text + +**Input Format:** + +```json +{ + "text": "Source text to generate instructions from", + "domain": "Optional domain context", + "difficulty": "Optional difficulty level" +} +``` + +### Q&A Generation + +Generate question-answer pairs from text content. + +**Available Prompts:** + +- `factual_qa`: Generate factual question-answer pairs +- `analytical_qa`: Generate analytical question-answer pairs +- `application_qa`: Generate application-based question-answer pairs + +**Input Format:** + +```json +{ + "text": "Source text to generate Q&A from", + "context": "Optional additional context", + "question_type": "Type of questions to generate" +} +``` + +### Summarization + +Generate summaries and abstractive content from source text. + +**Available Prompts:** + +- `extractive_summary`: Generate extractive summary by selecting key sentences +- `abstractive_summary`: Generate abstractive summary by paraphrasing +- `bullet_point_summary`: Generate structured bullet-point summary + +**Input Format:** + +```json +{ + "text": "Source text to summarize", + "length": "Desired summary length (short/medium/long)", + "focus": "Optional focus area for summary" +} +``` + +## Creating Custom Prompts + +### 1. Create YAML Configuration + +Create a new YAML file in `prompt_configs/`: + +```yaml +task_name: "my_custom_task" +description: "Description of your custom task" + +input_format: + field1: "string" + field2: "string" + +output_format: + output_field: "string" + +prompts: + my_prompt: + description: "Description of your prompt" + variables: ["field1", "field2"] + template: | + Your prompt template here using Jinja2 syntax. + + Field 1: {{ field1 }} + Field 2: {{ field2 }} + + Please provide your response in the following format: + Output: [your response here] + +metadata: + version: "1.0" + author: "Your Name" + created: "2024-01-01" +``` + +### 2. Use Custom Prompt + +Update your configuration to use the new task and prompt: + +```json +{ + "distillation": { + "task_name": "my_custom_task", + "prompt_name": "my_prompt" + // ... other configuration + } +} +``` + +## Adding New LLM Providers + +### 1. Create Provider Class + +Create a new file in `providers/` (e.g., `anthropic_provider.py`): + +```python +from ..llm_api_interface import LLMAPIProvider, LLMRequest, LLMResponse, BatchLLMRequest, BatchLLMResponse + +class AnthropicProvider(LLMAPIProvider): + def __init__(self, config): + super().__init__(config) + # Initialize Anthropic client + + async def generate_single(self, request: LLMRequest) -> LLMResponse: + # Implement single request generation + + async def generate_batch(self, batch_request: BatchLLMRequest) -> BatchLLMResponse: + # Implement batch request generation + + def validate_config(self) -> bool: + # Validate configuration + return True +``` + +### 2. Register Provider + +Update `llm_api_interface.py` to include your new provider: + +```python +def create_llm_provider(provider_type: str, config: Dict[str, Any]) -> LLMAPIProvider: + # ... existing code ... + elif provider_type == 'anthropic': + from .providers.anthropic_provider import AnthropicProvider + return AnthropicProvider(config) +``` + +## Configuration Options + +### Distillation Configuration + +| Option | Type | Default | Description | +| ----------------------- | ------ | ---------------------- | -------------------------------------------- | +| `provider_type` | string | - | LLM provider type (openai, anthropic, etc.) | +| `provider_config` | dict | - | Provider-specific configuration | +| `task_name` | string | - | Name of the task to use | +| `prompt_name` | string | - | Name of the prompt to use | +| `input_file` | string | - | Path to input data file | +| `output_file` | string | - | Path to output data file | +| `batch_size` | int | 10 | Number of requests per batch | +| `max_concurrent` | int | 5 | Maximum concurrent requests | +| `retry_attempts` | int | 3 | Number of retry attempts for failed requests | +| `delay_between_batches` | float | 1.0 | Delay between batches (seconds) | +| `save_intermediate` | bool | true | Save intermediate results | +| `intermediate_dir` | string | "intermediate_results" | Directory for intermediate results | + +### OpenAI Provider Configuration + +| Option | Type | Default | Description | +| --------------- | ------ | --------------- | -------------------------------------- | +| `api_key` | string | - | OpenAI API key | +| `model` | string | "gpt-3.5-turbo" | Model to use | +| `base_url` | string | None | Custom base URL | +| `use_batch_api` | bool | true | Use batch inference API | +| `batch_timeout` | int | 3600 | Timeout for batch operations (seconds) | + +## Error Handling + +The system includes comprehensive error handling: + +- **Retry Logic**: Automatic retry for failed requests with exponential backoff +- **Progress Tracking**: Save intermediate results to resume from failures +- **Error Logging**: Detailed logging of errors and warnings +- **Graceful Degradation**: Fallback to single requests if batch API fails + +## Performance Optimization + +- **Batch Processing**: Use OpenAI batch inference API for cost efficiency +- **Concurrent Requests**: Configurable concurrency limits +- **Rate Limiting**: Built-in rate limiting to respect API limits +- **Intermediate Results**: Save progress to avoid reprocessing + +## Cost Estimation + +The OpenAI provider includes cost estimation: + +```python +# Get cost estimate for a list of requests +cost_estimate = provider.get_cost_estimate(requests) +print(f"Estimated cost: ${cost_estimate['total_cost']:.4f}") +``` + +## Examples + +See `example_usage.py` for comprehensive examples of: + +- Instruction generation +- Q&A generation +- Custom prompt creation +- Error handling +- Progress tracking + +## Troubleshooting + +### Common Issues + +1. **API Key Not Found**: Ensure your API key is set in the configuration +2. **Invalid Prompt Variables**: Check that all required variables are provided in input data +3. **Batch API Timeout**: Increase `batch_timeout` or use single requests +4. **Rate Limiting**: Reduce `max_concurrent` or increase `delay_between_batches` + +### Debug Mode + +Enable debug logging: + +```bash +python knowledge_distillation.py --config my_config.json --log-level DEBUG +``` + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Add your provider or prompt configuration +4. Add tests and documentation +5. Submit a pull request + +## License + +This project is licensed under the MIT License. diff --git a/modeling/llm_post_training/data_process/convert_silk_road_to_alpaca.py b/modeling/llm_post_training/data_process/convert_silk_road_to_alpaca.py deleted file mode 100644 index 0b2db01..0000000 --- a/modeling/llm_post_training/data_process/convert_silk_road_to_alpaca.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -""" -Convert silk-road/alpaca-data-gpt4-chinese dataset to standard alpaca format. - -Original format (6 columns): -- instruction_zh, input_zh, output_zh (Chinese) -- instruction, input, output (English) - -Output format (3 columns): -- instruction, input, output - -Each original row becomes 2 rows: one English, one Chinese. - -Usage: - python convert_silk_road_to_alpaca.py -""" - -import os -from pathlib import Path -import pandas as pd -from datasets import Dataset, load_dataset -from huggingface_hub import HfApi, login -from dotenv import load_dotenv - - -def load_hf_token() -> str: - """Load Hugging Face token from environment variables.""" - env_path = Path(__file__).parent.parent / ".env" - if env_path.exists(): - load_dotenv(env_path) - - hf_token = os.getenv("HF_TOKEN") - if not hf_token: - raise ValueError("HF_TOKEN not found. Please set it in .env file.") - return hf_token - - -def convert_to_alpaca_format(dataset: Dataset) -> Dataset: - """ - Convert 6-column dataset to 3-column alpaca format. - Each original row becomes 2 rows (English + Chinese). - """ - print("Converting dataset to alpaca format...") - - # Convert to pandas - df = pd.DataFrame(dataset) - - # Create English rows - english_df = pd.DataFrame( - {"instruction": df["instruction"], "input": df["input"], "output": df["output"]} - ) - - # Create Chinese rows - chinese_df = pd.DataFrame( - { - "instruction": df["instruction_zh"], - "input": df["input_zh"], - "output": df["output_zh"], - } - ) - - # Combine both datasets - alpaca_df = pd.concat([english_df, chinese_df], ignore_index=True) - - # Remove empty rows - alpaca_df = alpaca_df[ - (alpaca_df["instruction"].str.strip() != "") - & (alpaca_df["output"].str.strip() != "") - ].reset_index(drop=True) - - print(f"Converted {len(df)} rows to {len(alpaca_df)} rows") - return Dataset.from_pandas(alpaca_df) - - -def upload_to_huggingface(dataset: Dataset, repo_name: str, hf_token: str) -> str: - """Upload dataset to Hugging Face Hub.""" - print(f"Uploading to Hugging Face: {repo_name}") - - login(token=hf_token) - api = HfApi() - - # Get username - user_info = api.whoami(token=hf_token) - username = user_info["name"] - full_repo_id = f"{username}/{repo_name}" - - # Create and upload - api.create_repo( - repo_id=full_repo_id, repo_type="dataset", token=hf_token, exist_ok=True - ) - - dataset.push_to_hub(repo_id=full_repo_id, token=hf_token) - print(f"Uploaded to: https://huggingface.co/datasets/{full_repo_id}") - - return full_repo_id - - -def main(): - """Main conversion function.""" - try: - # Load token and dataset - hf_token = load_hf_token() - original_dataset = load_dataset("silk-road/alpaca-data-gpt4-chinese")["train"] - - print(f"Original dataset: {len(original_dataset)} rows") - - # Convert to alpaca format - alpaca_dataset = convert_to_alpaca_format(original_dataset) - - # Upload to Hugging Face - repo_id = upload_to_huggingface(alpaca_dataset, "alpaca-bilingual", hf_token) - - print(f"\n✅ Success!") - print(f"📊 Original: {len(original_dataset)} rows") - print(f"📊 Converted: {len(alpaca_dataset)} rows") - print(f"🔗 Dataset: https://huggingface.co/datasets/{repo_id}") - - except Exception as e: - print(f"❌ Error: {e}") - raise - - -if __name__ == "__main__": - main() diff --git a/modeling/llm_post_training/data_process/knowledge_distillation.py b/modeling/llm_post_training/data_process/knowledge_distillation.py new file mode 100644 index 0000000..32e0461 --- /dev/null +++ b/modeling/llm_post_training/data_process/knowledge_distillation.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +""" +Knowledge distillation orchestrator for LLM-based data generation. + +This script provides a generic interface for knowledge distillation using +different LLM APIs and configurable prompts for various data generation tasks. +""" + +import asyncio +import json +import logging +import argparse +from pathlib import Path +from typing import Dict, Any, List, Optional, Union +import pandas as pd +from dataclasses import dataclass, asdict +import time +from datetime import datetime + +from llm_api_interface import ( + LLMRequest, + LLMResponse, + BatchLLMRequest, + create_llm_provider, + LLMAPIError, +) +from prompt_manager import PromptManager + + +@dataclass +class DistillationConfig: + """Configuration for knowledge distillation process.""" + + provider_type: str + provider_config: Dict[str, Any] + task_name: str + prompt_name: str + input_file: str + output_file: str + batch_size: int = 10 + max_concurrent: int = 5 + retry_attempts: int = 3 + delay_between_batches: float = 1.0 + save_intermediate: bool = True + intermediate_dir: str = "intermediate_results" + + +@dataclass +class DistillationResult: + """Result of knowledge distillation process.""" + + total_processed: int + successful: int + failed: int + total_cost: Optional[float] = None + processing_time: Optional[float] = None + output_file: Optional[str] = None + errors: List[str] = None + + +class KnowledgeDistillationOrchestrator: + """ + Orchestrates the knowledge distillation process using LLM APIs. + + Handles batch processing, error recovery, progress tracking, and + result management for large-scale data generation tasks. + """ + + def __init__(self, config: DistillationConfig, prompt_manager: PromptManager): + """ + Initialize the orchestrator. + + Args: + config: Distillation configuration + prompt_manager: Prompt manager instance + """ + self.config = config + self.prompt_manager = prompt_manager + self.logger = logging.getLogger(self.__class__.__name__) + + # Initialize LLM provider + self.provider = create_llm_provider( + config.provider_type, config.provider_config + ) + + if not self.provider.validate_config(): + raise ValueError( + f"Invalid configuration for {config.provider_type} provider" + ) + + # Create output directories + self.output_path = Path(config.output_file) + self.output_path.parent.mkdir(parents=True, exist_ok=True) + + if config.save_intermediate: + self.intermediate_dir = Path(config.intermediate_dir) + self.intermediate_dir.mkdir(parents=True, exist_ok=True) + + # Statistics + self.stats = { + "total_processed": 0, + "successful": 0, + "failed": 0, + "errors": [], + "start_time": None, + "end_time": None, + } + + async def process_dataset(self) -> DistillationResult: + """ + Process the entire dataset through knowledge distillation. + + Returns: + Distillation result with statistics + """ + self.logger.info( + f"Starting knowledge distillation for task: {self.config.task_name}" + ) + self.stats["start_time"] = time.time() + + try: + # Load input data + input_data = self._load_input_data() + self.logger.info(f"Loaded {len(input_data)} input records") + + # Process in batches + results = [] + total_batches = ( + len(input_data) + self.config.batch_size - 1 + ) // self.config.batch_size + + for batch_idx in range(0, len(input_data), self.config.batch_size): + batch_data = input_data[batch_idx : batch_idx + self.config.batch_size] + batch_num = batch_idx // self.config.batch_size + 1 + + self.logger.info( + f"Processing batch {batch_num}/{total_batches} ({len(batch_data)} records)" + ) + + # Process batch + batch_results = await self._process_batch(batch_data, batch_num) + results.extend(batch_results) + + # Save intermediate results + if self.config.save_intermediate: + self._save_intermediate_results(results, batch_num) + + # Delay between batches + if batch_num < total_batches and self.config.delay_between_batches > 0: + await asyncio.sleep(self.config.delay_between_batches) + + # Save final results + self._save_final_results(results) + + # Calculate final statistics + self.stats["end_time"] = time.time() + result = self._create_result() + + self.logger.info(f"Knowledge distillation completed: {result}") + return result + + except Exception as e: + self.logger.error(f"Knowledge distillation failed: {e}") + self.stats["errors"].append(str(e)) + raise + + def _load_input_data(self) -> List[Dict[str, Any]]: + """ + Load input data from file. + + Returns: + List of input records + """ + input_path = Path(self.config.input_file) + + if not input_path.exists(): + raise FileNotFoundError(f"Input file not found: {input_path}") + + if input_path.suffix.lower() == ".json": + with open(input_path, "r", encoding="utf-8") as f: + return json.load(f) + elif input_path.suffix.lower() in [".csv", ".tsv"]: + df = pd.read_csv(input_path) + return df.to_dict("records") + elif input_path.suffix.lower() == ".jsonl": + data = [] + with open(input_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + data.append(json.loads(line)) + return data + else: + raise ValueError(f"Unsupported file format: {input_path.suffix}") + + async def _process_batch( + self, batch_data: List[Dict[str, Any]], batch_num: int + ) -> List[Dict[str, Any]]: + """ + Process a single batch of data. + + Args: + batch_data: List of input records + batch_num: Batch number for logging + + Returns: + List of processed results + """ + # Prepare requests + requests = [] + for i, record in enumerate(batch_data): + try: + # Render prompt with record data + prompt = self.prompt_manager.render_prompt( + self.config.task_name, self.config.prompt_name, record + ) + + if not prompt: + self.logger.error( + f"Failed to render prompt for record {i} in batch {batch_num}" + ) + self.stats["failed"] += 1 + continue + + # Create LLM request + request = LLMRequest( + prompt=prompt, + metadata={ + "batch_num": batch_num, + "record_index": i, + "original_record": record, + }, + ) + requests.append(request) + + except Exception as e: + self.logger.error( + f"Error preparing request for record {i} in batch {batch_num}: {e}" + ) + self.stats["failed"] += 1 + continue + + if not requests: + self.logger.warning(f"No valid requests in batch {batch_num}") + return [] + + # Process requests with retry logic + responses = await self._process_requests_with_retry(requests, batch_num) + + # Process responses + results = [] + for i, (request, response) in enumerate(zip(requests, responses)): + try: + result = self._process_response(request, response, batch_data[i]) + results.append(result) + self.stats["successful"] += 1 + except Exception as e: + self.logger.error( + f"Error processing response {i} in batch {batch_num}: {e}" + ) + self.stats["failed"] += 1 + self.stats["errors"].append(str(e)) + + self.stats["total_processed"] += len(batch_data) + return results + + async def _process_requests_with_retry( + self, requests: List[LLMRequest], batch_num: int + ) -> List[LLMResponse]: + """ + Process requests with retry logic. + + Args: + requests: List of LLM requests + batch_num: Batch number for logging + + Returns: + List of LLM responses + """ + for attempt in range(self.config.retry_attempts): + try: + if len(requests) == 1: + # Single request + response = await self.provider.generate_single(requests[0]) + return [response] + else: + # Batch request + batch_request = BatchLLMRequest( + requests=requests, batch_id=f"batch_{batch_num}_{attempt}" + ) + batch_response = await self.provider.generate_batch(batch_request) + return batch_response.responses + + except LLMAPIError as e: + self.logger.warning( + f"Attempt {attempt + 1} failed for batch {batch_num}: {e}" + ) + if attempt == self.config.retry_attempts - 1: + raise + await asyncio.sleep(2**attempt) # Exponential backoff + + return [] + + def _process_response( + self, + request: LLMRequest, + response: LLMResponse, + original_record: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Process a single response into the final result format. + + Args: + request: Original LLM request + response: LLM response + original_record: Original input record + + Returns: + Processed result record + """ + result = { + "original_input": original_record, + "generated_output": response.content, + "metadata": { + "prompt_used": self.config.prompt_name, + "task_name": self.config.task_name, + "provider": self.provider.get_provider_name(), + "request_metadata": request.metadata, + "response_metadata": response.metadata, + "usage": response.usage, + "timestamp": datetime.now().isoformat(), + }, + } + + # Add any errors + if response.metadata and "error" in response.metadata: + result["metadata"]["error"] = response.metadata["error"] + + return result + + def _save_intermediate_results( + self, results: List[Dict[str, Any]], batch_num: int + ) -> None: + """Save intermediate results to file.""" + if not self.config.save_intermediate: + return + + intermediate_file = self.intermediate_dir / f"batch_{batch_num:04d}.json" + with open(intermediate_file, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + self.logger.info(f"Saved intermediate results: {intermediate_file}") + + def _save_final_results(self, results: List[Dict[str, Any]]) -> None: + """Save final results to output file.""" + output_path = Path(self.config.output_file) + + if output_path.suffix.lower() == ".json": + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + elif output_path.suffix.lower() == ".jsonl": + with open(output_path, "w", encoding="utf-8") as f: + for result in results: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + else: + # Default to JSON + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + self.logger.info(f"Saved final results: {output_path}") + + def _create_result(self) -> DistillationResult: + """Create distillation result from statistics.""" + processing_time = None + if self.stats["start_time"] and self.stats["end_time"]: + processing_time = self.stats["end_time"] - self.stats["start_time"] + + return DistillationResult( + total_processed=self.stats["total_processed"], + successful=self.stats["successful"], + failed=self.stats["failed"], + processing_time=processing_time, + output_file=self.config.output_file, + errors=self.stats["errors"], + ) + + +def setup_logging(level: str = "INFO") -> None: + """Setup logging configuration.""" + logging.basicConfig( + level=getattr(logging, level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler(), + logging.FileHandler("knowledge_distillation.log"), + ], + ) + + +async def main(): + """Main function for command-line usage.""" + parser = argparse.ArgumentParser(description="Knowledge Distillation with LLM APIs") + parser.add_argument("--config", required=True, help="Path to configuration file") + parser.add_argument("--log-level", default="INFO", help="Logging level") + + args = parser.parse_args() + setup_logging(args.log_level) + + # Load configuration + with open(args.config, "r") as f: + config_data = json.load(f) + + # Create configuration objects + distillation_config = DistillationConfig(**config_data["distillation"]) + prompt_manager = PromptManager(config_data["prompt_config"]["config_dir"]) + + # Create orchestrator and run + orchestrator = KnowledgeDistillationOrchestrator( + distillation_config, prompt_manager + ) + result = await orchestrator.process_dataset() + + print(f"\n✅ Knowledge distillation completed!") + print(f"📊 Total processed: {result.total_processed}") + print(f"✅ Successful: {result.successful}") + print(f"❌ Failed: {result.failed}") + if result.processing_time: + print(f"⏱️ Processing time: {result.processing_time:.2f} seconds") + print(f"📁 Output file: {result.output_file}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/modeling/llm_post_training/data_process/llm_api_interface.py b/modeling/llm_post_training/data_process/llm_api_interface.py new file mode 100644 index 0000000..742707f --- /dev/null +++ b/modeling/llm_post_training/data_process/llm_api_interface.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" +Generic LLM API interface for knowledge distillation. + +This module provides an abstract base class for different LLM providers, +enabling easy switching between OpenAI, Anthropic, and other APIs. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass +import asyncio +import logging + + +@dataclass +class LLMRequest: + """Represents a single LLM request.""" + + prompt: str + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class LLMResponse: + """Represents a single LLM response.""" + + content: str + usage: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None + request_metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class BatchLLMRequest: + """Represents a batch of LLM requests.""" + + requests: List[LLMRequest] + batch_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class BatchLLMResponse: + """Represents a batch of LLM responses.""" + + responses: List[LLMResponse] + batch_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +class LLMAPIProvider(ABC): + """ + Abstract base class for LLM API providers. + + This interface allows for easy switching between different LLM providers + while maintaining consistent batch processing capabilities. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the LLM provider with configuration. + + Args: + config: Provider-specific configuration dictionary + """ + self.config = config + self.logger = logging.getLogger(self.__class__.__name__) + + @abstractmethod + async def generate_single(self, request: LLMRequest) -> LLMResponse: + """ + Generate a single response from the LLM. + + Args: + request: Single LLM request + + Returns: + Single LLM response + """ + pass + + @abstractmethod + async def generate_batch(self, batch_request: BatchLLMRequest) -> BatchLLMResponse: + """ + Generate responses for a batch of requests. + + Args: + batch_request: Batch of LLM requests + + Returns: + Batch of LLM responses + """ + pass + + @abstractmethod + def validate_config(self) -> bool: + """ + Validate the provider configuration. + + Returns: + True if configuration is valid, False otherwise + """ + pass + + def get_provider_name(self) -> str: + """Get the name of this provider.""" + return self.__class__.__name__ + + def log_usage(self, response: Union[LLMResponse, BatchLLMResponse]) -> None: + """ + Log usage statistics for monitoring and cost tracking. + + Args: + response: Single or batch response to log + """ + if isinstance(response, LLMResponse): + if response.usage: + self.logger.info(f"Usage: {response.usage}") + elif isinstance(response, BatchLLMResponse): + total_usage = {} + for resp in response.responses: + if resp.usage: + for key, value in resp.usage.items(): + total_usage[key] = total_usage.get(key, 0) + value + if total_usage: + self.logger.info(f"Batch usage: {total_usage}") + + +class LLMAPIError(Exception): + """Base exception for LLM API errors.""" + + pass + + +class LLMAPIProviderError(LLMAPIError): + """Exception raised when provider-specific errors occur.""" + + pass + + +class LLMAPIValidationError(LLMAPIError): + """Exception raised when request validation fails.""" + + pass + + +def create_llm_provider(provider_type: str, config: Dict[str, Any]) -> LLMAPIProvider: + """ + Factory function to create LLM providers. + + Args: + provider_type: Type of provider ('openai', 'anthropic', etc.) + config: Provider-specific configuration + + Returns: + Initialized LLM provider instance + + Raises: + ValueError: If provider type is not supported + """ + provider_type = provider_type.lower() + + if provider_type == "openai": + from providers.openai_provider import OpenAIProvider + + return OpenAIProvider(config) + elif provider_type == "anthropic": + from providers.anthropic_provider import AnthropicProvider + + return AnthropicProvider(config) + else: + raise ValueError(f"Unsupported provider type: {provider_type}") + + +async def process_requests_concurrently( + provider: LLMAPIProvider, requests: List[LLMRequest], max_concurrent: int = 10 +) -> List[LLMResponse]: + """ + Process multiple requests concurrently with rate limiting. + + Args: + provider: LLM provider instance + requests: List of requests to process + max_concurrent: Maximum number of concurrent requests + + Returns: + List of responses in the same order as requests + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_single(request: LLMRequest) -> LLMResponse: + async with semaphore: + return await provider.generate_single(request) + + tasks = [process_single(req) for req in requests] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle exceptions + processed_responses = [] + for i, response in enumerate(responses): + if isinstance(response, Exception): + provider.logger.error(f"Request {i} failed: {response}") + # Create error response + error_response = LLMResponse( + content="", + metadata={"error": str(response)}, + request_metadata=requests[i].metadata, + ) + processed_responses.append(error_response) + else: + processed_responses.append(response) + + return processed_responses diff --git a/modeling/llm_post_training/data_process/prompt_configs/instruction_generation.yaml b/modeling/llm_post_training/data_process/prompt_configs/instruction_generation.yaml new file mode 100644 index 0000000..367ba8c --- /dev/null +++ b/modeling/llm_post_training/data_process/prompt_configs/instruction_generation.yaml @@ -0,0 +1,77 @@ +task_name: "instruction_generation" +description: "Generate instruction-following training data from raw text" + +input_format: + text: "string" # Raw text to generate instructions from + domain: "string" # Optional domain context + difficulty: "string" # Optional difficulty level + +output_format: + instruction: "string" # Generated instruction + input: "string" # Optional input context + output: "string" # Expected output + +prompts: + basic_instruction: + description: "Generate basic instruction-following examples" + variables: ["text", "domain"] + template: | + Given the following text, generate a clear instruction that would help someone understand or work with this content. + + Text: {{ text }} + {% if domain %}Domain: {{ domain }}{% endif %} + + Please provide: + 1. A clear, actionable instruction + 2. Any necessary input context (if applicable) + 3. The expected output format + + Format your response as: + Instruction: [your instruction here] + Input: [input context if needed, otherwise "None"] + Output: [expected output format] + + creative_task: + description: "Generate creative tasks from text content" + variables: ["text", "difficulty"] + template: | + Create a creative task based on the following text. The task should be engaging and educational. + + Text: {{ text }} + Difficulty Level: {{ difficulty or "intermediate" }} + + Generate a creative task that: + - Is appropriate for the difficulty level + - Encourages critical thinking + - Is engaging and fun + - Has clear success criteria + + Format your response as: + Instruction: [creative task description] + Input: [any required input or context] + Output: [expected output format and criteria] + + analysis_prompt: + description: "Generate analytical tasks from text" + variables: ["text", "domain"] + template: | + Create an analytical task that helps users understand the deeper meaning or structure of this text. + + Text: {{ text }} + {% if domain %}Domain: {{ domain }}{% endif %} + + Design an analysis task that requires: + - Critical thinking + - Pattern recognition + - Synthesis of information + - Clear reasoning + + Format your response as: + Instruction: [analysis task description] + Input: [text to analyze and any additional context] + Output: [expected analysis format and key points to cover] + +metadata: + version: "1.0" + author: "Knowledge Distillation System" + created: "2024-01-01" diff --git a/modeling/llm_post_training/data_process/prompt_configs/qa_generation.yaml b/modeling/llm_post_training/data_process/prompt_configs/qa_generation.yaml new file mode 100644 index 0000000..e7a533a --- /dev/null +++ b/modeling/llm_post_training/data_process/prompt_configs/qa_generation.yaml @@ -0,0 +1,88 @@ +task_name: "qa_generation" +description: "Generate question-answer pairs from text content" + +input_format: + text: "string" # Source text to generate Q&A from + context: "string" # Optional additional context + question_type: "string" # Type of questions to generate + +output_format: + question: "string" # Generated question + answer: "string" # Generated answer + difficulty: "string" # Difficulty level + +prompts: + factual_qa: + description: "Generate factual question-answer pairs" + variables: ["text", "context"] + template: | + Generate factual question-answer pairs from the following text. Focus on key facts, definitions, and important details. + + Text: {{ text }} + {% if context %}Additional Context: {{ context }}{% endif %} + + Create 3-5 factual questions that: + - Test understanding of key concepts + - Cover important facts and details + - Have clear, concise answers + - Vary in difficulty (easy to medium) + + Format your response as: + Question 1: [question] + Answer 1: [answer] + Difficulty: [easy/medium/hard] + + Question 2: [question] + Answer 2: [answer] + Difficulty: [easy/medium/hard] + + [Continue for all questions...] + + analytical_qa: + description: "Generate analytical question-answer pairs" + variables: ["text", "question_type"] + template: | + Generate analytical question-answer pairs that require deeper thinking and analysis. + + Text: {{ text }} + Question Type: {{ question_type or "analytical" }} + + Create 2-3 analytical questions that: + - Require critical thinking + - Ask for comparisons, analysis, or synthesis + - Have detailed, well-reasoned answers + - Test higher-order thinking skills + + Format your response as: + Question 1: [analytical question] + Answer 1: [detailed analytical answer] + Difficulty: [medium/hard] + + [Continue for all questions...] + + application_qa: + description: "Generate application-based question-answer pairs" + variables: ["text", "context"] + template: | + Generate application-based questions that ask how to apply concepts from the text. + + Text: {{ text }} + {% if context %}Context: {{ context }}{% endif %} + + Create 2-3 application questions that: + - Ask how to apply concepts in practice + - Require problem-solving skills + - Have step-by-step answers + - Connect theory to real-world scenarios + + Format your response as: + Question 1: [application question] + Answer 1: [step-by-step application answer] + Difficulty: [medium/hard] + + [Continue for all questions...] + +metadata: + version: "1.0" + author: "Knowledge Distillation System" + created: "2024-01-01" diff --git a/modeling/llm_post_training/data_process/prompt_configs/summarization.yaml b/modeling/llm_post_training/data_process/prompt_configs/summarization.yaml new file mode 100644 index 0000000..74b8a59 --- /dev/null +++ b/modeling/llm_post_training/data_process/prompt_configs/summarization.yaml @@ -0,0 +1,108 @@ +task_name: "summarization" +description: "Generate summaries and abstractive content from source text" + +input_format: + text: "string" # Source text to summarize + length: "string" # Desired summary length (short/medium/long) + focus: "string" # Optional focus area for summary + +output_format: + summary: "string" # Generated summary + key_points: "list" # List of key points + length: "string" # Actual summary length + +prompts: + extractive_summary: + description: "Generate extractive summary by selecting key sentences" + variables: ["text", "length"] + template: | + Create an extractive summary by selecting the most important sentences from the text. + + Text: {{ text }} + Desired Length: {{ length or "medium" }} + + Guidelines: + - For "short": Select 2-3 most important sentences + - For "medium": Select 4-6 key sentences + - For "long": Select 7-10 important sentences + + Focus on sentences that: + - Contain main ideas and key concepts + - Provide essential information + - Are clear and well-written + - Cover different aspects of the topic + + Format your response as: + Summary: [selected sentences combined into coherent summary] + Key Points: + 1. [first key point] + 2. [second key point] + 3. [third key point] + [Continue as needed...] + Length: [short/medium/long] + + abstractive_summary: + description: "Generate abstractive summary by paraphrasing and synthesizing" + variables: ["text", "length", "focus"] + template: | + Create an abstractive summary by paraphrasing and synthesizing the main ideas. + + Text: {{ text }} + Desired Length: {{ length or "medium" }} + {% if focus %}Focus Area: {{ focus }}{% endif %} + + Guidelines: + - Paraphrase the main ideas in your own words + - Synthesize information from multiple parts of the text + - Maintain the original meaning and tone + - Create a coherent, flowing summary + + {% if focus %}Pay special attention to: {{ focus }}{% endif %} + + Format your response as: + Summary: [coherent abstractive summary] + Key Points: + 1. [synthesized key point] + 2. [synthesized key point] + 3. [synthesized key point] + [Continue as needed...] + Length: [short/medium/long] + + bullet_point_summary: + description: "Generate structured bullet-point summary" + variables: ["text", "focus"] + template: | + Create a structured bullet-point summary of the text. + + Text: {{ text }} + {% if focus %}Focus Area: {{ focus }}{% endif %} + + Structure your summary with: + - Main topic/theme + - Key concepts and ideas + - Important details and examples + - Conclusions or implications + + {% if focus %}Emphasize aspects related to: {{ focus }}{% endif %} + + Format your response as: + Main Topic: [brief topic statement] + + Key Concepts: + • [concept 1] + • [concept 2] + • [concept 3] + + Important Details: + • [detail 1] + • [detail 2] + • [detail 3] + + Conclusions: + • [conclusion 1] + • [conclusion 2] + +metadata: + version: "1.0" + author: "Knowledge Distillation System" + created: "2024-01-01" diff --git a/modeling/llm_post_training/data_process/prompt_manager.py b/modeling/llm_post_training/data_process/prompt_manager.py new file mode 100644 index 0000000..68587c7 --- /dev/null +++ b/modeling/llm_post_training/data_process/prompt_manager.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +""" +Prompt management system for knowledge distillation. + +This module handles loading and managing prompts from YAML configuration files, +enabling easy customization of different data generation tasks. +""" + +import yaml +from typing import Dict, Any, List, Optional, Union +from pathlib import Path +from dataclasses import dataclass +import logging +from jinja2 import Template, Environment, FileSystemLoader + + +@dataclass +class PromptConfig: + """Configuration for a specific prompt template.""" + + name: str + template: str + variables: List[str] + description: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class TaskConfig: + """Configuration for a knowledge distillation task.""" + + task_name: str + description: str + prompts: Dict[str, PromptConfig] + input_format: Dict[str, Any] + output_format: Dict[str, Any] + metadata: Optional[Dict[str, Any]] = None + + +class PromptManager: + """ + Manages prompt templates and configurations for knowledge distillation. + + Loads prompts from YAML files and provides template rendering capabilities + using Jinja2 templating engine. + """ + + def __init__(self, config_dir: Union[str, Path]): + """ + Initialize prompt manager. + + Args: + config_dir: Directory containing YAML configuration files + """ + self.config_dir = Path(config_dir) + self.logger = logging.getLogger(self.__class__.__name__) + + # Initialize Jinja2 environment + self.jinja_env = Environment( + loader=FileSystemLoader(self.config_dir), + trim_blocks=True, + lstrip_blocks=True, + ) + + # Cache for loaded configurations + self._config_cache: Dict[str, TaskConfig] = {} + + # Load all available configurations + self._load_all_configs() + + def _load_all_configs(self) -> None: + """Load all YAML configuration files from the config directory.""" + if not self.config_dir.exists(): + self.logger.warning(f"Config directory {self.config_dir} does not exist") + return + + yaml_files = list(self.config_dir.glob("*.yaml")) + list( + self.config_dir.glob("*.yml") + ) + + for yaml_file in yaml_files: + try: + self._load_config_file(yaml_file) + except Exception as e: + self.logger.error(f"Failed to load config {yaml_file}: {e}") + + def _load_config_file(self, config_file: Path) -> None: + """ + Load a single YAML configuration file. + + Args: + config_file: Path to YAML configuration file + """ + with open(config_file, "r", encoding="utf-8") as f: + config_data = yaml.safe_load(f) + + task_name = config_data.get("task_name", config_file.stem) + + # Parse prompts + prompts = {} + for prompt_name, prompt_data in config_data.get("prompts", {}).items(): + prompts[prompt_name] = PromptConfig( + name=prompt_name, + template=prompt_data["template"], + variables=prompt_data.get("variables", []), + description=prompt_data.get("description"), + metadata=prompt_data.get("metadata", {}), + ) + + # Create task configuration + task_config = TaskConfig( + task_name=task_name, + description=config_data.get("description", ""), + prompts=prompts, + input_format=config_data.get("input_format", {}), + output_format=config_data.get("output_format", {}), + metadata=config_data.get("metadata", {}), + ) + + self._config_cache[task_name] = task_config + self.logger.info(f"Loaded task configuration: {task_name}") + + def get_available_tasks(self) -> List[str]: + """ + Get list of available task names. + + Returns: + List of task names + """ + return list(self._config_cache.keys()) + + def get_task_config(self, task_name: str) -> Optional[TaskConfig]: + """ + Get configuration for a specific task. + + Args: + task_name: Name of the task + + Returns: + Task configuration or None if not found + """ + return self._config_cache.get(task_name) + + def get_prompt_config( + self, task_name: str, prompt_name: str + ) -> Optional[PromptConfig]: + """ + Get prompt configuration for a specific task and prompt. + + Args: + task_name: Name of the task + prompt_name: Name of the prompt + + Returns: + Prompt configuration or None if not found + """ + task_config = self.get_task_config(task_name) + if task_config: + return task_config.prompts.get(prompt_name) + return None + + def render_prompt( + self, task_name: str, prompt_name: str, variables: Dict[str, Any] + ) -> Optional[str]: + """ + Render a prompt template with given variables. + + Args: + task_name: Name of the task + prompt_name: Name of the prompt + variables: Variables to substitute in the template + + Returns: + Rendered prompt string or None if not found + """ + prompt_config = self.get_prompt_config(task_name, prompt_name) + if not prompt_config: + self.logger.error(f"Prompt not found: {task_name}.{prompt_name}") + return None + + try: + template = Template(prompt_config.template) + return template.render(**variables) + except Exception as e: + self.logger.error(f"Failed to render prompt {task_name}.{prompt_name}: {e}") + return None + + def render_prompt_from_config( + self, prompt_config: PromptConfig, variables: Dict[str, Any] + ) -> Optional[str]: + """ + Render a prompt template from a PromptConfig object. + + Args: + prompt_config: Prompt configuration object + variables: Variables to substitute in the template + + Returns: + Rendered prompt string or None if rendering fails + """ + try: + template = Template(prompt_config.template) + return template.render(**variables) + except Exception as e: + self.logger.error(f"Failed to render prompt {prompt_config.name}: {e}") + return None + + def validate_variables( + self, task_name: str, prompt_name: str, variables: Dict[str, Any] + ) -> bool: + """ + Validate that all required variables are provided. + + Args: + task_name: Name of the task + prompt_name: Name of the prompt + variables: Variables to validate + + Returns: + True if all required variables are provided + """ + prompt_config = self.get_prompt_config(task_name, prompt_name) + if not prompt_config: + return False + + required_vars = set(prompt_config.variables) + provided_vars = set(variables.keys()) + + missing_vars = required_vars - provided_vars + if missing_vars: + self.logger.error( + f"Missing required variables for {task_name}.{prompt_name}: {missing_vars}" + ) + return False + + return True + + def get_task_info(self, task_name: str) -> Optional[Dict[str, Any]]: + """ + Get information about a task. + + Args: + task_name: Name of the task + + Returns: + Dictionary with task information + """ + task_config = self.get_task_config(task_name) + if not task_config: + return None + + return { + "task_name": task_config.task_name, + "description": task_config.description, + "available_prompts": list(task_config.prompts.keys()), + "input_format": task_config.input_format, + "output_format": task_config.output_format, + "metadata": task_config.metadata, + } + + def reload_configs(self) -> None: + """Reload all configuration files.""" + self._config_cache.clear() + self._load_all_configs() + self.logger.info("Reloaded all prompt configurations") + + def add_custom_prompt( + self, + task_name: str, + prompt_name: str, + template: str, + variables: List[str], + description: Optional[str] = None, + ) -> None: + """ + Add a custom prompt configuration at runtime. + + Args: + task_name: Name of the task + prompt_name: Name of the prompt + template: Prompt template string + variables: List of required variables + description: Optional description + """ + prompt_config = PromptConfig( + name=prompt_name, + template=template, + variables=variables, + description=description, + ) + + if task_name not in self._config_cache: + # Create new task configuration + task_config = TaskConfig( + task_name=task_name, + description="Custom task", + prompts={}, + input_format={}, + output_format={}, + ) + self._config_cache[task_name] = task_config + + self._config_cache[task_name].prompts[prompt_name] = prompt_config + self.logger.info(f"Added custom prompt: {task_name}.{prompt_name}") diff --git a/modeling/llm_post_training/data_process/providers/__init__.py b/modeling/llm_post_training/data_process/providers/__init__.py new file mode 100644 index 0000000..84467b5 --- /dev/null +++ b/modeling/llm_post_training/data_process/providers/__init__.py @@ -0,0 +1,10 @@ +""" +LLM API providers package. + +This package contains implementations of different LLM API providers +for the knowledge distillation system. +""" + +from .openai_provider import OpenAIProvider + +__all__ = ["OpenAIProvider"] diff --git a/modeling/llm_post_training/data_process/providers/openai_provider.py b/modeling/llm_post_training/data_process/providers/openai_provider.py new file mode 100644 index 0000000..0f9e587 --- /dev/null +++ b/modeling/llm_post_training/data_process/providers/openai_provider.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +""" +OpenAI API provider implementation for knowledge distillation. + +This module implements the OpenAI provider with support for both +single requests and batch inference API. +""" + +import asyncio +import json +import time +from typing import List, Dict, Any, Optional +from pathlib import Path +import openai +from openai import AsyncOpenAI + +from llm_api_interface import ( + LLMAPIProvider, + LLMRequest, + LLMResponse, + BatchLLMRequest, + BatchLLMResponse, + LLMAPIProviderError, +) + + +class OpenAIProvider(LLMAPIProvider): + """ + OpenAI API provider implementation. + + Supports both single requests and batch inference API for efficient + processing of large datasets. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize OpenAI provider. + + Args: + config: Configuration dictionary with keys: + - api_key: OpenAI API key + - model: Model name (e.g., 'gpt-4', 'gpt-3.5-turbo') + - base_url: Optional custom base URL + - use_batch_api: Whether to use batch inference API + - batch_timeout: Timeout for batch operations (seconds) + """ + super().__init__(config) + + self.api_key = config.get("api_key") + self.model = config.get("model", "gpt-3.5-turbo") + self.base_url = config.get("base_url") + self.use_batch_api = config.get("use_batch_api", True) + self.batch_timeout = config.get("batch_timeout", 3600) # 1 hour + + # Initialize OpenAI client + client_kwargs = {"api_key": self.api_key} + if self.base_url: + client_kwargs["base_url"] = self.base_url + + self.client = AsyncOpenAI(**client_kwargs) + + # Batch API specific + self.batch_client = ( + openai.OpenAI(**client_kwargs) if self.use_batch_api else None + ) + + def validate_config(self) -> bool: + """Validate OpenAI configuration.""" + if not self.api_key: + self.logger.error("OpenAI API key is required") + return False + + if not self.model: + self.logger.error("OpenAI model is required") + return False + + return True + + async def generate_single(self, request: LLMRequest) -> LLMResponse: + """ + Generate a single response using OpenAI API. + + Args: + request: Single LLM request + + Returns: + Single LLM response + + Raises: + LLMAPIProviderError: If API call fails + """ + try: + # Prepare request parameters + request_params = { + "model": self.model, + "messages": [{"role": "user", "content": request.prompt}], + "max_tokens": request.max_tokens, + "temperature": request.temperature, + "top_p": request.top_p, + } + + # Remove None values + request_params = {k: v for k, v in request_params.items() if v is not None} + + # Make API call + response = await self.client.chat.completions.create(**request_params) + + # Extract response content + content = response.choices[0].message.content or "" + + # Extract usage information + usage = ( + { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + if response.usage + else None + ) + + return LLMResponse( + content=content, + usage=usage, + metadata={"model": self.model}, + request_metadata=request.metadata, + ) + + except Exception as e: + self.logger.error(f"OpenAI API error: {e}") + raise LLMAPIProviderError(f"OpenAI API call failed: {e}") + + async def generate_batch(self, batch_request: BatchLLMRequest) -> BatchLLMResponse: + """ + Generate responses for a batch of requests. + + Uses OpenAI batch inference API for efficiency when available, + otherwise falls back to concurrent single requests. + + Args: + batch_request: Batch of LLM requests + + Returns: + Batch of LLM responses + """ + if self.use_batch_api and self.batch_client: + return await self._generate_batch_api(batch_request) + else: + return await self._generate_batch_concurrent(batch_request) + + async def _generate_batch_api( + self, batch_request: BatchLLMRequest + ) -> BatchLLMResponse: + """ + Generate batch using OpenAI batch inference API. + + Args: + batch_request: Batch of LLM requests + + Returns: + Batch of LLM responses + """ + try: + # Prepare batch file + batch_data = [] + for i, req in enumerate(batch_request.requests): + batch_data.append( + { + "custom_id": f"request_{i}_{batch_request.batch_id or ''}", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": self.model, + "messages": [{"role": "user", "content": req.prompt}], + "max_tokens": req.max_tokens, + "temperature": req.temperature, + "top_p": req.top_p, + }, + } + ) + + # Create batch file + batch_file_path = f"/tmp/openai_batch_{int(time.time())}.jsonl" + with open(batch_file_path, "w") as f: + for item in batch_data: + f.write(json.dumps(item) + "\n") + + # Upload batch file + with open(batch_file_path, "rb") as f: + batch_file = self.batch_client.files.create(file=f, purpose="batch") + + # Create batch + batch = self.batch_client.batches.create( + input_file_id=batch_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + self.logger.info(f"Created batch {batch.id}, waiting for completion...") + + # Wait for batch completion + batch = self._wait_for_batch_completion(batch.id) + + # Retrieve results + batch_output = self.batch_client.files.content(batch.output_file_id) + results = [] + for line in batch_output.text.split("\n"): + if line.strip(): + results.append(json.loads(line)) + + # Process results + responses = [] + for result in results: + if result.get("response", {}).get("status_code") == 200: + response_data = result["response"]["body"] + content = response_data["choices"][0]["message"]["content"] + usage = response_data.get("usage") + + responses.append( + LLMResponse( + content=content, + usage=usage, + metadata={"batch_id": batch.id}, + request_metadata=result.get("custom_id"), + ) + ) + else: + # Handle error + error_msg = ( + result.get("response", {}) + .get("body", {}) + .get("error", "Unknown error") + ) + responses.append( + LLMResponse( + content="", + metadata={"error": error_msg, "batch_id": batch.id}, + request_metadata=result.get("custom_id"), + ) + ) + + # Cleanup + Path(batch_file_path).unlink(missing_ok=True) + + return BatchLLMResponse( + responses=responses, + batch_id=batch.id, + metadata={"provider": "openai_batch_api"}, + ) + + except Exception as e: + self.logger.error(f"OpenAI batch API error: {e}") + # Fallback to concurrent requests + self.logger.info("Falling back to concurrent single requests...") + return await self._generate_batch_concurrent(batch_request) + + def _wait_for_batch_completion( + self, batch_id: str, check_interval: int = 30 + ) -> Any: + """ + Wait for batch completion with polling. + + Args: + batch_id: OpenAI batch ID + check_interval: Seconds between status checks + + Returns: + Completed batch object + """ + while True: + batch = self.batch_client.batches.retrieve(batch_id) + + if batch.status == "completed": + return batch + elif batch.status == "failed": + raise LLMAPIProviderError(f"Batch {batch_id} failed: {batch.errors}") + elif batch.status in ["cancelled", "cancelling"]: + raise LLMAPIProviderError(f"Batch {batch_id} was cancelled") + + self.logger.info(f"Batch {batch_id} status: {batch.status}, waiting...") + time.sleep(check_interval) + + async def _generate_batch_concurrent( + self, batch_request: BatchLLMRequest + ) -> BatchLLMResponse: + """ + Generate batch using concurrent single requests. + + Args: + batch_request: Batch of LLM requests + + Returns: + Batch of LLM responses + """ + # Use the utility function from the interface + from llm_api_interface import process_requests_concurrently + + responses = await process_requests_concurrently( + self, batch_request.requests, max_concurrent=10 + ) + + return BatchLLMResponse( + responses=responses, + batch_id=batch_request.batch_id, + metadata={"provider": "openai_concurrent"}, + ) + + def get_cost_estimate(self, requests: List[LLMRequest]) -> Dict[str, float]: + """ + Estimate cost for a list of requests. + + Args: + requests: List of requests to estimate + + Returns: + Dictionary with cost estimates + """ + # Rough token estimation (4 chars per token) + total_input_tokens = sum(len(req.prompt) // 4 for req in requests) + estimated_output_tokens = sum((req.max_tokens or 100) for req in requests) + + # GPT-3.5-turbo pricing (as of 2024) + input_cost_per_1k = 0.0015 + output_cost_per_1k = 0.002 + + input_cost = (total_input_tokens / 1000) * input_cost_per_1k + output_cost = (estimated_output_tokens / 1000) * output_cost_per_1k + + return { + "input_tokens": total_input_tokens, + "output_tokens": estimated_output_tokens, + "input_cost": input_cost, + "output_cost": output_cost, + "total_cost": input_cost + output_cost, + "currency": "USD", + } diff --git a/modeling/llm_post_training/data_process/requirements.txt b/modeling/llm_post_training/data_process/requirements.txt new file mode 100644 index 0000000..4eae1ba --- /dev/null +++ b/modeling/llm_post_training/data_process/requirements.txt @@ -0,0 +1,16 @@ +# Knowledge Distillation System Requirements + +# Core dependencies +openai>=1.0.0 +pyyaml>=6.0 +jinja2>=3.0.0 +pandas>=1.5.0 + +# Optional dependencies for different providers +# anthropic>=0.7.0 # Uncomment if using Anthropic +# google-generativeai>=0.3.0 # Uncomment if using Google AI + +# Development dependencies +pytest>=7.0.0 +black>=22.0.0 +flake8>=5.0.0