Skip to content

Commit

Permalink
fix(CVE-2024-39705): remove nltk download (#3361)
Browse files Browse the repository at this point in the history
### Summary

Addresses
[CVE-2024-39705](https://nvd.nist.gov/vuln/detail/CVE-2024-39705), which
highlights the risk of remote code execution when running
`nltk.download` . Removes `nltk.download` in favor of a `.tgz` file with
the appropriate NLTK data files and checking the SHA256 hash to validate
the download. An error now raises if `nltk.download` is invoked.

The logic for determining the NLTK download directory is borrowed from
`nltk`, so users can still set `NLTK_DATA` as they did previously.

### Testing

1. Create a directory called `~/tmp/nltk_test`. Set
`NLTK_DATA=${HOME}/tmp/nltk_test`.
2. From a python interactive session, run:
```python
from unstructured.nlp.tokenize import download_nltk_packages

download_nltk_packages()
```
3. Run `ls /tmp/nltk_test/nltk_data`. You should see the downloaded
data.

---------

Co-authored-by: Steve Canny <[email protected]>
  • Loading branch information
MthwRobinson and scanny committed Jul 8, 2024
1 parent d48fa3b commit 7b25dfc
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ jobs:
matrix:
python-version: [ "3.9","3.10" ]
runs-on: ubuntu-latest
env:
NLTK_DATA: ${{ github.workspace }}/nltk_data
needs: [ setup_ingest, lint ]
steps:
# actions/checkout MUST come before auth
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.14.10-dev13
## 0.14.10

### Enhancements

Expand All @@ -14,6 +14,7 @@

* **Fix counting false negatives and false positives in table structure evaluation**
* **Fix Slack CI test** Change channel that Slack test is pointing to because previous test bot expired
* **Remove NLTK download** Removes `nltk.download` in favor of downloading from an S3 bucket we host to mitigate CVE-2024-39705

## 0.14.9

Expand Down
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM quay.io/unstructured-io/base-images:wolfi-base-d46498e@sha256:3db0544df1d8d9989cd3c3b28670d8b81351dfdc1d9129004c71ff05996fd51e as base
FROM quay.io/unstructured-io/base-images:wolfi-base-e48da6b@sha256:8ad3479e5dc87a86e4794350cca6385c01c6d110902c5b292d1a62e231be711b as base

USER root

Expand All @@ -18,8 +18,7 @@ USER notebook-user

RUN find requirements/ -type f -name "*.txt" -exec pip3.11 install --no-cache-dir --user -r '{}' ';' && \
pip3.11 install unstructured.paddlepaddle && \
python3.11 -c "import nltk; nltk.download('punkt')" && \
python3.11 -c "import nltk; nltk.download('averaged_perceptron_tagger')" && \
python3.11 -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" && \
python3.11 -c "from unstructured.partition.model_init import initialize; initialize()" && \
python3.11 -c "from unstructured_inference.models.tables import UnstructuredTableTransformerModel; model = UnstructuredTableTransformerModel(); model.initialize('microsoft/table-transformer-structure-recognition')"

Expand Down
14 changes: 10 additions & 4 deletions test_unstructured/nlp/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
from unittest.mock import patch

import nltk
import pytest

from test_unstructured.nlp.mock_nltk import mock_sent_tokenize, mock_word_tokenize
from unstructured.nlp import tokenize


def test_error_raised_on_nltk_download():
with pytest.raises(ValueError):
tokenize.nltk.download("tokenizers/punkt")


def test_nltk_packages_download_if_not_present():
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_called_with("fake_package")
mock_download.assert_called_once()


def test_nltk_packages_do_not_download_if():
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_not_called()

Expand Down
17 changes: 17 additions & 0 deletions typings/nltk/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from nltk import data, internals
from nltk.data import find
from nltk.downloader import download
from nltk.tag import pos_tag
from nltk.tokenize import sent_tokenize, word_tokenize

__all__ = [
"data",
"download",
"find",
"internals",
"pos_tag",
"sent_tokenize",
"word_tokenize",
]
7 changes: 7 additions & 0 deletions typings/nltk/data.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

from typing import Sequence

path: list[str]

def find(resource_name: str, paths: Sequence[str] | None = None) -> str: ...
5 changes: 5 additions & 0 deletions typings/nltk/downloader.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from typing import Callable

download: Callable[..., bool]
3 changes: 3 additions & 0 deletions typings/nltk/internals.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

def is_writable(path: str) -> bool: ...
5 changes: 5 additions & 0 deletions typings/nltk/tag.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

def pos_tag(
tokens: list[str], tagset: str | None = None, lang: str = "eng"
) -> list[tuple[str, str]]: ...
4 changes: 4 additions & 0 deletions typings/nltk/tokenize.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

def sent_tokenize(text: str, language: str = ...) -> list[str]: ...
def word_tokenize(text: str, language: str = ..., preserve_line: bool = ...) -> list[str]: ...
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.14.10-dev13" # pragma: no cover
__version__ = "0.14.10" # pragma: no cover
139 changes: 121 additions & 18 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import hashlib
import os
import sys
import tarfile
import tempfile
import urllib.request
from functools import lru_cache
from typing import List, Tuple

if sys.version_info < (3, 8):
from typing_extensions import Final # pragma: no cover
else:
from typing import Final
from typing import Any, Final, List, Tuple

import nltk
from nltk import pos_tag as _pos_tag
Expand All @@ -14,42 +16,143 @@

CACHE_MAX_SIZE: Final[int] = 128

NLTK_DATA_URL = "https://utic-public-cf.s3.amazonaws.com/nltk_data.tgz"
NLTK_DATA_SHA256 = "126faf671cd255a062c436b3d0f2d311dfeefcd92ffa43f7c3ab677309404d61"


def _raise_on_nltk_download(*args: Any, **kwargs: Any):
raise ValueError("NLTK download disabled. See CVE-2024-39705")


nltk.download = _raise_on_nltk_download


# NOTE(robinson) - mimic default dir logic from NLTK
# https://github.com/nltk/nltk/
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
def get_nltk_data_dir() -> str | None:
"""Locates the directory the nltk data will be saved too. The directory
set by the NLTK environment variable takes highest precedence. Otherwise
the default is determined by the rules indicated below. Returns None when
the directory is not writable.
On Windows, the default download directory is
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
directory containing Python, e.g. ``C:\\Python311``.
On all other platforms, the default directory is the first of
the following which exists or which can be created with write
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
"""
# Check if we are on GAE where we cannot write into filesystem.
if "APPENGINE_RUNTIME" in os.environ:
return

# Check if we have sufficient permissions to install in a
# variety of system-wide locations.
for nltkdir in nltk.data.path:
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
return nltkdir

# On Windows, use %APPDATA%
if sys.platform == "win32" and "APPDATA" in os.environ:
homedir = os.environ["APPDATA"]

# Otherwise, install in the user's home directory.
else:
homedir = os.path.expanduser("~/")
if homedir == "~/":
raise ValueError("Could not find a default download directory")

# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
# present in the tar file so we don't have to do that here.
return homedir


def download_nltk_packages():
nltk_data_dir = get_nltk_data_dir()

if nltk_data_dir is None:
raise OSError("NLTK data directory does not exist or is not writable.")

def sha256_checksum(filename: str, block_size: int = 65536):
sha256 = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()

with tempfile.NamedTemporaryFile() as tmp_file:
tgz_file = tmp_file.name
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file)

file_hash = sha256_checksum(tgz_file)
if file_hash != NLTK_DATA_SHA256:
os.remove(tgz_file)
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")

# Extract the contents
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)

