Skip to content

Commit

Permalink
Merge branch 'threadpool'
Browse files Browse the repository at this point in the history
  • Loading branch information
JWCook committed Nov 28, 2023
2 parents 527130f + 5f7f719 commit 3233263
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 39 deletions.
41 changes: 26 additions & 15 deletions naturtag/app/threadpool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Adapted from examples in Python & Qt6 by Martin Fitzpatrick"""
from logging import getLogger
from threading import RLock
from typing import Callable
from typing import Callable, Optional

from PySide6.QtCore import (
QEasingCurve,
Expand All @@ -12,6 +12,7 @@
QThreadPool,
QTimer,
Signal,
Slot,
)
from PySide6.QtWidgets import QGraphicsOpacityEffect, QProgressBar

Expand All @@ -29,23 +30,30 @@ def __init__(self, **kwargs):
self.progress = ProgressBar()

def schedule(
self, callback: Callable, priority: QThread.Priority = QThread.NormalPriority, **kwargs
self,
callback: Callable,
priority: QThread.Priority = QThread.NormalPriority,
total_results: Optional[int] = None,
increment_length: bool = False,
**kwargs,
) -> 'WorkerSignals':
"""Schedule a task to be run by the next available worker thread"""
self.progress.add()
worker = Worker(callback, **kwargs)
self.progress.add(total_results or 1)
worker = Worker(callback, increment_length=increment_length, **kwargs)
worker.signals.on_progress.connect(self.progress.advance)
self.start(worker, priority.value)
return worker.signals

def schedule_all(self, callbacks: list[Callable], **kwargs) -> list['WorkerSignals']:
"""Schedule multiple tasks to be run by the next available worker thread"""
self.progress.add(len(callbacks))
for callback in callbacks:
worker = Worker(callback, **kwargs)
worker.signals.on_progress.connect(self.progress.advance)
self.start(worker)
return worker.signals
# def schedule_all(self, callbacks: list[Callable], **kwargs) -> list['WorkerSignals']:
# """Schedule multiple tasks to be run by the next available worker thread"""
# self.progress.add(len(callbacks))
# all_signals = []
# for callback in callbacks:
# worker = Worker(callback, **kwargs)
# worker.signals.on_progress.connect(self.progress.advance)
# self.start(worker)
# all_signals.append(worker.signals)
# return all_signals

def cancel(self):
"""Cancel all queued tasks and reset progress bar. Currently running tasks will be allowed
Expand All @@ -62,11 +70,12 @@ class Worker(QRunnable):
done.
"""

def __init__(self, callback: Callable, **kwargs):
def __init__(self, callback: Callable, increment_length: bool = False, **kwargs):
super().__init__()
self.callback = callback
self.kwargs = kwargs
self.signals = WorkerSignals()
self.increment_length = increment_length

def run(self):
try:
Expand All @@ -76,15 +85,16 @@ def run(self):
self.signals.on_error.emit(e)
else:
self.signals.on_result.emit(result)
self.signals.on_progress.emit()
increment = len(result) if self.increment_length and isinstance(result, list) else 1
self.signals.on_progress.emit(increment)


class WorkerSignals(QObject):
"""Signals used by a worker thread (can't be set directly on a QRunnable)"""

on_error = Signal(Exception) #: Return exception info on error
on_result = Signal(object) #: Return result on completion
on_progress = Signal() #: Increment progress bar
on_progress = Signal(int) #: Increment progress bar


class ProgressBar(QProgressBar):
Expand All @@ -108,6 +118,7 @@ def add(self, amount: int = 1):
with self.lock:
self.setMaximum(self.maximum() + amount)

@Slot(int)
def advance(self, amount: int = 1):
with self.lock:
new_value = min(self.value() + amount, self.maximum())
Expand Down
61 changes: 37 additions & 24 deletions naturtag/controllers/observation_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, *args, **kwargs):
# Add a delay before loading user observations on startup
QTimer.singleShot(1, self.load_user_observations)

# Actions triggered directly by UI
# ----------------------------------------

def select_observation(self, observation_id: int):
"""Select an observation to display full details"""
# Don't need to do anything if this observation is already selected
Expand All @@ -93,15 +96,40 @@ def select_observation(self, observation_id: int):
)
future.on_result.connect(self.display_observation)

def load_user_observations(self):
"""Fetch and display a single page of user observations"""
logger.info('Fetching user observations')
future = self.threadpool.schedule(self.get_user_observations, priority=QThread.LowPriority)
future.on_result.connect(self.display_user_observations)

def next_page(self):
if self.page < self.total_pages:
self.page += 1
self.load_user_observations()

def prev_page(self):
if self.page > 1:
self.page -= 1
self.load_user_observations()

def refresh(self):
self.page = 1
self.load_user_observations()

# UI helper functions (slots triggered after worker threads complete)
# ----------------------------------------

@Slot(Observation)
def display_observation(self, observation: Observation):
"""Display full details for a single observation"""
self.selected_observation = observation
self.on_select.emit(observation)
self.obs_info.load(observation)
logger.debug(f'Loaded observation {observation.id}')

@Slot(list)
def display_user_observations(self, observations: list[Observation]):
"""Display a page of observations"""
# Update observation list
self.user_observations.set_observations(observations)
self.bind_selection(self.user_observations.cards)
Expand All @@ -111,14 +139,18 @@ def display_user_observations(self, observations: list[Observation]):
self.next_button.setEnabled(self.page < self.total_pages)
self.page_label.setText(f'Page {self.page} / {self.total_pages}')

def load_user_observations(self):
logger.info('Fetching user observations')
future = self.threadpool.schedule(self.get_user_observations, priority=QThread.LowPriority)
future.on_result.connect(self.display_user_observations)
def bind_selection(self, obs_cards: Iterable[ObservationInfoCard]):
"""Connect click signal from each observation card"""
for obs_card in obs_cards:
obs_card.on_click.connect(self.select_observation)

# I/O bound functions run from worker threads
# ----------------------------------------

# TODO: Handle casual_observations setting
# TODO: Handle casual_observations setting?
# TODO: Store a Paginator object instead of page number?
def get_user_observations(self) -> list[Observation]:
"""Fetch a single page of user observations"""
if not self.settings.username:
return []

Expand All @@ -138,22 +170,3 @@ def get_user_observations(self) -> list[Observation]:
page=self.page,
)
return observations

def next_page(self):
if self.page < self.total_pages:
self.page += 1
self.load_user_observations()

def prev_page(self):
if self.page > 1:
self.page -= 1
self.load_user_observations()

def refresh(self):
self.page = 1
self.load_user_observations()

def bind_selection(self, obs_cards: Iterable[ObservationInfoCard]):
"""Connect click signal from each observation card"""
for obs_card in obs_cards:
obs_card.on_click.connect(self.select_observation)

0 comments on commit 3233263

Please sign in to comment.