diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000..5480a58 --- /dev/null +++ b/docs/config.md @@ -0,0 +1,13 @@ +If you're using the ml dependencies, the config class manages the computational device settings for your project, + +### Config + +This is a simple wrapper around pytorch. Setting the device will fail if pytorch is not installed. + +```python +import openparse + +openparse.config.set_device("cpu") +``` + +Note if you're on apple silicon, setting this to `mps` runs significantly slower than on `cpu`. diff --git a/mkdocs.yml b/mkdocs.yml index ebcfc16..eeb71c9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -79,6 +79,7 @@ nav: - Customization: processing/customization.md - Serializing Results: serialization.md - Visualization: visualization.md + - Config: config.md plugins: - search: diff --git a/src/openparse/__init__.py b/src/openparse/__init__.py index ed8f7eb..fa61a23 100644 --- a/src/openparse/__init__.py +++ b/src/openparse/__init__.py @@ -2,7 +2,8 @@ from openparse.doc_parser import ( DocumentParser, ) -from openparse import processing +from openparse import processing, version +from openparse.config import config from openparse.schemas import ( Bbox, LineElement, @@ -13,13 +14,18 @@ ) __all__ = [ + # core "DocumentParser", "Pdf", + # Schemas "Bbox", "LineElement", "Node", "TableElement", "TextElement", "TextSpan", + # Modules "processing", + "version", + "config", ] diff --git a/src/openparse/config.py b/src/openparse/config.py new file mode 100644 index 0000000..3c3b25a --- /dev/null +++ b/src/openparse/config.py @@ -0,0 +1,42 @@ +from typing import Literal + + +TorchDevice = Literal["cuda", "cpu", "mps"] + + +class Config: + def __init__(self): + self._device = "cpu" # Default to CPU + self._torch_available = False + self._cuda_available = False + try: + import torch + + self._torch_available = True + if torch.cuda.is_available(): + self._device = "cuda" + self._cuda_available = True + except ImportError: + pass + + def set_device(self, device: TorchDevice) -> None: + if not self._torch_available and device == "cuda": + raise RuntimeError( + "CUDA device requested but torch is not available. Have you installed ml dependencies?" + ) + if not self._cuda_available and device == "cuda": + raise RuntimeError("CUDA device requested but CUDA is not available") + if device not in ["cuda", "cpu", "mps"]: + raise ValueError("Device must be 'cuda', 'cpu' or 'mps'") + self._device = device + + def get_device(self): + if self._torch_available: + import torch + + return torch.device(self._device) + else: + return self._device + + +config = Config() diff --git a/src/openparse/tables/table_transformers/ml.py b/src/openparse/tables/table_transformers/ml.py index ad958af..ad4ec48 100644 --- a/src/openparse/tables/table_transformers/ml.py +++ b/src/openparse/tables/table_transformers/ml.py @@ -10,6 +10,7 @@ TableTransformerForObjectDetection, # type: ignore ) # type: ignore +from openparse.config import config from ..schemas import ( BBox, Size, @@ -34,12 +35,7 @@ ) t0 = time.time() - -cuda_available = torch.cuda.is_available() -user_preferred_device = "cuda" -device = torch.device( - "cuda" if cuda_available and user_preferred_device != "cpu" else "cpu" -) +device = config.get_device() class MaxResize: diff --git a/src/openparse/tables/unitable/config.py b/src/openparse/tables/unitable/config.py index 987b837..fabb424 100644 --- a/src/openparse/tables/unitable/config.py +++ b/src/openparse/tables/unitable/config.py @@ -1,10 +1,8 @@ -import torch from pydantic import BaseModel from pathlib import Path import sys from openparse import consts -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") root = Path(consts.__file__).parent diff --git a/src/openparse/tables/unitable/core.py b/src/openparse/tables/unitable/core.py index 456900c..4acdfa5 100644 --- a/src/openparse/tables/unitable/core.py +++ b/src/openparse/tables/unitable/core.py @@ -6,7 +6,6 @@ from torchvision import transforms # type: ignore from torch import nn, Tensor # type: ignore -from .config import device from .tokens import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN from .utils import ( subsequent_mask, @@ -27,10 +26,13 @@ cell_model, EncoderDecoder, ) +from openparse.config import config Size = Tuple[int, int] BBox = Tuple[int, int, int, int] +device = config.get_device() + def _image_to_tensor(image: Image, size: Size) -> Tensor: T = transforms.Compose( diff --git a/src/openparse/tables/unitable/unitable_model.py b/src/openparse/tables/unitable/unitable_model.py index 7c83c66..676f637 100644 --- a/src/openparse/tables/unitable/unitable_model.py +++ b/src/openparse/tables/unitable/unitable_model.py @@ -7,14 +7,16 @@ from torch import nn from functools import partial -from .config import device, config +from .config import config from .tabular_transformer import ( EncoderDecoder, ImgLinearBackbone, Encoder, Decoder, ) +from openparse.config import config as global_config +device = global_config.get_device() warnings.filterwarnings("ignore") diff --git a/src/openparse/version.py b/src/openparse/version.py new file mode 100644 index 0000000..8327712 --- /dev/null +++ b/src/openparse/version.py @@ -0,0 +1,43 @@ +OPEN_PARSE_VERSION = "0.5.1" + + +def version_info() -> str: + """Return complete version information for OpenParse and its dependencies.""" + import importlib.metadata as importlib_metadata + import platform + import sys + from pathlib import Path + + python_version = sys.version.split()[0] + operating_system = platform.system() + os_version = platform.release() + + package_names = { + "email-validator", + "torch", + "torchvision", + "transformers", + "tokenizers", + "PyMuPDF", + "pydantic", + } + related_packages = [] + + for dist in importlib_metadata.distributions(): + name = dist.metadata["Name"] + if name in package_names: + related_packages.append(f"{name}-{dist.version}") + + info = { + "python_version": python_version, + "operating_system": operating_system, + "os_version": os_version, + "open-parse version": OPEN_PARSE_VERSION, + "install path": Path(__file__).resolve().parent, + "python version": sys.version, + "platform": platform.platform(), + "related packages": " ".join(related_packages), + } + return "\n".join( + "{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items() + )