-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_models.py
executable file
·91 lines (68 loc) · 4.01 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import re
import sys
import argparse
import random
import time
import pickle
from torch.cuda import device_count
from utils import convert_time, readfile, split_data, create_ngram, split_data_features_labels
from test_translations import test_translation
from vectorization import load_gensim_model, remove_words, get_vocabulary, generate_indices, gen_tri_vec_split, get_w2v_vectors, generate_translation_vectors
parser = argparse.ArgumentParser(description="Feed forward neural networks.")
parser.add_argument("targetfile", type=str, default="UN-english", nargs='?', help="File used as target language.")
parser.add_argument("sourcefile", type=str, default="UN-french", nargs='?', help="File used as source language..")
parser.add_argument("googlemodelfile", type=str, default="GoogleNews-vectors-negative300.bin", nargs='?', help="Pre-trained word2vec 300-dimensional vectors.")
parser.add_argument("languagemodel", type=str, default="trained_language_model", nargs='?', help="Trained language model")
parser.add_argument("translationmodel", type=str, default="trained_translation_model", nargs='?', help="Trained translation model")
parser.add_argument("-T", "--top", metavar="T", dest="top_predicted", type=int, default=50, help="Top n predicted words (default 50).")
parser.add_argument("-P", "--processor", metavar="P", dest="processor", type=str, default="cpu", help="Select processing unit (default cpu).")
parser.add_argument("-S", "--test_size", metavar="T", dest="test_size", type=float, default=0.2, help="Size in percentage of test set (default 0.2).")
args = parser.parse_args()
# Variables
#test_size = 0.2 # Test data %, default 20
test_size = args.test_size
p = args.processor
start = time.time()
if test_size < 0 or test_size > 1:
exit("Error: Test size must be a number between 0 and lower than 1, e.g. 0.2")
processor_valid = False
if p.lower() == "cpu":
processor_valid = True
else:
gpu_rg = r'cuda\:(\d{1,2})'
m = re.search(gpu_rg, p, flags=re.I)
if m:
gpu_num = int(m.group(1))
if gpu_num <= device_count() and gpu_num > 0:
processor_valid = True
if processor_valid is False:
exit("Processor type is invalid - only 'cuda' and 'cpu' are valid device types. Only up to cuda:%d are valid" % device_count())
print("Using {}.".format(args.processor))
print("Loading target language from {} and source language from {}.".format(args.targetfile, args.sourcefile))
target, source = readfile(args.targetfile, args.sourcefile)
print("Loading model from {}.".format(args.googlemodelfile))
pre_trained_model = load_gensim_model(args.googlemodelfile)
print("Removing words not found in the model.")
target_data, source_data = remove_words(target, source, pre_trained_model)
print("Splitting data into training and testing sets, {}/{}.".format((test_size*100), (100-(test_size*100))))
target_train, target_test, source_train, source_test = split_data(target_data, source_data, test_size)
print("Generating vocabulary for source text.")
source_vocab = get_vocabulary(source_data)
print("Generating vocabulary for target text.")
target_vocab = get_vocabulary(target_data)
print("Generating word indices.")
source_indices = generate_indices(source_vocab)
target_indices = generate_indices(target_vocab)
print("Fetching vectors from model.")
vectors = get_w2v_vectors(pre_trained_model, target_vocab)
print("Loading language model.")
with open(args.languagemodel, 'rb') as lmf:
trigram_target_model = pickle.load(lmf)[1]
print("Loading translation model.")
with open(args.translationmodel, 'rb') as tmf:
translation_model = pickle.load(tmf)[1]
test_translation(target_test, source_test, target_vocab, source_vocab, vectors, target_indices, source_indices, trigram_target_model, translation_model, p, args.top_predicted)
stop = time.time()
hour, minute, second = (convert_time(start, stop))
print("Predicted {} sentences on {} hours, {} minutes and {} seconds".format(len(source_test), hour, minute, second))