Skip to content

Commit

Permalink
Merge pull request #282 from MLSysOps/hz/llama-1b
Browse files Browse the repository at this point in the history
[MRG] ask users to provide absolute path and test deepseek
  • Loading branch information
huangyz0918 authored Feb 3, 2025
2 parents 766e505 + 9dd859d commit 9f84f71
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mle/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def plan(self, user_prompt):
self.chat_history,
response_format={"type": "json_object"}
)

self.chat_history.append({"role": "assistant", "content": text})

try:
Expand Down
40 changes: 38 additions & 2 deletions mle/model/ollama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
import re

from .common import Model

Expand Down Expand Up @@ -27,29 +28,64 @@ def __init__(self, model, host_url=None):
"More information, please refer to: https://github.com/ollama/ollama-python"
)

def _clean_think_tags(self, text):
"""
Remove content between <think> tags and empty think tags from the text.
Args:
text (str): The input text to clean.
Returns:
str: The cleaned text with think tags and their content removed.
"""
# Remove content between <think> tags
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# Remove empty think tags
text = re.sub(r'<think></think>', '', text)
return text.strip()

def _process_message(self, message, **kwargs):
"""
Process the message before sending to the model.
Args:
message: The message to process.
**kwargs: Additional arguments.
Returns:
dict: The processed message.
"""
if isinstance(message, dict) and 'content' in message:
message['content'] = self._clean_think_tags(message['content'])
return message

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
**kwargs: Additional arguments for the model.
Returns:
str: The model's response.
"""

# Check if 'response_format' exists in kwargs
format = None
if 'response_format' in kwargs and kwargs['response_format'].get('type') == 'json_object':
format = 'json'

return self.client.chat(model=self.model, messages=chat_history, format=format)['message']['content']
response = self.client.chat(model=self.model, messages=chat_history, format=format)
return self._clean_think_tags(response['message']['content'])

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
**kwargs: Additional arguments for the model.
Yields:
str: Chunks of the model's response.
"""

for chunk in self.client.chat(
model=self.model,
messages=chat_history,
stream=True
):
yield chunk['message']['content']
yield self._clean_think_tags(chunk['message']['content'])
2 changes: 1 addition & 1 deletion mle/workflow/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def baseline(work_dir: str, model=None):
dataset = ca.resume("dataset")
if dataset is None:
advisor = AdviseAgent(model, console)
dataset = ask_text("Please provide your dataset information (a public dataset name or a local file path)")
dataset = ask_text("Please provide your dataset information (a public dataset name or a local absolute filepath)")
if not dataset:
print_in_box("The dataset is empty. Aborted", console, title="Error", color="red")
return
Expand Down

0 comments on commit 9f84f71

Please sign in to comment.