Skip to content

Commit

Permalink
More e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nagkumar91 committed Apr 18, 2024
1 parent ab6a3cc commit ea5d5f3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/promptflow-evals/promptflow/evals/synthetic/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class QADataGenerator:
_PARSING_ERR_UNEQUAL_Q_AFTER_MOD = "Parsing error: Unequal question count after modification"
_PARSING_ERR_FIRST_LINE = "Parsing error: First line must be a question"

def __init__(self, model_config: Dict):
def __init__(self, *, model_config: Dict):
"""Initialize QADataGenerator using Azure OpenAI details."""

api_key = "OPENAI_API_KEY"
Expand Down
54 changes: 54 additions & 0 deletions src/promptflow-evals/tests/simulator/e2etests/test_qa_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os

import pytest

from promptflow.evals.synthetic.qa import QADataGenerator, QAType


@pytest.mark.usefixtures("model_config", "recording_injection")
@pytest.mark.e2etest
class TestQAGenerator:
def setup(self, model_config):
os.environ["AZURE_OPENAI_ENDPOINT"] = model_config.azure_endpoint
os.environ["AZURE_OPENAI_KEY"] = model_config.api_key
text = (
"Leonardo di ser Piero da Vinci (15 April 1452 - 2 May 1519) was an Italian "
"polymath of the High Renaissance who was active as a painter, draughtsman, "
"engineer, scientist, theorist, sculptor, and architect. While his fame "
"initially rested on his achievements as a painter, he has also become known "
"for his notebooks, in which he made drawings and notes on a variety of "
"subjects, including anatomy, astronomy, botany, cartography, painting, and "
"paleontology. Leonardo epitomized the Renaissance humanist ideal, and his "
"collective works comprise a contribution to later generations of artists "
"matched only by that of his younger contemporary Michelangelo."
)
return text

def test_qa_generator_basic_conversation(self, model_config):
model_name = "gpt-4"
text = self.setup(model_config)
model_config = dict(
deployment=model_name,
model=model_name,
max_tokens=2000,
)
qa_generator = QADataGenerator(model_config=model_config)
qa_type = QAType.CONVERSATION
result = qa_generator.generate(text=text, qa_type=qa_type, num_questions=5)
assert "question_answers" in result.keys()
assert len(result["question_answers"]) == 5

def test_qa_generator_basic_summary(self, model_config):
model_name = "gpt-4"
text = self.setup(model_config)
model_config = dict(
deployment=model_name,
model=model_name,
max_tokens=2000,
)
qa_generator = QADataGenerator(model_config=model_config)
qa_type = QAType.SUMMARY
result = qa_generator.generate(text=text, qa_type=qa_type)
assert "question_answers" in result.keys()
assert len(result["question_answers"]) == 1
assert result["question_answers"][0][0].startswith("Write a summary in 100 words")
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_extract_qa_from_response(self):
"Answer after space.\n\n",
]
model_config = dict(api_base=API_BASE, api_key=API_KEY, deployment=DEPLOYMENT, model=MODEL)
qa_generator = QADataGenerator(model_config)
qa_generator = QADataGenerator(model_config=model_config)
questions, answers = qa_generator._parse_qa_from_response(response_text=response_text)
for i, question in enumerate(questions):
assert expected_questions[i] == question, "Question not equal"
Expand All @@ -57,15 +57,15 @@ def test_extract_qa_from_response(self):

def test_unsupported_num_questions_for_summary(self):
model_config = dict(api_base=API_BASE, api_key=API_KEY, deployment=DEPLOYMENT, model=MODEL)
qa_generator = QADataGenerator(model_config)
qa_generator = QADataGenerator(model_config=model_config)
with pytest.raises(ValueError) as excinfo:
qa_generator.generate("", QAType.SUMMARY, 10)
assert str(excinfo.value) == "num_questions unsupported for Summary QAType"

@pytest.mark.parametrize("num_questions", [0, -1])
def test_invalid_num_questions(self, num_questions):
model_config = dict(api_base=API_BASE, api_key=API_KEY, deployment=DEPLOYMENT, model=MODEL)
qa_generator = QADataGenerator(model_config)
qa_generator = QADataGenerator(model_config=model_config)
with pytest.raises(ValueError) as excinfo:
qa_generator.generate("", QAType.SHORT_ANSWER, num_questions)
assert str(excinfo.value) == "num_questions must be an integer greater than zero"
Expand All @@ -89,7 +89,7 @@ def test_export_format(self, qa_type, structure):
]

model_config = dict(api_base=API_BASE, api_key=API_KEY, deployment=DEPLOYMENT, model=MODEL)
qa_generator = QADataGenerator(model_config)
qa_generator = QADataGenerator(model_config=model_config)
qas = list(zip(questions, answers))
filepath = os.path.join(pathlib.Path(__file__).parent.parent.resolve(), "test_configs")
output_file = os.path.join(filepath, f"test_{qa_type.value}_{structure.value}.jsonl")
Expand Down

0 comments on commit ea5d5f3

Please sign in to comment.