Skip to content

Commit

Permalink
fix: exception handling for OCRAgent.get_agent() (#3335)
Browse files Browse the repository at this point in the history
The purpose of this PR is to help investigate
#3202.
  • Loading branch information
christinestraub committed Jul 3, 2024
1 parent d86d15c commit 493bfcc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
## 0.14.10-dev8

### Enhancements

* **Update unstructured-client dependency** Change unstructured-client dependency pin back to
greater than min version and updated tests that were failing given the update.

* **`.doc` files are now supported in the `arm64` image.**. `libreoffice24` is added to the `arm64` image, meaning `.doc` files are now supported. We have follow on work planned to investigate adding `.ppt` support for `arm64` as well.

### Features

### Fixes
- Fix counting false negatives and false positives in table structure evaluation

* **Fix counting false negatives and false positives in table structure evaluation**
* **Fix Slack CI test** Change channel that Slack test is pointing to because previous test bot expired

## 0.14.9
Expand Down
33 changes: 25 additions & 8 deletions test_unstructured/partition/utils/ocr_models/test_ocr_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from unittest.mock import patch

import pytest

from test_unstructured.unit_utils import (
Expand Down Expand Up @@ -39,18 +41,23 @@ def it_provides_access_to_the_configured_OCR_agent(
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
assert ocr_agent is ocr_agent_

@pytest.mark.parametrize("ExceptionCls", [ImportError, AttributeError])
def but_it_raises_whan_no_such_ocr_agent_class_is_found(
self, ExceptionCls: type, _get_ocr_agent_cls_qname_: Mock, get_instance_: Mock
def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, _get_ocr_agent_cls_qname_: Mock
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
get_instance_.side_effect = ExceptionCls

with pytest.raises(ValueError, match="OCR_AGENT must be set to an existing OCR agent "):
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent()

_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with("Invalid.Ocr.Agent.Qname")
@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
self, _get_ocr_agent_cls_qname_: Mock, exception_cls: type[Exception], _clear_cache
):
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
with patch(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent()

@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),
Expand All @@ -77,6 +84,16 @@ def and_it_logs_a_warning_when_the_OCR_AGENT_module_name_is_obsolete(

# -- fixtures --------------------------------------------------------------------------------

@pytest.fixture()
def _clear_cache(self):
# Clear the cache created by @functools.lru_cache(maxsize=None) on OCRAgent.get_instance()
# before each test
OCRAgent.get_instance.cache_clear()
yield
# Clear the cache created by @functools.lru_cache(maxsize=None) on OCRAgent.get_instance()
# after each test (just in case)
OCRAgent.get_instance.cache_clear()

@pytest.fixture()
def get_instance_(self, request: FixtureRequest):
return method_mock(request, OCRAgent, "get_instance")
Expand Down
25 changes: 13 additions & 12 deletions unstructured/partition/utils/ocr_models/ocr_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,27 @@ def get_agent(cls) -> OCRAgent:
The OCR package used by the agent is determined by the `OCR_AGENT` environment variable.
"""
ocr_agent_cls_qname = cls._get_ocr_agent_cls_qname()
try:
return cls.get_instance(ocr_agent_cls_qname)
except (ImportError, AttributeError):
raise ValueError(
f"Environment variable OCR_AGENT must be set to an existing OCR agent module,"
f" not {ocr_agent_cls_qname}."
)
return cls.get_instance(ocr_agent_cls_qname)

@staticmethod
@functools.lru_cache(maxsize=None)
def get_instance(ocr_agent_module: str) -> "OCRAgent":
module_name, class_name = ocr_agent_module.rsplit(".", 1)
if module_name in OCR_AGENT_MODULES_WHITELIST:
if module_name not in OCR_AGENT_MODULES_WHITELIST:
raise ValueError(
f"Environment variable OCR_AGENT module name {module_name} must be set to a "
f"whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}."
)

try:
module = importlib.import_module(module_name)
loaded_class = getattr(module, class_name)
return loaded_class()
else:
raise ValueError(
f"Environment variable OCR_AGENT module name {module_name}, must be set to a"
f" whitelisted module part of {OCR_AGENT_MODULES_WHITELIST}.",
except (ImportError, AttributeError) as e:
logger.error(f"Failed to get OCRAgent instance: {e}")
raise RuntimeError(
"Could not get the OCRAgent instance. Please check the OCR package and the "
"OCR_AGENT environment variable."
)

@abstractmethod
Expand Down

0 comments on commit 493bfcc

Please sign in to comment.