Skip to content

Commit

Permalink
Tables detection f1 (#3341)
Browse files Browse the repository at this point in the history
This pull request add table detection metrics.

One case that was considered by me:

Case: Two tables are predicted and matched with one table in ground
truth
Question: Is this matching correct in both cases or just for on table

There are two subcases:
- table was predicted by OD as two sub tables (so half in two, there are
two non overlapping subtables) -> in my opinion both are correct
- it is false positive from tables matching script in
get_table_level_alignment -> 1 good, 1 wrong

As we don't have bounding boxes I followed the notebook calculation
script and assumed pessimistic, second subcase version
  • Loading branch information
plutasnyy committed Jul 8, 2024
1 parent 3fc2342 commit caea73c
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 4 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
## 0.14.10-dev10
## 0.14.10-dev11

### Enhancements

* **Update unstructured-client dependency** Change unstructured-client dependency pin back to
greater than min version and updated tests that were failing given the update.
* **`.doc` files are now supported in the `arm64` image.**. `libreoffice24` is added to the `arm64` image, meaning `.doc` files are now supported. We have follow on work planned to investigate adding `.ppt` support for `arm64` as well.
* Add table detection metrics: recall, precision and f1

### Features

Expand Down
4 changes: 2 additions & 2 deletions test_unstructured/metrics/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_text_extraction_evaluation():
UNSTRUCTURED_TABLE_STRUCTURE_DIRNAME,
GOLD_TABLE_STRUCTURE_DIRNAME,
Path("IRS-2023-Form-1095-A.pdf.json"),
17,
23,
{},
),
(
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_table_structure_evaluation():
assert os.path.isfile(os.path.join(export_dir, "aggregate-table-structure-accuracy.tsv"))
df = pd.read_csv(os.path.join(export_dir, "all-docs-table-structure-accuracy.tsv"), sep="\t")
assert len(df) == 1
assert len(df.columns) == 17
assert len(df.columns) == 23
assert df.iloc[0].filename == "IRS-2023-Form-1095-A.pdf"


Expand Down
33 changes: 33 additions & 0 deletions test_unstructured/metrics/test_table_detection_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from unstructured.metrics.table.table_eval import calculate_table_detection_metrics


@pytest.mark.parametrize(
("matched_indices", "ground_truth_tables_number", "expected_metrics"),
[
([0, 1, 2], 3, (1, 1, 1)), # everything was predicted correctly
([2, 1, 0], 3, (1, 1, 1)), # everything was predicted correctly
(
[-1, 2, -1, 1, 0, -1],
3,
(1, 0.5, 0.66),
), # some false positives, all tables matched, too many predictions
([2, 2, 1, 1], 8, (0.25, 0.5, 0.33)),
# Some false negatives, all predictions matched with gt, not enough predictions
# The precision here is not 1 as only one from tables matched with '1' index can be correct
([1, -1], 2, (0.5, 0.5, 0.5)), # typical case with false positive and false negative
([-1, -1, -1], 2, (0, 0, 0)), # nothing was matched
([-1, -1, -1], 0, (0, 0, 0)), # there was no table in ground truth
([], 0, (0, 0, 0)), # just zeros to account for errors
],
)
def test_calculate_table_metrics(matched_indices, ground_truth_tables_number, expected_metrics):
expected_recall, expected_precision, expected_f1 = expected_metrics
pred_recall, pred_precision, pred_f1 = calculate_table_detection_metrics(
matched_indices=matched_indices, ground_truth_tables_number=ground_truth_tables_number
)

assert pred_recall == expected_recall
assert pred_precision == expected_precision
assert pred_f1 == pytest.approx(expected_f1, abs=0.01)
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.14.10-dev10" # pragma: no cover
__version__ = "0.14.10-dev11" # pragma: no cover
3 changes: 3 additions & 0 deletions unstructured/metrics/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def supported_metric_names(self):
return [
"total_tables",
"table_level_acc",
"table_detection_recall",
"table_detection_precision",
"table_detection_f1",
"composite_structure_acc",
"element_col_level_index_acc",
"element_row_level_index_acc",
Expand Down
55 changes: 55 additions & 0 deletions unstructured/metrics/table/table_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class TableEvaluation:

total_tables: int
table_level_acc: float
table_detection_recall: float
table_detection_precision: float
table_detection_f1: float
element_col_level_index_acc: float
element_row_level_index_acc: float
element_col_level_content_acc: float
Expand Down Expand Up @@ -91,6 +94,42 @@ def _count_predicted_tables(matched_indices: List[int]) -> int:
return sum(1 for idx in matched_indices if idx >= 0)


def calculate_table_detection_metrics(
matched_indices: list[int], ground_truth_tables_number: int
) -> tuple[float, float, float]:
"""
Calculate the table detection metrics: recall, precision, and f1 score.
Args:
matched_indices:
List of indices indicating matches between predicted and ground truth tables
For example: matched_indices[i] = j means that the
i-th predicted table is matched with the j-th ground truth table.
ground_truth_tables_number: the number of ground truth tables.
Returns:
Tuple of recall, precision, and f1 scores
"""
predicted_tables_number = len(matched_indices)

matched_set = set(matched_indices)
if -1 in matched_set:
matched_set.remove(-1)

true_positive = len(matched_set)
false_positive = predicted_tables_number - true_positive
positive = ground_truth_tables_number

recall = true_positive / positive if positive > 0 else 0
precision = (
true_positive / (true_positive + false_positive)
if true_positive + false_positive > 0
else 0
)
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0

return recall, precision, f1


class TableEvalProcessor:
def __init__(
self,
Expand Down Expand Up @@ -209,6 +248,9 @@ def process_file(self) -> TableEvaluation:
return TableEvaluation(
total_tables=0,
table_level_acc=table_acc,
table_detection_recall=score,
table_detection_precision=score,
table_detection_f1=score,
element_col_level_index_acc=score,
element_row_level_index_acc=score,
element_col_level_content_acc=score,
Expand All @@ -218,6 +260,9 @@ def process_file(self) -> TableEvaluation:
return TableEvaluation(
total_tables=len(ground_truth_table_data),
table_level_acc=0,
table_detection_recall=0,
table_detection_precision=0,
table_detection_f1=0,
element_col_level_index_acc=0,
element_row_level_index_acc=0,
element_col_level_content_acc=0,
Expand All @@ -240,9 +285,19 @@ def process_file(self) -> TableEvaluation:
cutoff=self.cutoff,
)

table_detection_recall, table_detection_precision, table_detection_f1 = (
calculate_table_detection_metrics(
matched_indices=matched_indices,
ground_truth_tables_number=len(ground_truth_table_data),
)
)

evaluation = TableEvaluation(
total_tables=len(ground_truth_table_data),
table_level_acc=predicted_table_acc,
table_detection_recall=table_detection_recall,
table_detection_precision=table_detection_precision,
table_detection_f1=table_detection_f1,
element_col_level_index_acc=metrics.get("col_index_acc", 0),
element_row_level_index_acc=metrics.get("row_index_acc", 0),
element_col_level_content_acc=metrics.get("col_content_acc", 0),
Expand Down

0 comments on commit caea73c

Please sign in to comment.