Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global PyTorch Config #19

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
If you're using the ml dependencies, the config class manages the computational device settings for your project,

### Config

This is a simple wrapper around pytorch. Setting the device will fail if pytorch is not installed.

```python
import openparse

openparse.config.set_device("cpu")
```

Note if you're on apple silicon, setting this to `mps` runs significantly slower than on `cpu`.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ nav:
- Customization: processing/customization.md
- Serializing Results: serialization.md
- Visualization: visualization.md
- Config: config.md

plugins:
- search:
Expand Down
8 changes: 7 additions & 1 deletion src/openparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from openparse.doc_parser import (
DocumentParser,
)
from openparse import processing
from openparse import processing, version
from openparse.config import config
from openparse.schemas import (
Bbox,
LineElement,
Expand All @@ -13,13 +14,18 @@
)

__all__ = [
# core
"DocumentParser",
"Pdf",
# Schemas
"Bbox",
"LineElement",
"Node",
"TableElement",
"TextElement",
"TextSpan",
# Modules
"processing",
"version",
"config",
]
42 changes: 42 additions & 0 deletions src/openparse/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Literal


TorchDevice = Literal["cuda", "cpu", "mps"]


class Config:
def __init__(self):
self._device = "cpu" # Default to CPU
self._torch_available = False
self._cuda_available = False
try:
import torch

self._torch_available = True
if torch.cuda.is_available():
self._device = "cuda"
self._cuda_available = True
except ImportError:
pass

def set_device(self, device: TorchDevice) -> None:
if not self._torch_available and device == "cuda":
raise RuntimeError(
"CUDA device requested but torch is not available. Have you installed ml dependencies?"
)
if not self._cuda_available and device == "cuda":
raise RuntimeError("CUDA device requested but CUDA is not available")
if device not in ["cuda", "cpu", "mps"]:
raise ValueError("Device must be 'cuda', 'cpu' or 'mps'")
self._device = device

def get_device(self):
if self._torch_available:
import torch

return torch.device(self._device)
else:
return self._device


config = Config()
8 changes: 2 additions & 6 deletions src/openparse/tables/table_transformers/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TableTransformerForObjectDetection, # type: ignore
) # type: ignore

from openparse.config import config
from ..schemas import (
BBox,
Size,
Expand All @@ -34,12 +35,7 @@
)

t0 = time.time()

cuda_available = torch.cuda.is_available()
user_preferred_device = "cuda"
device = torch.device(
"cuda" if cuda_available and user_preferred_device != "cpu" else "cpu"
)
device = config.get_device()


class MaxResize:
Expand Down
2 changes: 0 additions & 2 deletions src/openparse/tables/unitable/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
from pydantic import BaseModel
from pathlib import Path
import sys
from openparse import consts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root = Path(consts.__file__).parent


Expand Down
4 changes: 3 additions & 1 deletion src/openparse/tables/unitable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torchvision import transforms # type: ignore
from torch import nn, Tensor # type: ignore

from .config import device
from .tokens import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN
from .utils import (
subsequent_mask,
Expand All @@ -27,10 +26,13 @@
cell_model,
EncoderDecoder,
)
from openparse.config import config

Size = Tuple[int, int]
BBox = Tuple[int, int, int, int]

device = config.get_device()


def _image_to_tensor(image: Image, size: Size) -> Tensor:
T = transforms.Compose(
Expand Down
4 changes: 3 additions & 1 deletion src/openparse/tables/unitable/unitable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from torch import nn
from functools import partial

from .config import device, config
from .config import config
from .tabular_transformer import (
EncoderDecoder,
ImgLinearBackbone,
Encoder,
Decoder,
)
from openparse.config import config as global_config

device = global_config.get_device()
warnings.filterwarnings("ignore")


Expand Down
43 changes: 43 additions & 0 deletions src/openparse/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
OPEN_PARSE_VERSION = "0.5.1"


def version_info() -> str:
"""Return complete version information for OpenParse and its dependencies."""
import importlib.metadata as importlib_metadata
import platform
import sys
from pathlib import Path

python_version = sys.version.split()[0]
operating_system = platform.system()
os_version = platform.release()

package_names = {
"email-validator",
"torch",
"torchvision",
"transformers",
"tokenizers",
"PyMuPDF",
"pydantic",
}
related_packages = []

for dist in importlib_metadata.distributions():
name = dist.metadata["Name"]
if name in package_names:
related_packages.append(f"{name}-{dist.version}")

info = {
"python_version": python_version,
"operating_system": operating_system,
"os_version": os_version,
"open-parse version": OPEN_PARSE_VERSION,
"install path": Path(__file__).resolve().parent,
"python version": sys.version,
"platform": platform.platform(),
"related packages": " ".join(related_packages),
}
return "\n".join(
"{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items()
)
Loading