-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
130 lines (92 loc) · 3.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import csv
import importlib
import logging
from enum import Enum
from arekit.contrib.source.synonyms.utils import iter_synonym_groups
from arelight.pipelines.demo.labels.scalers import ThreeLabelScaler
from arelight.utils import auto_import
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def create_sentence_parser(arg):
if arg == "linesplit":
return lambda text: [t.strip() for t in text.split('\n')]
elif arg == "ru":
# Using ru_sent_tokenize library.
ru_sent_tokenize = importlib.import_module("ru_sent_tokenize")
return lambda text: ru_sent_tokenize.ru_sent_tokenize(text)
elif arg == "nltk_en":
# Using nltk library.
nltk = importlib.import_module("nltk")
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
return tokenizer.tokenize
else:
raise Exception("Arg `{}` was not found".format(arg))
def create_translate_model(arg):
if arg == "googletrans":
# We do auto-import so we not depend on the actually installed library.
translate_value = auto_import("arelight.third_party.googletrans.translate_value")
# Translation of the list of data.
# Returns the list of strings.
return lambda str_list, src, dest: [translate_value(s, dest=dest, src=src) for s in str_list]
def create_labels_scaler(labels_count):
assert (isinstance(labels_count, int))
if labels_count == 3:
return ThreeLabelScaler()
raise NotImplementedError("Not supported")
def iter_group_values(filepath):
if filepath is None:
return None
with open(filepath, 'r') as file:
for group in iter_synonym_groups(file):
yield group
class EnumConversionService(object):
_data = None
@classmethod
def is_supported(cls, name):
assert(isinstance(cls._data, dict))
return name in cls._data
@classmethod
def name_to_type(cls, name):
assert(isinstance(cls._data, dict))
assert(isinstance(name, str))
return cls._data[name]
@classmethod
def iter_names(cls):
assert(isinstance(cls._data, dict))
return iter(list(cls._data.keys()))
@classmethod
def type_to_name(cls, enum_type):
assert(isinstance(cls._data, dict))
assert(isinstance(enum_type, Enum))
for item_name, item_type in cls._data.items():
if item_type == enum_type:
return item_name
raise NotImplemented("Formatting type '{}' does not supported".format(enum_type))
def merge_dictionaries(dict_iter):
merged_dict = {}
for d in dict_iter:
for key, value in d.items():
if key in merged_dict:
raise Exception("Key `{}` is already registred!".format(key))
merged_dict[key] = value
return merged_dict
def iter_csv_lines(csv_filepath, column_name, delimiter=","):
with open(csv_filepath, mode='r', encoding='utf-8-sig') as csv_file:
csv_reader = csv.DictReader(csv_file, delimiter=delimiter)
if column_name not in csv_reader.fieldnames:
print(f"Error: {column_name} column not found.")
for row in csv_reader:
yield row[column_name]
def read_files(paths, delimiter, csv_column):
if paths is None:
return None
file_contents = []
for path in paths:
if path.endswith(".csv"):
# Handle as a column from the csv file.
file_contents.extend(list(iter_csv_lines(path, column_name=csv_column, delimiter=delimiter)))
else:
# Handle as a normal file.
with open(path) as f:
file_contents.append(f.read().rstrip())
return file_contents