-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
executable file
·115 lines (85 loc) · 3.54 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import re
import html
import itertools
from collections import Counter
import numpy as np
import torch
import torch.utils.data
from nltk import word_tokenize, sent_tokenize
from nltk.tokenize import TweetTokenizer
class LanguageModelDataset(torch.utils.data.Dataset):
PAD_TOKEN = '<pad>'
INIT_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'
UNK_TOKEN = '<unk>'
def __init__(self, sentences, max_len=20):
super().__init__()
self.max_len = max_len
self.token2id = {}
self.id2token = {}
self.token_counts = Counter()
self.special_tokens = [
LanguageModelDataset.PAD_TOKEN, LanguageModelDataset.UNK_TOKEN,
LanguageModelDataset.INIT_TOKEN, LanguageModelDataset.EOS_TOKEN,
]
self._build_vocab([self.special_tokens, ])
self._build_vocab(sentences)
self._prune_vocab(min_count=2)
# cut to max len and append the end-of-sentence tokens
sentences = [s[:max_len - 1] for s in sentences]
sentences = [s + [LanguageModelDataset.EOS_TOKEN, ] for s in sentences]
self.sentences = sentences
self.nb_sentences = len(sentences)
def _build_vocab(self, data):
for token in itertools.chain.from_iterable(data):
self.token_counts[token] += 1
if token not in self.token2id:
self.token2id[token] = len(self.token2id)
self.id2token = {i: t for t, i in self.token2id.items()}
def _prune_vocab(self, min_count=2):
nb_tokens_before = len(self.token2id)
tokens_to_delete = set([t for t, c in self.token_counts.items() if c < min_count])
tokens_to_delete ^= set(self.special_tokens)
for token in tokens_to_delete:
self.token_counts.pop(token)
self.token2id = {t: i for i, t in enumerate(self.token_counts.keys())}
self.id2token = {i: t for t, i in self.token2id.items()}
print('Vocab pruned: {} -> {}'.format(nb_tokens_before, len(self.token2id)))
def __getitem__(self, index):
sentence = self.sentences[index]
# pad to max_len
nb_pads = self.max_len - len(sentence)
if nb_pads > 0:
sentence = sentence + [LanguageModelDataset.PAD_TOKEN] * nb_pads
# convert to indices
sentence = [
self.token2id[t] if t in self.token2id else self.token2id[LanguageModelDataset.UNK_TOKEN]
for t in sentence
]
sentence = np.array(sentence)
return sentence
def __len__(self):
return self.nb_sentences
class TrumpSpeechesDataset(LanguageModelDataset):
def __init__(self, filename, *args, **kwargs):
sentences = self._load_file(filename)
super().__init__(sentences, *args, **kwargs)
def _load_file(self, filename):
"""
Load a file with tweet, one tweet per line
:param filename: The path to the file
:return: A list of lists of tokens, e.g. [ [I, am, great, ...], [It, is, going, to, ...], ...]
"""
tweets = []
toknzr = TweetTokenizer()
with open(filename, 'r') as tweet_file:
for line in tweet_file:
#line = line.strip().split()
line = tweet_file.readline().replace('\n', '').lower()
line = re.sub(r'((www\.[\S]+)|(https?://[\S]+))', ' URL ', line)
line = html.unescape(line)
#print(line)
processed_line = toknzr.tokenize(line)
#print(processed_line)
tweets.append(processed_line)
return tweets