-
Notifications
You must be signed in to change notification settings - Fork 32
/
convert_msmarco_doc_to_t5_format.py
61 lines (48 loc) · 2.03 KB
/
convert_msmarco_doc_to_t5_format.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import re
import spacy
from tqdm import tqdm
def load_corpus(path):
print('Loading corpus...')
corpus = {}
for line in tqdm(open(path)):
doc_id, doc_url, doc_title, doc_text = line.split('\t')
doc_text = doc_text.strip()
corpus[doc_id] = (doc_title, doc_text)
return corpus
parser = argparse.ArgumentParser(
description='Create T5-formatted tsv file from MS MARCO Document Ranking '
'dataset.')
parser.add_argument('--corpus_path', required=True, default='', help='')
parser.add_argument('--output_passage_texts_path', required=True, default='',
help='')
parser.add_argument('--output_passage_doc_ids_path', required=True, default='',
help='')
parser.add_argument('--stride', default=5, help='')
parser.add_argument('--max_length', default=10, help='')
args = parser.parse_args()
nlp = spacy.blank("en")
nlp.add_pipe(nlp.create_pipe("sentencizer"))
corpus = load_corpus(path=args.corpus_path)
n_passages = 0
n_no_passages = 0
with open(args.output_passage_texts_path, 'w') as fout_passage_texts, \
open(args.output_passage_doc_ids_path, 'w') as fout_passage_doc_ids:
for doc_id, (doc_title, doc_text) in tqdm(corpus.items(), total=len(corpus)):
doc = nlp(doc_text[:10000])
sentences = [sent.string.strip() for sent in doc.sents]
if not sentences:
n_no_passages += 1
for i in range(0, len(sentences), args.stride):
segment = ' '.join(sentences[i:i + args.max_length])
segment = doc_title + ' ' + segment
# Remove starting #'s as T5 skips those lines by default.
segment = re.sub(r'^#*', '', segment)
fout_passage_doc_ids.write(f'{doc_id}\n')
fout_passage_texts.write(f'{segment}\n')
n_passages += 1
if i + args.max_length >= len(sentences):
break
print(f'Wrote {n_passages} passages from {len(corpus)} docs.')
print(f'There were {n_no_passages} docs without passages.')
print('Done!')