-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_perturbations.py
executable file
·56 lines (49 loc) · 1.95 KB
/
test_perturbations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from data.utils import (
sentence_permutation,
document_rotation,
text_infilling,
token_masking,
token_deletion,
)
from transformers import AutoTokenizer
text = """
Questa è una frase lunga che verrà usata per fare dei test.
Le parole sono scritte in modo casuale per simulare un testo reale.
In questo modo si possono testare le funzioni di perturbazione.
Le funzioni di perturbazione sono state implementate in modo da poter essere usate in modo indipendente.
"""
text = text.replace("\n", " ")
tokenizer = AutoTokenizer.from_pretrained("tokenizer_bart_it")
list_special_tokens = tokenizer.all_special_ids
# sentence_permutation - operates on text strings
perturbed_text = sentence_permutation(text)
print("\n\nText:", text)
print("PERTURBED sentence_permutation:", perturbed_text)
perturbed_text = document_rotation(text)
print("\n\nText:", text)
print("PERTURBED document_rotation:", perturbed_text)
tokenized_input_ids = tokenizer(text, return_tensors="pt")["input_ids"][0]
perturbed_tokenized_text = token_deletion(
tokenized_input_ids, list_special_tokens=list_special_tokens
)
perturbed_text = tokenizer.decode(perturbed_tokenized_text)
print("\n\nText:", text)
print("PERTURBED token_deletion:", perturbed_text)
tokenized_input_ids = tokenizer(text, return_tensors="pt")["input_ids"][0]
perturbed_tokenized_text = token_masking(
tokenized_input_ids,
mask_token_id=tokenizer.mask_token_id,
list_special_tokens=list_special_tokens,
)
perturbed_text = tokenizer.decode(perturbed_tokenized_text)
print("\n\nText:", text)
print("PERTURBED token_masking:", perturbed_text)
tokenized_input_ids = tokenizer(text, return_tensors="pt")["input_ids"][0]
perturbed_tokenized_text = token_infilling(
tokenized_input_ids,
mask_token_id=tokenizer.mask_token_id,
list_special_tokens=list_special_tokens,
)
perturbed_text = tokenizer.decode(perturbed_tokenized_text)
print("\n\nText:", text)
print("PERTURBED token_infilling:", perturbed_text)