Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions ChatQnA/chatqna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

import argparse
import json
import logging
import os
import re

# Configure logging
logger = logging.getLogger(__name__)
log_level = logging.DEBUG if os.getenv("LOGFLAG", "").lower() == "true" else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
from comps.cores.mega.utils import handle_message
from comps.cores.proto.api_protocol import (
Expand Down Expand Up @@ -62,6 +68,10 @@ def generate_rag_prompt(question, documents):


def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
logger.debug(
f"Aligning inputs for service: {self.services[cur_node].name}, type: {self.services[cur_node].service_type}"
)

if self.services[cur_node].service_type == ServiceType.EMBEDDING:
inputs["inputs"] = inputs["text"]
del inputs["text"]
Expand All @@ -83,6 +93,9 @@ def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **k
# next_inputs["repetition_penalty"] = inputs["repetition_penalty"]
next_inputs["temperature"] = inputs["temperature"]
inputs = next_inputs

# Log the aligned inputs (be careful with sensitive data)
logger.debug(f"Aligned inputs for {self.services[cur_node].name}: {type(inputs)}")
return inputs


Expand Down Expand Up @@ -123,7 +136,9 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di
elif input_variables == ["question"]:
prompt = prompt_template.format(question=data["initial_query"])
else:
print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']")
logger.warning(
f"{prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
prompt = ChatTemplate.generate_rag_prompt(data["initial_query"], docs)
else:
prompt = ChatTemplate.generate_rag_prompt(data["initial_query"], docs)
Expand Down Expand Up @@ -152,7 +167,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di
elif input_variables == ["question"]:
prompt = prompt_template.format(question=prompt)
else:
print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']")
logger.warning(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']")
prompt = ChatTemplate.generate_rag_prompt(prompt, reranked_docs)
else:
prompt = ChatTemplate.generate_rag_prompt(prompt, reranked_docs)
Expand All @@ -171,29 +186,65 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di


def align_generator(self, gen, **kwargs):
# OpenAI response format
# b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct","system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},"logprobs":null,"finish_reason":null}]}\n\n'
for line in gen:
line = line.decode("utf-8")
start = line.find("{")
end = line.rfind("}") + 1
"""Aligns the generator output to match ChatQnA's format of sending bytes.

Handles different LLM output formats (TGI, OpenAI) and properly filters
empty or null content chunks to avoid UI display issues.
"""
# OpenAI response format example:
# b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct",
# "system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},
# "logprobs":null,"finish_reason":null}]}\n\n'

json_str = line[start:end]
for line in gen:
try:
# sometimes yield empty chunk, do a fallback here
line = line.decode("utf-8")
start = line.find("{")
end = line.rfind("}") + 1

# Skip lines with invalid JSON structure
if start == -1 or end <= start:
logger.debug("Skipping line with invalid JSON structure")
continue

json_str = line[start:end]

# Parse the JSON data
json_data = json.loads(json_str)

# Handle TGI format responses
if "ops" in json_data and "op" in json_data["ops"][0]:
if "value" in json_data["ops"][0] and isinstance(json_data["ops"][0]["value"], str):
yield f"data: {repr(json_data['ops'][0]['value'].encode('utf-8'))}\n\n"
else:
pass
elif (
json_data["choices"][0]["finish_reason"] != "eos_token"
and "content" in json_data["choices"][0]["delta"]
):
yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n"
# Empty value chunks are silently skipped

# Handle OpenAI format responses
elif "choices" in json_data and len(json_data["choices"]) > 0:
# Only yield content if it exists and is not null
if (
"delta" in json_data["choices"][0]
and "content" in json_data["choices"][0]["delta"]
and json_data["choices"][0]["delta"]["content"] is not None
):
content = json_data["choices"][0]["delta"]["content"]
yield f"data: {repr(content.encode('utf-8'))}\n\n"
# Null content chunks are silently skipped
elif (
"delta" in json_data["choices"][0]
and "content" in json_data["choices"][0]["delta"]
and json_data["choices"][0]["delta"]["content"] is None
):
logger.debug("Skipping null content chunk")

except json.JSONDecodeError as e:
# Log the error with the problematic JSON string for better debugging
logger.error(f"JSON parsing error in align_generator: {e}\nProblematic JSON: {json_str[:200]}")
# Skip sending invalid JSON to avoid UI issues
continue
except Exception as e:
yield f"data: {repr(json_str.encode('utf-8'))}\n\n"
logger.error(f"Unexpected error in align_generator: {e}, line snippet: {line[:100]}...")
# Skip sending to avoid UI issues
continue
yield "data: [DONE]\n\n"


Expand Down
Loading