-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
678 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset | ||
date_understanding.zip | ||
logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
type question(question: String) | ||
|
||
@gpt_extract_info( | ||
header="Please extract dates and their relationships from the provided question.", | ||
prompts=[ | ||
"What are the mentioned MM/DD/YYYY dates in JSON format?", | ||
"What is the goal in JSON format?", | ||
"What are the relationships of the dates in JSON format?" | ||
], | ||
examples=[ | ||
( | ||
["Yesterday is February 14, 2019. What is the date 1 month ago?"], | ||
[ | ||
[("yesterday", "02/14/2019")], | ||
[("1-month-ago")], | ||
[("yesterday", "today", "1 day"), ("1-month-ago", "today", "1 month")] | ||
] | ||
), | ||
( | ||
["Yesterday is February 14, 2019. What is the date 1 year later?"], | ||
[ | ||
[("yesterday", "02/14/2019")], | ||
[("1-year-later")], | ||
[("yesterday", "today", "1 day"), ("today", "1-year-later", "1 year")] | ||
] | ||
), | ||
( | ||
["The deadline is August 15, 2023, which is today. What is the date today?"], | ||
[ | ||
[("deadline", "08/15/2023")], | ||
[("today")], | ||
[("today", "deadline", "0 days")] | ||
] | ||
), | ||
( | ||
["Jenny began her current job on the Christmas Eve of 2016. What is the 5th anniversary?"], | ||
[ | ||
[("start-date", "12/24/2016")], | ||
[("anniversary")], | ||
[("start-date", "anniversary", "5 years")] | ||
] | ||
), | ||
( | ||
["Today is March 5th, 2010. Mark earns $1000 per day starting from now. When can Mark earn $10000?"], | ||
[ | ||
[("start-date", "03/05/2010")], | ||
[("earned-10000")], | ||
[("start-date", "earned-10000", "10 days")] | ||
] | ||
) | ||
], | ||
model="gpt-4", | ||
cot=[false,false,true], | ||
debug=true, | ||
) | ||
type extract_mentioned_date(bound question: String, label: String, date: DateTime), | ||
extract_goal (bound question: String, goal: String), | ||
extract_relationship (bound question: String, earlier_date: String, later_date: String, diff: Duration) | ||
|
||
// Extract information from the question | ||
rel mentioned_date(label, date) = question(question) and extract_mentioned_date(question, label, date) | ||
rel relationship(l1, l2, diff) = question(question) and extract_relationship(question, l1, l2, diff) | ||
rel goal(label) = question(question) and extract_goal(question, label) | ||
|
||
// Do computation | ||
rel derived_date(label, date) = mentioned_date(label, date) | ||
rel derived_date(label, date - diff) = relationship(label, other, diff) and derived_date(other, date) | ||
rel derived_date(label, date + diff) = relationship(other, label, diff) and derived_date(other, date) | ||
|
||
// Get the answer corresponding to the goal | ||
rel answer(date) = goal(label) and derived_date(label, date) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# date understanding | ||
Download and extract the dataset from [this link](https://drive.google.com/file/d/1aXUpQna54Dx4lt4wqRIdzTRanMzjXWjy/view?usp=drive_link). | ||
|
||
| Model | Score | Accuracy | | ||
| --- | --- | --- | | ||
| Scallop (5-shot + 0-shot CoT) | 341 | 92.41% | | ||
| Scallop (5-shot) | 304 | 82.38% | | ||
| GPT-4 (5-shot CoT) | 350 | **94.85%** | | ||
| GPT-4 (0-shot CoT) | 322 | 87.26% | | ||
| GPT-4 (0-shot) | 262 | 71.00% | |
71 changes: 71 additions & 0 deletions
71
experiments/big-bench/date-understanding/test_gpt_baseline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import openai | ||
import os | ||
import json | ||
from tqdm import tqdm | ||
|
||
DATASET = os.path.abspath(os.path.join(__file__, "../dataset/task-corrected.json")) | ||
TASK = json.load(open(DATASET)) | ||
ZERO_SHOT_HEADER = """Answer only in the format of MM/DD/YYYY, where MM is 2 digits for the month, DD is 2 digits for the day, and YYYY is 4 digits for the year. Do not include anything else with your answer. | ||
Examples: | ||
April 20, 2021 is 04/20/2021 in MM/DD/YYYY form. | ||
Answer the following: | ||
""" | ||
FEW_SHOT_HEADER = """Answer only in the format of MM/DD/YYYY, where MM is 2 digits for the month, DD is 2 digits for the day, and YYYY is 4 digits for the year. Do not include anything else with your answer. | ||
Examples: | ||
Question: The deadline is Jun 1, 2021, which is 2 days away from now. What is the date 10 days ago in MM/DD/YYYY? | ||
05/20/2021 | ||
Question: Jenny began her current job on the Christmas Eve of 2016. What is the 5th anniversary in MM/DD/YYYY? | ||
12/24/2021 | ||
Question: Today is March 5th, 2010. Mark earns $1000 per day starting from now. When can Mark earh $10000 in MM/DD/YYYY? | ||
03/15/2010 | ||
Answer the following: | ||
""" | ||
|
||
|
||
def run_gpt(question): | ||
messages = [{"role": "user", "content": ZERO_SHOT_HEADER + question}] | ||
response = openai.ChatCompletion.create( | ||
model="gpt-4", | ||
messages=messages, | ||
temperature=0, | ||
) | ||
return response["choices"][0]["message"]["content"] | ||
|
||
|
||
def test_date_understanding(): | ||
out = {"score": 0, "data": []} | ||
|
||
pbar = tqdm(TASK["examples"]) | ||
for example in pbar: | ||
try: | ||
answer = run_gpt(example["input"]) | ||
score = int( | ||
answer in example["target_scores"] and example["target_scores"][answer] | ||
) | ||
out["score"] += score | ||
out["data"] += [ | ||
{ | ||
"question": example["input"], | ||
"answer": answer, | ||
"score": score, | ||
} | ||
] | ||
except Exception as e: | ||
out["data"] += [ | ||
{"question": example["input"], "exception": str(e), "score": 0} | ||
] | ||
|
||
pbar.set_postfix({"score": out["score"]}) | ||
|
||
json_object = json.dumps(out.copy(), indent=4) | ||
with open("data.json", "w") as outfile: | ||
outfile.write(json_object) | ||
|
||
|
||
test_date_understanding() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import openai | ||
import os | ||
import json | ||
import re | ||
from tqdm import tqdm | ||
|
||
|
||
FEW_SHOT = True | ||
SHOTS = """ | ||
Examples: | ||
Q: Yesterday is February 14, 2019. What is the date 1 month ago? | ||
A: Let's think step by step. Since yesterday is February 14, 2019, that means today is February 15, 2019. 1 month ago from today is January 15, 2019. Therefore, the answer is 01/15/2019. | ||
Q: Yesterday is February 14, 2019. What is the date 1 year later? | ||
A: Let's think step by step. Since yesterday is February 14, 2019, that means today is February 15, 2019. 1 year later from today is February 15, 2020. Therefore, the answer is 02/15/2020. | ||
Q: The deadline is August 15, 2023, which is today. What is the date today? | ||
A: Let's think step by step. Since the deadline is August 15, 2023, and since today is the deadline, that means today is August 15, 2023. Therefore, the answer is 08/15/2023. | ||
Q: Jenny began her current job on the Christmas Eve of 2016. What is the 5th anniversary? | ||
A: Let's think step by step. Since Jenny began her current job on Christmas Eve of 2016, that means she began on her job on December 24, 2016. The 5th anniversary is 5 years later from that date, so that date is December 24, 2021. Therefore, the answer is 12/24/2021. | ||
Q: Today is March 5th, 2010. Mark earns $1000 per day starting from now. When can Mark earn $10000? | ||
A: Let's think step by step. Since Mark earns $1000 per day, it takes $10000 / $1000 = 10 days for Mark to earn $10000. Because today is March 5th, 2010, Mark will earn $10000 after 10 days, which is March 15th, 2010. Therefore, the answer is 03/15/2010. | ||
Now here is the question: | ||
""" | ||
COT_PROMPT = "Let's think step by step." | ||
COT_EXTRACTION = "Therefore, in MM/DD/YYYY form, the answer is" | ||
REGEX = r"\d\d\/\d\d\/\d\d\d\d" | ||
|
||
DATASET = os.path.abspath(os.path.join(__file__, "../dataset/task-corrected.json")) | ||
TASK = json.load(open(DATASET)) | ||
N = len(TASK["examples"]) | ||
|
||
|
||
def run_gpt(question): | ||
messages = [{"role": "user", "content": question}] | ||
response = openai.ChatCompletion.create( | ||
model="gpt-4", | ||
messages=messages, | ||
temperature=0, | ||
) | ||
return response["choices"][0]["message"]["content"] | ||
|
||
|
||
def test_tracking(range): | ||
out = {"score": 0, "data": []} | ||
|
||
pbar = tqdm(range) | ||
for i in pbar: | ||
example = TASK["examples"][i] | ||
question = f"Q: {example['input']}\nA: {COT_PROMPT}" | ||
try: | ||
if FEW_SHOT: | ||
response = run_gpt(SHOTS + question) | ||
else: | ||
response = run_gpt(question) | ||
question2 = f"{question} {response}\n{COT_EXTRACTION}" | ||
response2 = run_gpt(question2) | ||
pred = re.findall(REGEX, response2) | ||
pred = pred[0] if len(pred) > 0 else "" | ||
score = int( | ||
pred in example["target_scores"] and example["target_scores"][pred] | ||
) | ||
out["score"] += score | ||
out["data"] += [ | ||
{ | ||
"id": i, | ||
"question": question2, | ||
"response": response2, | ||
"answer": pred, | ||
"score": score, | ||
} | ||
] | ||
except Exception as e: | ||
out["data"] += [ | ||
{"id": i, "question": question, "exception": str(e), "score": 0} | ||
] | ||
|
||
pbar.set_postfix({"score": out["score"]}) | ||
|
||
json_object = json.dumps(out.copy(), indent=4) | ||
with open("data.json", "w") as outfile: | ||
outfile.write(json_object) | ||
|
||
|
||
test_tracking(range(N)) |
95 changes: 95 additions & 0 deletions
95
experiments/big-bench/date-understanding/test_gpt_scallop.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import json | ||
from tqdm import tqdm | ||
from datetime import datetime | ||
import scallopy | ||
import scallopy_ext | ||
import os | ||
from io import StringIO | ||
import sys | ||
|
||
DATASET = os.path.abspath(os.path.join(__file__, "../dataset/task-corrected.json")) | ||
TASK = json.load(open(DATASET)) | ||
N = len(TASK["examples"]) | ||
SCALLOP_FILE = os.path.abspath(os.path.join(__file__, "../date-compute.scl")) | ||
|
||
|
||
class Args: | ||
def __init__(self): | ||
self.cuda = False | ||
self.gpu = None | ||
self.num_allowed_openai_request = N | ||
self.openai_gpt_model = "gpt-4" | ||
self.openai_gpt_temperature = 0 | ||
|
||
|
||
def test_date_understanding(range): | ||
out = {"score": 0, "data": [], "logs": []} | ||
|
||
# Configure scallop extension library | ||
scallopy_ext.config.configure(Args()) | ||
|
||
# Setup scallop context | ||
ctx = scallopy.ScallopContext(provenance="unit") | ||
scallopy_ext.extlib.load_extlib(ctx) | ||
ctx.import_file(SCALLOP_FILE) | ||
|
||
# Iterate through al the datapoints | ||
pbar = tqdm(range) | ||
for i in pbar: | ||
example = TASK["examples"][i] | ||
buffer = StringIO() | ||
sys.stdout = buffer | ||
|
||
try: | ||
temp_ctx = ctx.clone() | ||
temp_ctx.add_facts("question", [(example["input"],)]) | ||
temp_ctx.run() | ||
res = [ | ||
datetime.strptime(x[0].split(" ")[0], "%Y-%m-%d").strftime("%m/%d/%Y") | ||
for x in list(temp_ctx.relation("answer")) | ||
] | ||
score = 0 | ||
final_answer = "" | ||
for answer in example["target_scores"]: | ||
if answer in res: | ||
final_answer = answer | ||
score = example["target_scores"][answer] | ||
break | ||
out["score"] += score | ||
out["data"] += [ | ||
{ | ||
"id": i, | ||
"question": example["input"], | ||
"final_answer": final_answer, | ||
"score": score, | ||
"mentioned_date": list(temp_ctx.relation("mentioned_date")), | ||
"relationship": list(temp_ctx.relation("relationship")), | ||
"goal": list(temp_ctx.relation("goal")), | ||
"answer": list(temp_ctx.relation("answer")), | ||
} | ||
] | ||
except Exception as e: | ||
out["data"] += [ | ||
{ | ||
"id": i, | ||
"question": example["input"], | ||
"exception": str(e), | ||
"score": 0, | ||
} | ||
] | ||
|
||
out["logs"] += [ | ||
{ | ||
"id": i, | ||
"log": buffer.getvalue().encode("utf-8").decode("unicode_escape"), | ||
} | ||
] | ||
|
||
pbar.set_postfix({"score": out["score"]}) | ||
|
||
json_object = json.dumps(out.copy(), indent=2) | ||
with open("data.json", "w") as outfile: | ||
outfile.write(json_object) | ||
|
||
|
||
test_date_understanding(range(N)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# object tracking | ||
Download and extract the dataset from [this link](https://drive.google.com/file/d/1tpKWhBHSCyA7HE2RoBmpSRprgjqUJa1P/view?usp=drive_link). | ||
|
||
| Model | Score | Accuracy | | ||
| --- | --- | --- | | ||
| Scallop (1-shot) | 150 | **100.00%** | | ||
| GPT-4 (3-shot CoT) | 149 | 99.33% | | ||
| GPT-4 (0-shot CoT) | 126 | 84.00% | | ||
| GPT-4 (0-shot) | 45 | 30.00% | |
Oops, something went wrong.