Skip to content

Commit

Permalink
Merge pull request #3490 from MattGPT-ai/gh-3487/fix-gpu-memory-leak-…
Browse files Browse the repository at this point in the history
…text-pair-regressor

Fix GPU memory leak in TextPairRegressor
  • Loading branch information
alanakbik committed Jul 13, 2024
2 parents 472b2c7 + 05622e6 commit 7887678
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 61 deletions.
95 changes: 52 additions & 43 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from operator import itemgetter
from os import PathLike
from pathlib import Path
from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast
from typing import Any, DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast

import torch
from deprecated.sphinx import deprecated
Expand Down Expand Up @@ -49,7 +50,7 @@ class BoundingBox(NamedTuple):
class Dictionary:
"""This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings."""

def __init__(self, add_unk=True) -> None:
def __init__(self, add_unk: bool = True) -> None:
# init dictionaries
self.item2idx: Dict[bytes, int] = {}
self.idx2item: List[bytes] = []
Expand Down Expand Up @@ -143,21 +144,21 @@ def is_span_prediction_problem(self) -> bool:
def start_stop_tags_are_set(self) -> bool:
return {b"<START>", b"<STOP>"}.issubset(self.item2idx.keys())

def save(self, savefile):
def save(self, savefile: PathLike):
import pickle

with open(savefile, "wb") as f:
mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx}
pickle.dump(mappings, f)

def __setstate__(self, d):
def __setstate__(self, d: Dict) -> None:
self.__dict__ = d
# set 'add_unk' if the dictionary was created with a version of Flair older than 0.9
if "add_unk" not in self.__dict__:
self.__dict__["add_unk"] = b"<unk>" in self.__dict__["idx2item"]

@classmethod
def load_from_file(cls, filename: Union[str, Path]):
def load_from_file(cls, filename: Union[str, Path]) -> "Dictionary":
import pickle

with Path(filename).open("rb") as f:
Expand All @@ -174,7 +175,7 @@ def load_from_file(cls, filename: Union[str, Path]):
return dictionary

@classmethod
def load(cls, name: str):
def load(cls, name: str) -> "Dictionary":
from flair.file_utils import cached_path

hu_path: str = "https://flair.informatik.hu-berlin.de/resources/characters"
Expand Down Expand Up @@ -282,11 +283,11 @@ class DataPoint:
def __init__(self) -> None:
self.annotation_layers: Dict[str, List[Label]] = {}
self._embeddings: Dict[str, torch.Tensor] = {}
self._metadata: Dict[str, typing.Any] = {}
self._metadata: Dict[str, Any] = {}

@property
@abstractmethod
def embedding(self):
def embedding(self) -> torch.Tensor:
pass

def set_embedding(self, name: str, vector: torch.Tensor):
Expand Down Expand Up @@ -316,35 +317,35 @@ def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> Lis
embeddings.append(embed)
return embeddings

def to(self, device: str, pin_memory: bool = False):
def to(self, device: str, pin_memory: bool = False) -> None:
for name, vector in self._embeddings.items():
if str(vector.device) != str(device):
if pin_memory:
self._embeddings[name] = vector.to(device, non_blocking=True).pin_memory()
else:
self._embeddings[name] = vector.to(device, non_blocking=True)

def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
def clear_embeddings(self, embedding_names: Optional[List[str]] = None) -> None:
if embedding_names is None:
self._embeddings = {}
else:
for name in embedding_names:
if name in self._embeddings:
del self._embeddings[name]

def has_label(self, type) -> bool:
def has_label(self, type: str) -> bool:
return type in self.annotation_layers

def add_metadata(self, key: str, value: typing.Any) -> None:
def add_metadata(self, key: str, value: Any) -> None:
self._metadata[key] = value

def get_metadata(self, key: str) -> typing.Any:
def get_metadata(self, key: str) -> Any:
return self._metadata[key]

def has_metadata(self, key: str) -> bool:
return key in self._metadata

def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata) -> "DataPoint":
label = Label(self, value, score, **metadata)

if typename not in self.annotation_layers:
Expand All @@ -358,16 +359,16 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
self.annotation_layers[typename] = [Label(self, value, score, **metadata)]
return self

def remove_labels(self, typename: str):
def remove_labels(self, typename: str) -> None:
if typename in self.annotation_layers:
del self.annotation_layers[typename]

def get_label(self, label_type: Optional[str] = None, zero_tag_value="O"):
def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O") -> Label:
if len(self.get_labels(label_type)) == 0:
return Label(self, zero_tag_value)
return self.get_labels(label_type)[0]

def get_labels(self, typename: Optional[str] = None):
def get_labels(self, typename: Optional[str] = None) -> List[Label]:
if typename is None:
return self.labels

Expand All @@ -385,7 +386,7 @@ def labels(self) -> List[Label]:
def unlabeled_identifier(self):
raise NotImplementedError

