Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-3243: support pickle & deepcopy #3245

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 29 additions & 45 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,25 +629,9 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
class Span(_PartOfSentence):
"""This class represents one textual span consisting of Tokens."""

def __new__(self, tokens: List[Token]):
# check if the span already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(tokens)
if unlabeled_identifier in tokens[0].sentence._known_spans:
span = tokens[0].sentence._known_spans[unlabeled_identifier]
return span

# else make a new span
else:
span = super().__new__(self)
span.initialized = False
tokens[0].sentence._known_spans[unlabeled_identifier] = span
return span

def __init__(self, tokens: List[Token]) -> None:
if not self.initialized:
super().__init__(tokens[0].sentence)
self.tokens = tokens
self.initialized: bool = True
super().__init__(tokens[0].sentence)
self.tokens = tokens

@property
def start_position(self) -> int:
Expand Down Expand Up @@ -696,26 +680,10 @@ def to_dict(self, tag_type: Optional[str] = None):


class Relation(_PartOfSentence):
def __new__(self, first: Span, second: Span):
# check if the relation already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(first, second)
if unlabeled_identifier in first.sentence._known_spans:
span = first.sentence._known_spans[unlabeled_identifier]
return span

# else make a new relation
else:
span = super().__new__(self)
span.initialized = False
first.sentence._known_spans[unlabeled_identifier] = span
return span

def __init__(self, first: Span, second: Span) -> None:
if not self.initialized:
super().__init__(sentence=first.sentence)
self.first: Span = first
self.second: Span = second
self.initialized: bool = True
super().__init__(sentence=first.sentence)
self.first: Span = first
self.second: Span = second

def __repr__(self) -> str:
return str(self)
Expand Down Expand Up @@ -793,7 +761,7 @@ def __init__(
self.tokens: List[Token] = []

# private field for all known spans
self._known_spans: Dict[str, _PartOfSentence] = {}
self._known_parts: Dict[str, _PartOfSentence] = {}

self.language_code: Optional[str] = language_code

Expand Down Expand Up @@ -870,7 +838,7 @@ def get_relations(self, label_type: Optional[str] = None) -> List[Relation]:

def get_spans(self, label_type: Optional[str] = None) -> List[Span]:
spans: List[Span] = []
for potential_span in self._known_spans.values():
for potential_span in self._known_parts.values():
if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)):
spans.append(potential_span)
return sorted(spans)
Expand Down Expand Up @@ -1047,18 +1015,34 @@ def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
}

def get_span(self, start: int, stop: int) -> Span:
span_slice = slice(start, stop)
return self[span_slice]
return self[start:stop]

@typing.overload
def __getitem__(self, idx: int) -> Token: ...

@typing.overload
def __getitem__(self, s: slice) -> Span: ...

@typing.overload
def __getitem__(self, s: typing.Tuple[Span, Span]) -> Relation: ...

def __getitem__(self, subscript):
if isinstance(subscript, slice):
return Span(self.tokens[subscript])
if isinstance(subscript, tuple):
first, second = subscript
identifier = ""
if isinstance(first, Span) and isinstance(second, Span):
identifier = Relation._make_unlabeled_identifier(first, second)
if identifier not in self._known_parts:
self._known_parts[identifier] = Relation(first, second)

return self._known_parts[identifier]
elif isinstance(subscript, slice):
identifier = Span._make_unlabeled_identifier(self.tokens[subscript])

if identifier not in self._known_parts:
self._known_parts[identifier] = Span(self.tokens[subscript])

return self._known_parts[identifier]
else:
return self.tokens[subscript]

Expand Down Expand Up @@ -1210,11 +1194,11 @@ def remove_labels(self, typename: str):
token.remove_labels(typename)

# labels also need to be deleted at all known spans
for span in self._known_spans.values():
for span in self._known_parts.values():
span.remove_labels(typename)

# remove spans without labels
self._known_spans = {k: v for k, v in self._known_spans.items() if len(v.labels) > 0}
self._known_parts = {k: v for k, v in self._known_parts.items() if len(v.labels) > 0}

# delete labels at object itself
super().remove_labels(typename)
Expand Down
5 changes: 1 addition & 4 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Corpus,
FlairDataset,
MultiCorpus,
Relation,
Sentence,
Token,
get_spans_from_bio,
Expand Down Expand Up @@ -731,9 +730,7 @@ def _convert_lines_to_sentence(
tail_end = int(indices[3])
label = indices[4]
# head and tail span indices are 1-indexed and end index is inclusive
relation = Relation(
first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end]
)
relation = sentence[sentence[head_start - 1 : head_end], sentence[tail_start - 1 : tail_end]]
remapped = self._remap_label(label)
if remapped != "O":
relation.add_label(typename="relation", value=remapped)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_token_span(self, span: Tuple[int, int]) -> Span:
"""
span_start: int = self.__tokens_start_pos.index(span[0])
span_end: int = self.__tokens_end_pos.index(span[1])
return Span(self.tokens[span_start : span_end + 1])
return self.sentence[span_start : span_end + 1]


class RegexpTagger:
Expand Down
17 changes: 8 additions & 9 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,9 @@ def _entity_pair_permutations(
"""
valid_entities: List[_Entity] = list(self._valid_entities(sentence))

