Skip to content

Commit

Permalink
make muse warn immediately if gensim is not installed
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Jul 12, 2024
1 parent c17dd2f commit fb60047
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
import tempfile
from collections import Counter, Mapping
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -1283,6 +1283,8 @@ def __init__(
self.static_embeddings = True
self.__embedding_length: int = 300
self.language_embeddings: Dict[str, Any] = {}
(KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors")
self.kv = KeyedVectors
super().__init__()
self.eval()

Expand Down Expand Up @@ -1345,7 +1347,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
embeddings_file = cached_path(f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir)

# load the model
self.language_embeddings[language_code] = gensim.models.KeyedVectors.load(str(embeddings_file))
self.language_embeddings[language_code] = self.kv.load(str(embeddings_file))

for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
word_embedding = self.get_cached_vec(language_code=language_code, word=token.text)
Expand Down Expand Up @@ -1401,7 +1403,7 @@ def __init__(
else:
if not language and model_file_path is None:
raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings")
BPEmb, = lazy_import("word-embeddings", "bpemb", "BPEmb")
(BPEmb,) = lazy_import("word-embeddings", "bpemb", "BPEmb")

if language:
self.name: str = f"bpe-{language}-{syllables}-{dim}"
Expand Down Expand Up @@ -1504,7 +1506,14 @@ def from_params(cls, params):
else:
embedding_file_path = None
dim = params["dim"]
return cls(name=params["name"], dim=dim, model_file_path=model_file_path, embedding_file_path=embedding_file_path, field=params.get("field"), preprocess=params.get("preprocess", True))
return cls(
name=params["name"],
dim=dim,
model_file_path=model_file_path,
embedding_file_path=embedding_file_path,
field=params.get("field"),
preprocess=params.get("preprocess", True),
)

def to_params(self):
return {
Expand Down Expand Up @@ -1541,7 +1550,7 @@ def state_dict(self, *args, **kwargs):
return super().state_dict(*args, **kwargs)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if not state_dict:
# old embeddings do not have a torch-embedding and therefore do not store the weights in the saved torch state_dict
Expand Down

0 comments on commit fb60047

Please sign in to comment.