def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True):
def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True) -> str:
all_labels = []
keys = [main_label] if main_label is not None else self.annotation_layers.keys()

Expand Down Expand Up @@ -431,7 +432,7 @@ def tag(self):
def score(self):
return self.labels[0].score

def __lt__(self, other):
def __lt__(self, other: "DataPoint"):
return self.start_position < other.start_position

def __len__(self) -> int:
Expand Down Expand Up @@ -482,7 +483,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return str(self)

def to_dict(self) -> Dict[str, typing.Any]:
def to_dict(self) -> Dict[str, Any]:
return {
"concept_id": self.concept_id,
"concept_name": self.concept_name,
Expand Down Expand Up @@ -517,7 +518,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
super().set_label(typename, value, score, **metadata)
return self

def remove_labels(self, typename: str):
def remove_labels(self, typename: str) -> None:
# labels also need to be deleted at Sentence object
for label in self.get_labels(typename):
self.sentence.annotation_layers[typename].remove(label)
Expand Down Expand Up @@ -567,7 +568,7 @@ def text(self) -> str:
def unlabeled_identifier(self) -> str:
return f'Token[{self.idx - 1}]: "{self.text}"'

def add_tags_proba_dist(self, tag_type: str, tags: List[Label]):
def add_tags_proba_dist(self, tag_type: str, tags: List[Label]) -> None:
self.tags_proba_dist[tag_type] = tags

def get_tags_proba_dist(self, tag_type: str) -> List[Label]:
Expand Down Expand Up @@ -616,7 +617,7 @@ def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
else:
DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata)

def to_dict(self, tag_type: Optional[str] = None):
def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
return {
"text": self.text,
"start_pos": self.start_position,
Expand Down Expand Up @@ -958,7 +959,7 @@ def right_context(self, context_length: int, respect_document_boundaries: bool =
def __str__(self) -> str:
return self.to_tagged_string()

def to_tagged_string(self, main_label=None) -> str:
def to_tagged_string(self, main_label: Optional[str] = None) -> str:
already_printed = [self]

output = super().__str__()
Expand All @@ -978,7 +979,7 @@ def to_tagged_string(self, main_label=None) -> str:
return output

@property
def text(self):
def text(self) -> str:
return self.to_original_text()

def to_tokenized_string(self) -> str:
Expand All @@ -987,7 +988,7 @@ def to_tokenized_string(self) -> str:

return self.tokenized

def to_plain_string(self):
def to_plain_string(self) -> str:
plain = ""
for token in self.tokens:
plain += token.text
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def to_original_text(self) -> str:
[t.text + t.whitespace_after * " " for t in self.tokens]
).strip()

def to_dict(self, tag_type: Optional[str] = None):
def to_dict(self, tag_type: Optional[str] = None) -> Dict[str, Any]:
return {
"text": self.to_original_text(),
"labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self],
Expand All @@ -1045,7 +1046,7 @@ def to_dict(self, tag_type: Optional[str] = None):
"tokens": [token.to_dict(tag_type) for token in self.tokens],
}

def get_span(self, start: int, stop: int):
def get_span(self, start: int, stop: int) -> Span:
span_slice = slice(start, stop)
return self[span_slice]

Expand Down Expand Up @@ -1090,7 +1091,8 @@ def get_language_code(self) -> str:

try:
self.language_code = langdetect.detect(self.to_plain_string())
except Exception:
except Exception as e:
log.debug(e)
self.language_code = "en"

return self.language_code
Expand Down Expand Up @@ -1223,6 +1225,7 @@ def __init__(self, first: DT, second: DT2) -> None:
super().__init__()
self.first = first
self.second = second
self.concatenated_data: Optional[Union[DT, DT2]] = None

def to(self, device: str, pin_memory: bool = False):
self.first.to(device, pin_memory)
Expand All @@ -1231,6 +1234,8 @@ def to(self, device: str, pin_memory: bool = False):
def clear_embeddings(self, embedding_names: Optional[List[str]] = None):
self.first.clear_embeddings(embedding_names)
self.second.clear_embeddings(embedding_names)
if self.concatenated_data is not None:
self.concatenated_data.clear_embeddings(embedding_names)

@property
def embedding(self):
Expand Down Expand Up @@ -1304,7 +1309,7 @@ def text(self):


class Image(DataPoint):
def __init__(self, data=None, imageURL=None) -> None:
def __init__(self, data=None, imageURL=None):
super().__init__()

self.data = data
Expand Down Expand Up @@ -1403,7 +1408,7 @@ def downsample(
downsample_dev: bool = True,
downsample_test: bool = True,
random_seed: Optional[int] = None,
):
) -> "Corpus":
"""Reduce all datasets in corpus proportionally to the given percentage."""
if downsample_train and self._train is not None:
self._train = self._downsample_to_proportion(self._train, percentage, random_seed)
Expand Down Expand Up @@ -1470,7 +1475,7 @@ def _filter_empty_sentences(dataset) -> Dataset:

