Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make gensim optional #3493

Merged
merged 14 commits into from
Jul 19, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Install Torch cpu
run: pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install Flair dependencies
run: pip install -e .
run: pip install -e .[word-embeddings]
- name: Install unittest dependencies
run: pip install -r requirements-dev.txt
- name: Show installed dependencies
Expand Down
28 changes: 27 additions & 1 deletion flair/class_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import inspect
from typing import Iterable, Optional, Type, TypeVar
from types import ModuleType
from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload

T = TypeVar("T")

Expand All @@ -17,3 +19,27 @@ def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]
if sub_cls.__name__ == cls_name:
return sub_cls
raise ValueError(f"Could not find any class with name '{cls_name}'")


@overload
def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ...


@overload
def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ...


def lazy_import(
group: str, module: str, first_symbol: Optional[str] = None, *symbols: str
) -> Union[List[Any], ModuleType]:
try:
imported_module = importlib.import_module(module)
except ImportError:
raise ImportError(
f"Could not import {module}. Please install the optional '{group}' dependency. Via 'pip install flair[{group}]'"
)
if first_symbol is None:
return imported_module
symbols = (first_symbol, *symbols)

return [getattr(imported_module, symbol) for symbol in symbols]
1 change: 0 additions & 1 deletion flair/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

# Expose token embedding classes
from .token import (
BPEmbSerializable,
BytePairEmbeddings,
CharacterEmbeddings,
FastTextEmbeddings,
Expand Down
Loading
Loading