# Use a dictionary to find gold relation annotations for a given entity pair
relation_to_gold_label: Dict[str, str] = {
relation.unlabeled_identifier: relation.get_label(self.label_type, zero_tag_value=self.zero_tag_value).value
for relation in sentence.get_relations(self.label_type)
}
# ensure that all existing relations without label have the label set to zero_tag_value.
for relation in sentence.get_relations(self.label_type):
relation.set_label(self.label_type, relation.get_label(self.label_type, self.zero_tag_value).value)

# Yield head and tail entity pairs from the cross product of all entities
for head, tail in itertools.product(valid_entities, repeat=2):
Expand All @@ -393,9 +391,10 @@ def _entity_pair_permutations(
continue

# Obtain gold label, if existing
original_relation: Relation = Relation(first=head.span, second=tail.span)
gold_label: Optional[str] = relation_to_gold_label.get(original_relation.unlabeled_identifier)

gold_relation = sentence[head.span, tail.span]
gold_label: Optional[str] = gold_relation.get_label(self.label_type, zero_tag_value="O").value
if gold_label == "O":
gold_label = None
yield head, tail, gold_label

def _encode_sentence(
Expand Down Expand Up @@ -479,7 +478,7 @@ def _encode_sentence_for_inference(
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
original_relation: Relation = sentence[head.span, tail.span]
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
Expand Down
2 changes: 1 addition & 1 deletion flair/models/relation_extractor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_data_points_from_sentence(self, sentence: Sentence) -> List[Relation]:
):
continue

relation = Relation(span_1, span_2)
relation = sentence[span_1, span_2]
if self.training and self.train_on_gold_pairs_only and relation.get_label(self.label_type).value == "O":
continue
entity_pairs.append(relation)
Expand Down
25 changes: 11 additions & 14 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,9 @@ def _get_tars_formatted_sentence(self, label, sentence):

for entity_label in sentence.get_labels(self.label_type):
if entity_label.value == label:
new_span = Span(
[tars_sentence.get_token(token.idx + label_length) for token in entity_label.data_point]
)
new_span.add_label(self.static_label_type, value="entity")
start_pos = entity_label.data_point[0].idx + label_length - 1
end_pos = entity_label.data_point[-1].idx + label_length
tars_sentence[start_pos:end_pos].add_label(self.static_label_type, value="entity")
tars_sentence.copy_context_from_sentence(sentence)
return tars_sentence

Expand Down Expand Up @@ -572,19 +571,16 @@ def predict(

already_set_indices: List[int] = []

sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
sorted_x.reverse()
for tuple in sorted_x:
# get the span and its label
label = tuple[0]

sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1), reverse=True)
for label, _ in sorted_x:
span = typing.cast(Span, label.data_point)
label_length = (
0 if not self.prefix else len(label.value.split(" ")) + len(self.separator.split(" "))
)

# determine whether tokens in this span already have a label
tag_this = True
for token in label.data_point:
for token in span:
corresponding_token = sentence.get_token(token.idx - label_length)
if corresponding_token is None:
tag_this = False
Expand All @@ -596,9 +592,10 @@ def predict(
# only add if all tokens have no label
if tag_this:
# make and add a corresponding predicted span
predicted_span = Span(
[sentence.get_token(token.idx - label_length) for token in label.data_point]
)
start_pos = span.tokens[0].idx - label_length - 1
end_pos = span.tokens[-1].idx - label_length

predicted_span = sentence[start_pos:end_pos]
predicted_span.add_label(label_name, value=label.value, score=label.score)

# set indices so that no token can be tagged twice
Expand Down
6 changes: 3 additions & 3 deletions tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def test_relation_tags():
sentence = Sentence("Humboldt Universität zu Berlin is located in Berlin .")

# create two relation label
Relation(sentence[0:4], sentence[7:8]).add_label("rel", "located in")
Relation(sentence[0:2], sentence[3:4]).add_label("rel", "university of")
Relation(sentence[0:2], sentence[3:4]).add_label("syntactic", "apposition")
sentence[sentence[0:4], sentence[7:8]].add_label("rel", "located in")
sentence[sentence[0:2], sentence[3:4]].add_label("rel", "university of")
sentence[sentence[0:2], sentence[3:4]].add_label("syntactic", "apposition")

# there should be two relation labels
labels: List[Label] = sentence.get_labels("rel")
Expand Down
37 changes: 37 additions & 0 deletions tests/test_sentence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import copy
import pickle

from flair.data import Sentence


Expand Down Expand Up @@ -73,3 +76,37 @@ def test_start_end_position_pretokenized() -> None:
(10, 18),
(19, 20),
]


def test_spans_support_deepcopy() -> None:
sentence = Sentence(["I", "live", "in", "Vienna", "."])
sentence[3:4].add_label("ner", "LOC")

_ = copy.deepcopy(sentence)


def test_spans_support_pickle() -> None:
sentence = Sentence(["I", "live", "in", "Vienna", "."])
sentence[3:4].add_label("ner", "LOC")

pickle_data = pickle.dumps(sentence)
_ = pickle.loads(pickle_data)


def test_relations_support_deepcopy() -> None:
sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"])
sentence[0:1].add_label("ner", "LOC")
sentence[5:6].add_label("ner", "LOC")
sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital")

_ = copy.deepcopy(sentence)


def test_relations_support_pickle() -> None:
sentence = Sentence(["Vienna", "is", "the", "capital", "of", "Austria"])
sentence[0:1].add_label("ner", "LOC")
sentence[5:6].add_label("ner", "LOC")
sentence[sentence[0:1], sentence[5:6]].add_label("rel", "capital")

pickle_data = pickle.dumps(sentence)
_ = pickle.loads(pickle_data)
Loading