Skip to content

Commit

Permalink
Merge pull request #28 from mbkroese/mkroese/add-algo
Browse files Browse the repository at this point in the history
Improve least_duration algorithm by sorting durations
  • Loading branch information
jerry-git committed Jul 21, 2021
2 parents ce4309f + 3236751 commit 32b73ea
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 29 deletions.
51 changes: 35 additions & 16 deletions src/pytest_split/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import functools
import heapq
from operator import itemgetter
from typing import TYPE_CHECKING, NamedTuple

if TYPE_CHECKING:
Expand All @@ -18,17 +19,26 @@ class TestGroup(NamedTuple):
def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]") -> "List[TestGroup]":
"""
Split tests into groups by runtime.
Assigns the test with the largest runtime to the test with the smallest
duration sum.
It walks the test items, starting with the test with largest duration.
It assigns the test with the largest runtime to the group with the smallest duration sum.
The algorithm sorts the items by their duration. Since the sorting algorithm is stable, ties will be broken by
maintaining the original order of items. It is therefore important that the order of items be identical on all nodes
that use this plugin. Due to issue #25 this might not always be the case.
:param splits: How many groups we're splitting in.
:param items: Test items passed down by Pytest.
:param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
:return:
List of groups
"""
durations = _remove_irrelevant_durations(items, durations)
avg_duration_per_test = _get_avg_duration_per_test(durations)
items_with_durations = _get_items_with_durations(items, durations)

# add index of item in list
items_with_durations = [(*tup, i) for i, tup in enumerate(items_with_durations)]

# sort in ascending order
sorted_items_with_durations = sorted(items_with_durations, key=lambda tup: tup[1], reverse=True)

selected: "List[List[nodes.Item]]" = [[] for i in range(splits)]
deselected: "List[List[nodes.Item]]" = [[] for i in range(splits)]
Expand All @@ -37,15 +47,13 @@ def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str,
# create a heap of the form (summed_durations, group_index)
heap: "List[Tuple[float, int]]" = [(0, i) for i in range(splits)]
heapq.heapify(heap)
for item in items:
item_duration = durations.get(item.nodeid, avg_duration_per_test)

for item, item_duration, original_index in sorted_items_with_durations:
# get group with smallest sum
summed_durations, group_idx = heapq.heappop(heap)
new_group_durations = summed_durations + item_duration

# store assignment
selected[group_idx].append(item)
selected[group_idx].append((item, original_index))
duration[group_idx] = new_group_durations
for i in range(splits):
if i != group_idx:
Expand All @@ -54,7 +62,14 @@ def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str,
# store new duration - in case of ties it sorts by the group_idx
heapq.heappush(heap, (new_group_durations, group_idx))

return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)]
groups = []
for i in range(splits):
# sort the items by their original index to maintain relative ordering
# we don't care about the order of deselected items
s = [item for item, original_index in sorted(selected[i], key=lambda tup: tup[1])]
group = TestGroup(selected=s, deselected=deselected[i], duration=duration[i])
groups.append(group)
return groups


