Skip to content

Commit

Permalink
Merge pull request #19 from Filimoa/torch-device
Browse files Browse the repository at this point in the history
Global PyTorch Config
  • Loading branch information
Filimoa committed Apr 11, 2024
2 parents 256bd5e + 5fd1c84 commit ae91e86
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 11 deletions.
13 changes: 13 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
If you're using the ml dependencies, the config class manages the computational device settings for your project,

### Config

This is a simple wrapper around pytorch. Setting the device will fail if pytorch is not installed.

```python
import openparse

openparse.config.set_device("cpu")
```

Note if you're on apple silicon, setting this to `mps` runs significantly slower than on `cpu`.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ nav:
- Customization: processing/customization.md
- Serializing Results: serialization.md
- Visualization: visualization.md
- Config: config.md

plugins:
- search:
Expand Down
8 changes: 7 additions & 1 deletion src/openparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from openparse.doc_parser import (
DocumentParser,
)
from openparse import processing
from openparse import processing, version
from openparse.config import config
from openparse.schemas import (
Bbox,
LineElement,
Expand All @@ -13,13 +14,18 @@
)

__all__ = [
# core
"DocumentParser",
"Pdf",
# Schemas
"Bbox",
"LineElement",
"Node",
"TableElement",
"TextElement",
"TextSpan",
# Modules
"processing",
"version",
"config",
]
42 changes: 42 additions & 0 deletions src/openparse/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Literal


TorchDevice = Literal["cuda", "cpu", "mps"]


class Config:
def __init__(self):
self._device = "cpu" # Default to CPU
self._torch_available = False
self._cuda_available = False
try:
import torch

self._torch_available = True
if torch.cuda.is_available():
self._device = "cuda"
self._cuda_available = True
except ImportError:
pass

def set_device(self, device: TorchDevice) -> None:
if not self._torch_available and device == "cuda":
raise RuntimeError(
"CUDA device requested but torch is not available. Have you installed ml dependencies?"
)
if not self._cuda_available and device == "cuda":
raise RuntimeError("CUDA device requested but CUDA is not available")
if device not in ["cuda", "cpu", "mps"]:
raise ValueError("Device must be 'cuda', 'cpu' or 'mps'")
self._device = device

def get_device(self):
if self._torch_available:
import torch

return torch.device(self._device)
else:
return self._device


config = Config()
8 changes: 2 additions & 6 deletions src/openparse/tables/table_transformers/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TableTransformerForObjectDetection, # type: ignore
) # type: ignore

from openparse.config import config
from ..schemas import (
BBox,
Size,
Expand All @@ -34,12 +35,7 @@
)

t0 = time.time()

cuda_available = torch.cuda.is_available()
user_preferred_device = "cuda"
device = torch.device(
"cuda" if cuda_available and user_preferred_device != "cpu" else "cpu"
)
device = config.get_device()


class MaxResize:
Expand Down
2 changes: 0 additions & 2 deletions src/openparse/tables/unitable/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
from pydantic import BaseModel
from pathlib import Path
import sys
from openparse import consts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root = Path(consts.__file__).parent


Expand Down
4 changes: 3 additions & 1 deletion src/openparse/tables/unitable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torchvision import transforms # type: ignore
from torch import nn, Tensor # type: ignore

from .config import device
from .tokens import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN
from .utils import (
subsequent_mask,
Expand All @@ -27,10 +26,13 @@
cell_model,
EncoderDecoder,
)
from openparse.config import config

Size = Tuple[int, int]
BBox = Tuple[int, int, int, int]

device = config.get_device()


def _image_to_tensor(image: Image, size: Size) -> Tensor:
T = transforms.Compose(
Expand Down
4 changes: 3 additions & 1 deletion src/openparse/tables/unitable/unitable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from torch import nn
from functools import partial

from .config import device, config
from .config import config
from .tabular_transformer import (
EncoderDecoder,
ImgLinearBackbone,
Encoder,
Decoder,
)
from openparse.config import config as global_config

device = global_config.get_device()
warnings.filterwarnings("ignore")


Expand Down
43 changes: 43 additions & 0 deletions src/openparse/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
OPEN_PARSE_VERSION = "0.5.1"


def version_info() -> str:
"""Return complete version information for OpenParse and its dependencies."""
import importlib.metadata as importlib_metadata
import platform
import sys
from pathlib import Path

python_version = sys.version.split()[0]
operating_system = platform.system()
os_version = platform.release()

package_names = {
"email-validator",
"torch",
"torchvision",
"transformers",
"tokenizers",
"PyMuPDF",
"pydantic",
}
related_packages = []

for dist in importlib_metadata.distributions():
name = dist.metadata["Name"]
if name in package_names:
related_packages.append(f"{name}-{dist.version}")

info = {
"python_version": python_version,
"operating_system": operating_system,
"os_version": os_version,
"open-parse version": OPEN_PARSE_VERSION,
"install path": Path(__file__).resolve().parent,
"python version": sys.version,
"platform": platform.platform(),
"related packages": " ".join(related_packages),
}
return "\n".join(
"{:>30} {}".format(k + ":", str(v).replace("\n", " ")) for k, v in info.items()
)

0 comments on commit ae91e86

Please sign in to comment.