return subset

def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary:
def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dictionary:
"""Creates a dictionary of all tokens contained in the corpus.
By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary.
Expand All @@ -1492,7 +1497,7 @@ def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary:

return vocab_dictionary

def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]:
def _get_most_common_tokens(self, max_tokens: int, min_freq: int) -> List[str]:
tokens_and_frequencies = Counter(self._get_all_tokens())

tokens: List[str] = []
Expand Down Expand Up @@ -1561,20 +1566,20 @@ def _obtain_statistics_for(sentences, name, tag_type) -> dict:
}

@staticmethod
def _get_tokens_per_sentence(sentences):
def _get_tokens_per_sentence(sentences: Iterable[Sentence]) -> List[int]:
return [len(x.tokens) for x in sentences]

@staticmethod
def _count_sentence_labels(sentences):
label_count = defaultdict(lambda: 0)
def _count_sentence_labels(sentences: Iterable[Sentence]) -> DefaultDict[str, int]:
label_count: DefaultDict[str, int] = defaultdict(lambda: 0)
for sent in sentences:
for label in sent.labels:
label_count[label.value] += 1
return label_count

@staticmethod
def _count_token_labels(sentences, label_type):
label_count = defaultdict(lambda: 0)
def _count_token_labels(sentences: Iterable[Sentence], label_type: str) -> DefaultDict[str, int]:
label_count: DefaultDict[str, int] = defaultdict(lambda: 0)
for sent in sentences:
for token in sent.tokens:
if label_type in token.annotation_layers:
Expand Down Expand Up @@ -1665,7 +1670,9 @@ def make_label_dictionary(
[f"'{label[0]}' (in {label[1]} sentences)" for label in sentence_label_type_counter.most_common()]
)
log.error(f"ERROR: The corpus contains the following label types: {contained_labels}")
raise Exception
raise ValueError(
f"You specified a label type ({label_type}) that is not contained in the corpus:\n{contained_labels}"
)

log.info(
f"Dictionary created for label '{label_type}' with {len(label_dictionary)} "
Expand Down Expand Up @@ -1888,7 +1895,7 @@ def __init__(self, datasets: Iterable[Dataset], ids: Iterable[str]) -> None:
def __len__(self) -> int:
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Sentence:
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
Expand All @@ -1900,11 +1907,11 @@ def __getitem__(self, idx):
return sentence

@property
def cummulative_sizes(self):
def cummulative_sizes(self) -> List[int]:
return self.cumulative_sizes


def iob2(tags):
def iob2(tags: List) -> bool:
"""Converts the tags to the IOB2 format.
Check that tags have a valid IOB format.
Expand Down Expand Up @@ -1951,7 +1958,9 @@ def randomly_split_into_two_datasets(
return Subset(dataset, first_dataset), Subset(dataset, second_dataset)


def get_spans_from_bio(bioes_tags: List[str], bioes_scores=None) -> List[typing.Tuple[List[int], float, str]]:
def get_spans_from_bio(
bioes_tags: List[str], bioes_scores: Optional[List[float]] = None
) -> List[typing.Tuple[List[int], float, str]]:
# add a dummy "O" to close final prediction
bioes_tags.append("O")
# return complex list
Expand Down
2 changes: 1 addition & 1 deletion flair/models/lemmatizer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def predict(
# option 1: greedy decoding
if self.beam_size == 1:
# predictions
predicted: List[List[int]] = [[] for _ in range(number_tokens)]
predicted: List[List[Union[int, float]]] = [[] for _ in range(number_tokens)]

for _decode_step in range(max_length):
# decode next character
Expand Down
19 changes: 11 additions & 8 deletions flair/models/pairwise_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,17 @@ def _get_embedding_for_data_point(self, prediction_data_point: TextPair) -> torc
0,
)
else:
concatenated_sentence = Sentence(
prediction_data_point.first.to_tokenized_string()
+ self.sep
+ prediction_data_point.second.to_tokenized_string(),
use_tokenizer=False,
)
self.embeddings.embed(concatenated_sentence)
return concatenated_sentence.get_embedding(embedding_names)
# If the concatenated version of the text pair does not exist yet, create it
if prediction_data_point.concatenated_data is None:
concatenated_sentence = Sentence(
prediction_data_point.first.to_tokenized_string()
+ self.sep
+ prediction_data_point.second.to_tokenized_string(),
use_tokenizer=False,
)
prediction_data_point.concatenated_data = concatenated_sentence
self.embeddings.embed(prediction_data_point.concatenated_data)
return prediction_data_point.concatenated_data.get_embedding(embedding_names)

def _get_state_dict(self):
model_state = {
Expand Down
Loading

0 comments on commit 7887678

Please sign in to comment.