def duration_based_chunks(splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]") -> "List[TestGroup]":
Expand All @@ -69,30 +84,34 @@ def duration_based_chunks(splits: int, items: "List[nodes.Item]", durations: "Di
:param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
:return: List of TestGroup
"""
durations = _remove_irrelevant_durations(items, durations)
avg_duration_per_test = _get_avg_duration_per_test(durations)

tests_and_durations = {item: durations.get(item.nodeid, avg_duration_per_test) for item in items}
time_per_group = sum(tests_and_durations.values()) / splits
items_with_durations = _get_items_with_durations(items, durations)
time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits

selected: "List[List[nodes.Item]]" = [[] for i in range(splits)]
deselected: "List[List[nodes.Item]]" = [[] for i in range(splits)]
duration: "List[float]" = [0 for i in range(splits)]

group_idx = 0
for item in items:
for item, item_duration in items_with_durations:
if duration[group_idx] >= time_per_group:
group_idx += 1

selected[group_idx].append(item)
for i in range(splits):
if i != group_idx:
deselected[i].append(item)
duration[group_idx] += tests_and_durations.pop(item)
duration[group_idx] += item_duration

return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)]


def _get_items_with_durations(items, durations):
durations = _remove_irrelevant_durations(items, durations)
avg_duration_per_test = _get_avg_duration_per_test(durations)
items_with_durations = [(item, durations.get(item.nodeid, avg_duration_per_test)) for item in items]
return items_with_durations


def _get_avg_duration_per_test(durations: "Dict[str, float]") -> float:
if durations:
avg_duration_per_test = sum(durations.values()) / len(durations)
Expand Down
25 changes: 21 additions & 4 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ def test__split_tests_handles_tests_with_missing_durations(self, algo_name):
assert first.selected == [item("a")]
assert second.selected == [item("b")]

@pytest.mark.parametrize("algo_name", Algorithms.names())
@pytest.mark.skip("current algorithm does handle this well")
def test__split_test_handles_large_duration_at_end(self, algo_name):
def test__split_test_handles_large_duration_at_end(self):
"""NOTE: only least_duration does this correctly"""
durations = {"a": 1, "b": 1, "c": 1, "d": 3}
items = [item(x) for x in ["a", "b", "c", "d"]]
algo = Algorithms[algo_name].value
algo = Algorithms["least_duration"].value
splits = algo(splits=2, items=items, durations=durations)

first, second = splits
Expand Down Expand Up @@ -83,3 +82,21 @@ def test__split_tests_calculates_avg_test_duration_only_on_present_tests(self, a
expected_first, expected_second = expected
assert first.selected == expected_first
assert second.selected == expected_second

@pytest.mark.parametrize(
"algo_name, expected",
[
("duration_based_chunks", [[item("a"), item("b"), item("c"), item("d"), item("e")], []]),
("least_duration", [[item("e")], [item("a"), item("b"), item("c"), item("d")]]),
],
)
def test__split_tests_maintains_relative_order_of_tests(self, algo_name, expected):
durations = {"a": 2, "b": 3, "c": 4, "d": 5, "e": 10000}
items = [item(x) for x in ["a", "b", "c", "d", "e"]]
algo = Algorithms[algo_name].value
splits = algo(splits=2, items=items, durations=durations)

first, second = splits
expected_first, expected_second = expected
assert first.selected == expected_first
assert second.selected == expected_second
18 changes: 9 additions & 9 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ class TestSplitToSuites:
),
(2, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4", "test_5", "test_6", "test_7"]),
(2, 2, "duration_based_chunks", ["test_8", "test_9", "test_10"]),
(2, 1, "least_duration", ["test_1", "test_3", "test_5", "test_7", "test_9"]),
(2, 2, "least_duration", ["test_2", "test_4", "test_6", "test_8", "test_10"]),
(2, 1, "least_duration", ["test_3", "test_5", "test_6", "test_8", "test_10"]),
(2, 2, "least_duration", ["test_1", "test_2", "test_4", "test_7", "test_9"]),
(3, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4", "test_5"]),
(3, 2, "duration_based_chunks", ["test_6", "test_7", "test_8"]),
(3, 3, "duration_based_chunks", ["test_9", "test_10"]),
(3, 1, "least_duration", ["test_1", "test_4", "test_7", "test_10"]),
(3, 2, "least_duration", ["test_2", "test_5", "test_8"]),
(3, 3, "least_duration", ["test_3", "test_6", "test_9"]),
(3, 1, "least_duration", ["test_3", "test_6", "test_9"]),
(3, 2, "least_duration", ["test_4", "test_7", "test_10"]),
(3, 3, "least_duration", ["test_1", "test_2", "test_5", "test_8"]),
(4, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4"]),
(4, 2, "duration_based_chunks", ["test_5", "test_6", "test_7"]),
(4, 3, "duration_based_chunks", ["test_8", "test_9"]),
(4, 4, "duration_based_chunks", ["test_10"]),
(4, 1, "least_duration", ["test_1", "test_5", "test_9"]),
(4, 2, "least_duration", ["test_2", "test_6", "test_10"]),
(4, 3, "least_duration", ["test_3", "test_7"]),
(4, 4, "least_duration", ["test_4", "test_8"]),
(4, 1, "least_duration", ["test_6", "test_10"]),
(4, 2, "least_duration", ["test_1", "test_4", "test_7"]),
(4, 3, "least_duration", ["test_2", "test_5", "test_8"]),
(4, 4, "least_duration", ["test_3", "test_9"]),
]
legacy_duration = [True, False]
all_params = [(*param, legacy_flag) for param, legacy_flag in itertools.product(parameters, legacy_duration)]
Expand Down

0 comments on commit 32b73ea

Please sign in to comment.