Skip to content

Commit

Permalink
Advanced Method - BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
CeliTop committed Nov 14, 2022
1 parent 224c6d5 commit 7259caa
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 56 deletions.
22 changes: 14 additions & 8 deletions BM_25_pyterrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,34 @@
# for the next parts of the project

import pyterrier as pt
import pandas as pd
import gzip
import shutil
import os as os
from pyterrier.measures import *
import pandas as pd

# Init pyterrier
pt.init()

# Get MS Marco passages used in TREC-2019
dataset = pt.get_dataset("trec-deep-learning-passages")

print(dataset)
# Get corpus
pathCorpus = dataset.get_corpus()
print(pathCorpus)

# Get the index stemmed (Porter stemmer)
path = dataset.get_index("terrier_stemmed")
index = pt.IndexFactory.of(path)

# Get the queries
queries = dataset.get_topics("test-2019")
queries = dataset.get_topics("test-2020")
print("query examples")
print(queries)
print()

# Get the qrels
qrels = dataset.get_qrels("test-2019")
qrels = dataset.get_qrels("test-2020")
print("qrel examples:")
print(qrels)
print()
Expand All @@ -38,23 +40,25 @@
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

# Run BM-25 on the whole test dataset
pt.Experiment(
results = pt.Experiment(
[bm25],
queries,
qrels,
eval_metrics=["map", "recip_rank", "ndcg"],
eval_metrics=["map", "recip_rank", "ndcg", "recall"],
save_dir="./",
save_mode="overwrite",
dataframe=True,
)
print(results)

# Run BM-25 on a subset of queries
queries_uni = queries.loc[queries["qid"] == str(156493)]
queries_uni = queries.loc[queries["qid"] == str(1037496)]
print(queries_uni)
pt.Experiment(
[bm25],
queries_uni,
qrels,
eval_metrics=["map", "recip_rank", "ndcg"],
eval_metrics=["map", "recip_rank", "ndcg", "recall"],
perquery=True,
dataframe=True,
)
Expand All @@ -79,3 +83,5 @@
with open("retrieved.txt", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove("BR(BM25).res.gz")

print("The ranking is generated !")
48 changes: 0 additions & 48 deletions evaluate.py

This file was deleted.

92 changes: 92 additions & 0 deletions reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from sentence_transformers import CrossEncoder
import pyterrier as pt
import pandas as pd

PATH_TO_TOP_1000 = "retrieved.txt"
OUTPUT_PATH = "x.txt"

# Init
pd.set_option("display.max_rows", None)
pd.set_option("display.max_colwidth", 150)
pt.init()

dataset = pt.get_dataset("trec-deep-learning-passages")

# Get the previously retrieved top 1000 (by a baseline method)
retrieved = pd.read_csv(PATH_TO_TOP_1000, sep=" ")
retrieved.columns = ["qid", "Q0", "docID", "rank", "score", "system"]
print(retrieved.dtypes)
print(retrieved.head(n=5))
print()

# Get the queries
queries = dataset.get_topics("test-2020")
queries = queries.astype({"qid": "int64", "query": "string"})
print(queries.dtypes)
print("query examples")
print(queries.head(n=5))
print()


# Get the text corpus
pathCorpus = dataset.get_corpus()
print(pathCorpus[0])
print("Load CSV...")
corpus = pd.read_csv(pathCorpus[0], sep="\t")
corpus.columns = ["docno", "text"]
corpus = corpus.astype({"text": "string"})
print("corpus examples:")
print(corpus.head(n=5))
print()

model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-2-v2", max_length=512)


def getReranked(qid):
querytext = queries.loc[queries["qid"] == qid].iloc[0]["query"]
print("query text: ", querytext)
docIds = retrieved.loc[retrieved["qid"] == qid]["docID"]
print(docIds.head(n=5))
docs = corpus.loc[corpus["docno"].isin(docIds)]
print(docs.head(n=5))
print()

print("Predict...")
couples = [(querytext, docText) for docText in docs["text"]]
scores = model.predict(couples)
print(scores)

print("Sort...")
sorted_indices = [i[0] for i in sorted(enumerate(scores), key=lambda x: -x[1])]

top = docs.iloc[sorted_indices]
return top


s = ""
numberquery = 0
for qid in retrieved["qid"].unique():
print(numberquery, " query processed ...")
numberquery += 1
top = getReranked(qid)
i = 0
for index, row in top.iterrows():
s += (
str(qid)
+ " "
+ "Q0"
+ " "
+ str(row["docno"])
+ " "
+ str(i)
+ " "
+ str(1 / (i + 1))
+ " "
+ "BERT"
+ "\n"
)
i += 1


with open(OUTPUT_PATH, "w+") as file:
file.write(s)

0 comments on commit 7259caa

Please sign in to comment.