Skip to content

Commit

Permalink
Improve Python type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesmishra committed Oct 23, 2023
1 parent b5d33d4 commit f5935f6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 110 deletions.
78 changes: 3 additions & 75 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -60,88 +60,17 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code,
invalid-name,
no-self-use
use-dict-literal

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down Expand Up @@ -590,5 +519,4 @@ min-public-methods=2

# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception
overgeneral-exceptions=
10 changes: 8 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ fmt-check:
$(POETRY) run black --check .
.PHONY: fmt-check

lint:
mypy:
$(POETRY) run mypy --strict log_with_context.py test_log_with_context.py
.PHONY: mypy

pylint:
$(POETRY) run pylint log_with_context.py test_log_with_context.py
$(POETRY) run mypy log_with_context.py test_log_with_context.py
.PHONY: pylint

lint: pylint mypy
.PHONY: lint

build:
Expand Down
48 changes: 30 additions & 18 deletions log_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import logging
import os
import threading
from typing import Any, Callable, Mapping, Optional, Union
from types import TracebackType
from typing import Any, Callable, DefaultDict, Mapping, Optional, Type, Union

from typing_extensions import Literal

_EXTRA_TYPE = Mapping[str, Any]

Expand All @@ -14,7 +17,7 @@
_THREAD_LOCAL.pids = collections.defaultdict(dict)


def init_extra():
def init_extra() -> None:
"""Initialize our thread-local variable for storing contexts."""
if not hasattr(_THREAD_LOCAL, "pids"):
_THREAD_LOCAL.pids = collections.defaultdict(dict)
Expand All @@ -28,7 +31,8 @@ def get_extra() -> _EXTRA_TYPE:
"""
init_extra()
pid = os.getpid()
return _THREAD_LOCAL.pids[pid]
pids: Mapping[int, _EXTRA_TYPE] = _THREAD_LOCAL.pids
return pids[pid]


def set_extra(extra: _EXTRA_TYPE) -> _EXTRA_TYPE:
Expand All @@ -41,8 +45,9 @@ def set_extra(extra: _EXTRA_TYPE) -> _EXTRA_TYPE:
"""
init_extra()
pid = os.getpid()
_THREAD_LOCAL.pids[pid] = extra
return _THREAD_LOCAL.pids[pid]
pids: DefaultDict[int, _EXTRA_TYPE] = _THREAD_LOCAL.pids
pids[pid] = extra
return pids[pid]


