Skip to content

Commit

Permalink
Merge pull request #3 from seung-lab/wms_grey
Browse files Browse the repository at this point in the history
feat: add grey dilation and erosion operators
  • Loading branch information
william-silversmith committed Dec 31, 2023
2 parents 4715659 + aa73645 commit e8162fc
Show file tree
Hide file tree
Showing 4 changed files with 447 additions and 12 deletions.
65 changes: 63 additions & 2 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_spherical_close():
assert res[5,5,5] == True


def test_dilate():
def test_multilabel_dilate():
labels = np.zeros((3,3,3), dtype=bool)

out = fastmorph.dilate(labels)
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_dilate():
assert np.all(ans == out)


def test_erode():
def test_multilabel_erode():
labels = np.ones((3,3,3), dtype=bool)
out = fastmorph.erode(labels)
assert np.sum(out) == 1 and out[1,1,1] == True
Expand Down Expand Up @@ -148,4 +148,65 @@ def test_erode():
assert np.sum(out) == 27


def test_grey_erode():
labels = np.arange(27, dtype=int).reshape((3,3,3), order="F")
out = fastmorph.erode(labels, mode=fastmorph.Mode.grey)

ans = np.array([
[
[0, 0, 1],
[0, 0, 1],
[3, 3, 4],
],
[
[0, 0, 1],
[0, 0, 1],
[3, 3, 4],
],
[
[9, 9, 10],
[9, 9, 10],
[12, 12, 13],
],
]).T

assert np.all(out == ans)

out = fastmorph.erode(out, mode=fastmorph.Mode.grey)
assert np.all(out == 0)


def test_grey_dilate():
L = 100
H = 200

labels = np.zeros((3,3,3), dtype=int)
labels[0,0,0] = L
labels[2,2,2] = H

out = fastmorph.dilate(labels, mode=fastmorph.Mode.grey)

ans = np.array([
[
[L, L, 0],
[L, L, 0],
[0, 0, 0],
],
[
[L, L, 0],
[L, H, H],
[0, H, H],
],
[
[0, 0, 0],
[0, H, H],
[0, H, H],
],
]).T

assert np.all(out == ans)

out = fastmorph.dilate(out, mode=fastmorph.Mode.grey)
assert np.all(out == H)


29 changes: 24 additions & 5 deletions fastmorph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Optional, Sequence
import numpy as np
import edt
Expand All @@ -10,10 +11,15 @@

AnisotropyType = Optional[Sequence[int]]

class Mode(Enum):
multilabel = 1
grey = 2

def dilate(
labels:np.ndarray,
background_only:bool = True,
parallel:int = 1,
mode:Mode = Mode.multilabel,
) -> np.ndarray:
"""
Dilate forground labels using a 3x3x3 stencil with
Expand All @@ -37,10 +43,17 @@ def dilate(
labels = np.asfortranarray(labels)
while labels.ndim < 3:
labels = labels[..., np.newaxis]
output = fastmorphops.dilate(labels, background_only, parallel)
if mode == Mode.multilabel:
output = fastmorphops.multilabel_dilate(labels, background_only, parallel)
else:
output = fastmorphops.grey_dilate(labels, parallel)
return output.view(labels.dtype)

def erode(labels:np.ndarray, parallel:int = 1) -> np.ndarray:
def erode(
labels:np.ndarray,
parallel:int = 1,
mode:Mode = Mode.multilabel,
) -> np.ndarray:
"""
Erodes forground labels using a 3x3x3 stencil with
all elements "on".
Expand All @@ -55,13 +68,18 @@ def erode(labels:np.ndarray, parallel:int = 1) -> np.ndarray:
labels = np.asfortranarray(labels)
while labels.ndim < 3:
labels = labels[..., np.newaxis]
output = fastmorphops.erode(labels, parallel)

if mode == Mode.multilabel:
output = fastmorphops.multilabel_erode(labels, parallel)
else:
output = fastmorphops.grey_erode(labels, parallel)
return output.view(labels.dtype)

def opening(
labels:np.ndarray,
background_only:bool = True,
parallel:int = 1,
mode:Mode = Mode.multilabel,
) -> np.ndarray:
"""Performs morphological opening of labels.
Expand All @@ -70,11 +88,12 @@ def opening(
False: Allow labels to erode each other as they grow.
parallel: how many pthreads to use in a threadpool
"""
return dilate(erode(labels, parallel), background_only, parallel)
return dilate(erode(labels, parallel, mode), background_only, parallel, mode)

def closing(
labels:np.ndarray,
background_only:bool = True,
mode:Mode = Mode.multilabel,
) -> np.ndarray:
"""Performs morphological closing of labels.
Expand All @@ -83,7 +102,7 @@ def closing(
False: Allow labels to erode each other as they grow.
parallel: how many pthreads to use in a threadpool
"""
return erode(dilate(labels, background_only, parallel), parallel)
return erode(dilate(labels, background_only, parallel, mode), parallel, mode)

def spherical_dilate(
labels:np.ndarray,
Expand Down
Loading

0 comments on commit e8162fc

Please sign in to comment.