with tarfile.open(tgz_file, "r:gz") as tar:
tar.extractall(path=nltk_data_dir)


def check_for_nltk_package(package_name: str, package_category: str) -> bool:
"""Checks to see if the specified NLTK package exists on the file system"""
paths: list[str] = []
for path in nltk.data.path:
if not path.endswith("nltk_data"):
path = os.path.join(path, "nltk_data")
paths.append(path)

def _download_nltk_package_if_not_present(package_name: str, package_category: str):
"""If the required nlt package is not present, download it."""
try:
nltk.find(f"{package_category}/{package_name}")
nltk.find(f"{package_category}/{package_name}", paths=paths)
return True
except LookupError:
nltk.download(package_name)
return False


def _download_nltk_packages_if_not_present():
"""If required NLTK packages are not available, download them."""

tagger_available = check_for_nltk_package(
package_category="taggers",
package_name="averaged_perceptron_tagger",
)
tokenizer_available = check_for_nltk_package(
package_category="tokenizers", package_name="punkt"
)

if not (tokenizer_available and tagger_available):
download_nltk_packages()


@lru_cache(maxsize=CACHE_MAX_SIZE)
def sent_tokenize(text: str) -> List[str]:
"""A wrapper around the NLTK sentence tokenizer with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_packages_if_not_present()
return _sent_tokenize(text)


@lru_cache(maxsize=CACHE_MAX_SIZE)
def word_tokenize(text: str) -> List[str]:
"""A wrapper around the NLTK word tokenizer with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_packages_if_not_present()
return _word_tokenize(text)


@lru_cache(maxsize=CACHE_MAX_SIZE)
def pos_tag(text: str) -> List[Tuple[str, str]]:
"""A wrapper around the NLTK POS tagger with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_package_if_not_present(
package_category="taggers",
package_name="averaged_perceptron_tagger",
)
_download_nltk_packages_if_not_present()
# NOTE(robinson) - Splitting into sentences before tokenizing. The helps with
# situations like "ITEM 1A. PROPERTIES" where "PROPERTIES" can be mistaken
# for a verb because it looks like it's in verb form an "ITEM 1A." looks like the subject.
sentences = _sent_tokenize(text)
parts_of_speech = []
parts_of_speech: list[tuple[str, str]] = []
for sentence in sentences:
tokens = _word_tokenize(sentence)
parts_of_speech.extend(_pos_tag(tokens))
Expand Down

0 comments on commit 7b25dfc

Please sign in to comment.