class Logger:
Expand All @@ -69,7 +74,9 @@ def __init__(
else:
self.base_logger = logger or logging.getLogger(name=name)

def _msg(self, func: Callable, msg, *args, **kwargs):
def _msg(
self, func: Callable[..., None], msg: Any, *args: Any, **kwargs: Any
) -> None:
"""Log with our extra values,"""
extra = {
**self.extra,
Expand All @@ -87,35 +94,35 @@ def extra(self) -> _EXTRA_TYPE:
"""Return the extra metadata that this logger sends with every message."""
return get_extra()

def debug(self, msg, *args, **kwargs):
def debug(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Debug."""
return self._msg(self.base_logger.debug, msg, *args, **kwargs)

def info(self, msg, *args, **kwargs):
def info(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Info."""
return self._msg(self.base_logger.info, msg, *args, **kwargs)

def warning(self, msg, *args, **kwargs):
def warning(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Warning."""
return self._msg(self.base_logger.warning, msg, *args, **kwargs)

def warn(self, msg, *args, **kwargs):
def warn(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Warn. Deprecated. Use `warning` instead."""
return self._msg(self.base_logger.warn, msg, *args, **kwargs)

def error(self, msg, *args, **kwargs):
def error(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Error."""
return self._msg(self.base_logger.error, msg, *args, **kwargs)

def critical(self, msg, *args, **kwargs):
def critical(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Critical."""
return self._msg(self.base_logger.critical, msg, *args, **kwargs)

def log(self, level, msg, *args, **kwargs):
def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Log."""
return self._msg(self.base_logger.log, level, msg, *args, **kwargs)

def exception(self, msg, *args, **kwargs):
def exception(self, msg: Any, *args: Any, **kwargs: Any) -> None:
"""Exception."""
return self._msg(self.base_logger.exception, msg, *args, **kwargs)

Expand All @@ -125,18 +132,23 @@ class add_logging_context:
A context manager to push and pop "extra" dictionary keys.
"""

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
"""Create a new context manager."""
self._new_extra = kwargs
self._old_extra = {}
self._old_extra: Mapping[str, Any] = {}

def __enter__(self):
def __enter__(self) -> "add_logging_context":
"""Add the new kwargs to the thread-local state."""
self._old_extra = get_extra()
set_extra({**self._old_extra, **self._new_extra})
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType] = None,
) -> Literal[False]:
"""Return the thread-local state to what it used to be."""
set_extra(self._old_extra)
return False
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ license = "MIT"

[tool.poetry.dependencies]
python = "^3.8"
typing-extensions = "^4.8.0"

[tool.poetry.dev-dependencies]
pytest = "^5.2"
pylint = "^2.12.2"
mypy = "^0.941"
black = "^22.1.0"
isort = "^5.10.1"
pytest = "^7.4"
pylint = "^3.0"
mypy = "^1.6.1"
black = "^23.10.1"
isort = "^5.12"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
20 changes: 10 additions & 10 deletions test_log_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import unittest
import unittest.mock
from typing import Any

from log_with_context import Logger, add_logging_context

Expand All @@ -20,7 +21,7 @@
]


def process_worker(val):
def process_worker(val: Any) -> None:
"""See how our logger interacts with multiprocessing."""
base_logger = unittest.mock.Mock(spec=LOGGER)
logger = Logger(logger=base_logger)
Expand All @@ -35,22 +36,21 @@ def process_worker(val):
class TestLogger(unittest.TestCase):
"""Unit tests for :py:class:`log_with_context.Logger`."""

def setUp(self):
def setUp(self) -> None:
self.base_logger = unittest.mock.Mock(spec=LOGGER)
self.logger = Logger(logger=self.base_logger)

def gl(self, method_name: str):
def gl(self, method_name: str) -> Any:
"""Execute a given method on our test logger."""
return getattr(self.logger, method_name)

def gb(self, method_name: str):
def gb(self, method_name: str) -> Any:
"""Execute a given method on our base logger mock."""
return getattr(self.base_logger, method_name)

def test_add_logging_context(self):
def test_add_logging_context(self) -> None:
"""Test that :py:class:`add_logging_context` works."""
for method_name in LOG_METHOD_NAMES:

with self.subTest(name=f"logger.{method_name}"):
self.gl(method_name)("1")
self.gb(method_name).assert_called_with("1", extra={}, stacklevel=3)
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_add_logging_context(self):
self.gl(method_name)("8")
self.gb(method_name).assert_called_with("8", extra={}, stacklevel=3)

def test_inline_extra(self):
def test_inline_extra(self) -> None:
"""Test that we can add one-off additions to our context."""
for method_name in LOG_METHOD_NAMES:
with self.subTest(name=f"logger.{method_name}"):
Expand All @@ -98,10 +98,10 @@ def test_inline_extra(self):
"%s", method_name, extra=dict(hello=method_name), stacklevel=3
)

def test_thread_local(self):
def test_thread_local(self) -> None:
"""Test that our logging context is indeed thread-local."""

def thread_worker(val):
def thread_worker(val: Any) -> None:
self.logger.debug("1")
self.base_logger.debug.assert_called_with("1", extra={}, stacklevel=3)
with add_logging_context(val=val):
Expand All @@ -114,7 +114,7 @@ def thread_worker(val):
with add_logging_context(a=1):
list(exc.map(thread_worker, [1, 2]))

def test_process_local(self):
def test_process_local(self) -> None:
"""Test how our logger works with multiple processes."""
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as exc:
with add_logging_context(a=1):
Expand Down

0 comments on commit f5935f6

Please sign in to comment.