Skip to content

Commit f26f115

Browse files
bhashemiannsrivathsa
authored andcommitted
Stain normalization (Project-MONAI#2666)
* added stain norm and tests Signed-off-by: Neha Srivathsa <[email protected]> * import changes Signed-off-by: Neha Srivathsa <[email protected]> * changed stain extraction tests Signed-off-by: Neha Srivathsa <[email protected]> * edited stain norm tests Signed-off-by: Neha Srivathsa <[email protected]> * convert floats to float32 Signed-off-by: Neha Srivathsa <[email protected]> * added uint8 assumption to docstring Signed-off-by: Neha Srivathsa <[email protected]> * add error case Signed-off-by: Neha Srivathsa <[email protected]> * formatting change Signed-off-by: Neha Srivathsa <[email protected]> * modify tests wrt cupy import Signed-off-by: Neha Srivathsa <[email protected]> * minor change to pass lint test Signed-off-by: Neha Srivathsa <[email protected]> * import changes Signed-off-by: Neha Srivathsa <[email protected]> * refactored classes Signed-off-by: Neha Srivathsa <[email protected]> * Restructure and rename transforms Signed-off-by: Behrooz <[email protected]> * added dict transform Signed-off-by: Neha Srivathsa <[email protected]> * Move stain_extractor to init Signed-off-by: Behrooz <[email protected]> * Exclude pathology transform tests from mini tests Signed-off-by: Behrooz <[email protected]> * Fix type checking for cupy ndarray Signed-off-by: Behrooz <[email protected]> * Include pathology transform tests Signed-off-by: Behrooz <[email protected]> * Update to cupy 9.0.0 Signed-off-by: Behrooz <[email protected]> * Remove exact version for cupy Signed-off-by: Behrooz <[email protected]> * add to docs Signed-off-by: Neha Srivathsa <[email protected]> * Organize into stain dir Signed-off-by: Behrooz <[email protected]> * Add/update init files Signed-off-by: Behrooz <[email protected]> * Transit all from cupy to numpy Signed-off-by: Behrooz <[email protected]> * Update imports Signed-off-by: Behrooz <[email protected]> * Update test cases for numpy Signed-off-by: Behrooz <[email protected]> * Rename to NormalizeHEStains and NormalizeHEStainsD Signed-off-by: Behrooz <[email protected]> * Add dictionary variant names Signed-off-by: Behrooz <[email protected]> * Fix typing and formatting Signed-off-by: Behrooz <[email protected]> * Fix docs Signed-off-by: Behrooz <[email protected]> * Update test cases Signed-off-by: Behrooz <[email protected]> * Fix clip max Signed-off-by: Behrooz <[email protected]> * Fix var typing Signed-off-by: Behrooz <[email protected]> * Fix a typing issue Signed-off-by: Behrooz <[email protected]> * Update default values, and change D to d Signed-off-by: Behrooz <[email protected]> * Update docs Signed-off-by: Behrooz <[email protected]> * Add image value check Signed-off-by: Behrooz <[email protected]> * Add test cases for negative and invalid values Signed-off-by: Behrooz <[email protected]> Co-authored-by: Neha Srivathsa <[email protected]> Co-authored-by: nsrivathsa <[email protected]>
1 parent c82b5cf commit f26f115

File tree

8 files changed

+838
-0
lines changed

8 files changed

+838
-0
lines changed

docs/source/apps.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,15 @@ Clara MMARs
9898
.. autofunction:: compute_isolated_tumor_cells
9999
.. autoclass:: PathologyProbNMS
100100
:members:
101+
102+
.. automodule:: monai.apps.pathology.transforms.stain.array
103+
.. autoclass:: ExtractHEStains
104+
:members:
105+
.. autoclass:: NormalizeHEStains
106+
:members:
107+
108+
.. automodule:: monai.apps.pathology.transforms.stain.dictionary
109+
.. autoclass:: ExtractHEStainsd
110+
:members:
111+
.. autoclass:: NormalizeHEStainsd
112+
:members:

monai/apps/pathology/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,13 @@
1212
from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset
1313
from .handlers import ProbMapProducer
1414
from .metrics import LesionFROC
15+
from .transforms.stain.array import ExtractHEStains, NormalizeHEStains
16+
from .transforms.stain.dictionary import (
17+
ExtractHEStainsd,
18+
ExtractHEStainsD,
19+
ExtractHEStainsDict,
20+
NormalizeHEStainsd,
21+
NormalizeHEStainsD,
22+
NormalizeHEStainsDict,
23+
)
1524
from .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from .stain.array import ExtractHEStains, NormalizeHEStains
13+
from .stain.dictionary import (
14+
ExtractHEStainsd,
15+
ExtractHEStainsD,
16+
ExtractHEStainsDict,
17+
NormalizeHEStainsd,
18+
NormalizeHEStainsD,
19+
NormalizeHEStainsDict,
20+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from .array import ExtractHEStains, NormalizeHEStains
13+
from .dictionary import (
14+
ExtractHEStainsd,
15+
ExtractHEStainsD,
16+
ExtractHEStainsDict,
17+
NormalizeHEStainsd,
18+
NormalizeHEStainsD,
19+
NormalizeHEStainsDict,
20+
)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Union
13+
14+
import numpy as np
15+
16+
from monai.transforms.transform import Transform
17+
18+
19+
class ExtractHEStains(Transform):
20+
"""Class to extract a target stain from an image, using stain deconvolution (see Note).
21+
22+
Args:
23+
tli: transmitted light intensity. Defaults to 240.
24+
alpha: tolerance in percentile for the pseudo-min (alpha percentile)
25+
and pseudo-max (100 - alpha percentile). Defaults to 1.
26+
beta: absorbance threshold for transparent pixels. Defaults to 0.15
27+
max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).
28+
Defaults to (1.9705, 1.0308).
29+
30+
Note:
31+
For more information refer to:
32+
- the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf
33+
- the previous implementations:
34+
35+
- MATLAB: https://github.com/mitkovetta/staining-normalization
36+
- Python: https://github.com/schaugf/HEnorm_python
37+
"""
38+
39+
def __init__(
40+
self,
41+
tli: float = 240,
42+
alpha: float = 1,
43+
beta: float = 0.15,
44+
max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308),
45+
) -> None:
46+
self.tli = tli
47+
self.alpha = alpha
48+
self.beta = beta
49+
self.max_cref = np.array(max_cref)
50+
51+
def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:
52+
"""Perform Stain Deconvolution and return stain matrix for the image.
53+
54+
Args:
55+
img: uint8 RGB image to perform stain deconvolution on
56+
57+
Return:
58+
he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values)
59+
"""
60+
# check image type and vlues
61+
if not isinstance(image, np.ndarray):
62+
raise TypeError("Image must be of type numpy.ndarray.")
63+
if image.min() < 0:
64+
raise ValueError("Image should not have negative values.")
65+
if image.max() > 255:
66+
raise ValueError("Image should not have values greater than 255.")
67+
68+
# reshape image and calculate absorbance
69+
image = image.reshape((-1, 3))
70+
image = image.astype(np.float32) + 1.0
71+
absorbance = -np.log(image.clip(max=self.tli) / self.tli)
72+
73+
# remove transparent pixels
74+
absorbance_hat = absorbance[np.all(absorbance > self.beta, axis=1)]
75+
if len(absorbance_hat) == 0:
76+
raise ValueError("All pixels of the input image are below the absorbance threshold.")
77+
78+
# compute eigenvectors
79+
_, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32))
80+
81+
# project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues
82+
t_hat = absorbance_hat.dot(eigvecs[:, 1:3])
83+
84+
# find the min and max vectors and project back to absorbance space
85+
phi = np.arctan2(t_hat[:, 1], t_hat[:, 0])
86+
min_phi = np.percentile(phi, self.alpha)
87+
max_phi = np.percentile(phi, 100 - self.alpha)
88+
v_min = eigvecs[:, 1:3].dot(np.array([(np.cos(min_phi), np.sin(min_phi))], dtype=np.float32).T)
89+
v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)
90+
91+
# a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second
92+
if v_min[0] > v_max[0]:
93+
he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T
94+
else:
95+
he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T
96+
97+
return he
98+
99+
def __call__(self, image: np.ndarray) -> np.ndarray:
100+
"""Perform stain extraction.
101+
102+
Args:
103+
image: uint8 RGB image to extract stain from
104+
105+
return:
106+
target_he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values)
107+
"""
108+
if not isinstance(image, np.ndarray):
109+
raise TypeError("Image must be of type numpy.ndarray.")
110+
111+
target_he = self._deconvolution_extract_stain(image)
112+
return target_he
113+
114+
115+
class NormalizeHEStains(Transform):
116+
"""Class to normalize patches/images to a reference or target image stain (see Note).
117+
118+
Performs stain deconvolution of the source image using the ExtractHEStains
119+
class, to obtain the stain matrix and calculate the stain concentration matrix
120+
for the image. Then, performs the inverse Beer-Lambert transform to recreate the
121+
patch using the target H&E stain matrix provided. If no target stain provided, a default
122+
reference stain is used. Similarly, if no maximum stain concentrations are provided, a
123+
reference maximum stain concentrations matrix is used.
124+
125+
Args:
126+
tli: transmitted light intensity. Defaults to 240.
127+
alpha: tolerance in percentile for the pseudo-min (alpha percentile) and
128+
pseudo-max (100 - alpha percentile). Defaults to 1.
129+
beta: absorbance threshold for transparent pixels. Defaults to 0.15.
130+
target_he: target stain matrix. Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)).
131+
max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).
132+
Defaults to [1.9705, 1.0308].
133+
134+
Note:
135+
For more information refer to:
136+
- the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf
137+
- the previous implementations:
138+
139+
- MATLAB: https://github.com/mitkovetta/staining-normalization
140+
- Python: https://github.com/schaugf/HEnorm_python
141+
"""
142+
143+
def __init__(
144+
self,
145+
tli: float = 240,
146+
alpha: float = 1,
147+
beta: float = 0.15,
148+
target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)),
149+
max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308),
150+
) -> None:
151+
self.tli = tli
152+
self.target_he = np.array(target_he)
153+
self.max_cref = np.array(max_cref)
154+
self.stain_extractor = ExtractHEStains(tli=self.tli, alpha=alpha, beta=beta, max_cref=self.max_cref)
155+
156+
def __call__(self, image: np.ndarray) -> np.ndarray:
157+
"""Perform stain normalization.
158+
159+
Args:
160+
image: uint8 RGB image/patch to be stain normalized, pixel values between 0 and 255
161+
162+
Return:
163+
image_norm: stain normalized image/patch
164+
"""
165+
# check image type and vlues
166+
if not isinstance(image, np.ndarray):
167+
raise TypeError("Image must be of type numpy.ndarray.")
168+
if image.min() < 0:
169+
raise ValueError("Image should not have negative values.")
170+
if image.max() > 255:
171+
raise ValueError("Image should not have values greater than 255.")
172+
173+
# extract stain of the image
174+
he = self.stain_extractor(image)
175+
176+
# reshape image and calculate absorbance
177+
h, w, _ = image.shape
178+
image = image.reshape((-1, 3))
179+
image = image.astype(np.float32) + 1.0
180+
absorbance = -np.log(image.clip(max=self.tli) / self.tli)
181+
182+
# rows correspond to channels (RGB), columns to absorbance values
183+
y = np.reshape(absorbance, (-1, 3)).T
184+
185+
# determine concentrations of the individual stains
186+
conc = np.linalg.lstsq(he, y, rcond=None)[0]
187+
188+
# normalize stain concentrations
189+
max_conc = np.array([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32)
190+
tmp = np.divide(max_conc, self.max_cref, dtype=np.float32)
191+
image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32)
192+
193+
image_norm: np.ndarray = np.multiply(self.tli, np.exp(-self.target_he.dot(image_c)), dtype=np.float32)
194+
image_norm[image_norm > 255] = 254
195+
image_norm = np.reshape(image_norm.T, (h, w, 3)).astype(np.uint8)
196+
return image_norm
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
A collection of dictionary-based wrappers around the pathology transforms
13+
defined in :py:class:`monai.apps.pathology.transforms.array`.
14+
15+
Class names are ended with 'd' to denote dictionary-based transforms.
16+
"""
17+
18+
from typing import Dict, Hashable, Mapping, Union
19+
20+
import numpy as np
21+
22+
from monai.config import KeysCollection
23+
from monai.transforms.transform import MapTransform
24+
25+
from .array import ExtractHEStains, NormalizeHEStains
26+
27+
28+
class ExtractHEStainsd(MapTransform):
29+
"""Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.ExtractHEStains`.
30+
Class to extract a target stain from an image, using stain deconvolution.
31+
32+
Args:
33+
keys: keys of the corresponding items to be transformed.
34+
See also: :py:class:`monai.transforms.compose.MapTransform`
35+
tli: transmitted light intensity. Defaults to 240.
36+
alpha: tolerance in percentile for the pseudo-min (alpha percentile)
37+
and pseudo-max (100 - alpha percentile). Defaults to 1.
38+
beta: absorbance threshold for transparent pixels. Defaults to 0.15
39+
max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).
40+
Defaults to (1.9705, 1.0308).
41+
allow_missing_keys: don't raise exception if key is missing.
42+
43+
"""
44+
45+
def __init__(
46+
self,
47+
keys: KeysCollection,
48+
tli: float = 240,
49+
alpha: float = 1,
50+
beta: float = 0.15,
51+
max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308),
52+
allow_missing_keys: bool = False,
53+
) -> None:
54+
super().__init__(keys, allow_missing_keys)
55+
self.extractor = ExtractHEStains(tli=tli, alpha=alpha, beta=beta, max_cref=max_cref)
56+
57+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
58+
d = dict(data)
59+
for key in self.key_iterator(d):
60+
d[key] = self.extractor(d[key])
61+
return d
62+
63+
64+
class NormalizeHEStainsd(MapTransform):
65+
"""Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.NormalizeHEStains`.
66+
67+
Class to normalize patches/images to a reference or target image stain.
68+
69+
Performs stain deconvolution of the source image using the ExtractHEStains
70+
class, to obtain the stain matrix and calculate the stain concentration matrix
71+
for the image. Then, performs the inverse Beer-Lambert transform to recreate the
72+
patch using the target H&E stain matrix provided. If no target stain provided, a default
73+
reference stain is used. Similarly, if no maximum stain concentrations are provided, a
74+
reference maximum stain concentrations matrix is used.
75+
76+
Args:
77+
keys: keys of the corresponding items to be transformed.
78+
See also: :py:class:`monai.transforms.compose.MapTransform`
79+
tli: transmitted light intensity. Defaults to 240.
80+
alpha: tolerance in percentile for the pseudo-min (alpha percentile) and
81+
pseudo-max (100 - alpha percentile). Defaults to 1.
82+
beta: absorbance threshold for transparent pixels. Defaults to 0.15.
83+
target_he: target stain matrix. Defaults to None.
84+
max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).
85+
Defaults to None.
86+
allow_missing_keys: don't raise exception if key is missing.
87+
88+
"""
89+
90+
def __init__(
91+
self,
92+
keys: KeysCollection,
93+
tli: float = 240,
94+
alpha: float = 1,
95+
beta: float = 0.15,
96+
target_he: Union[tuple, np.ndarray] = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)),
97+
max_cref: Union[tuple, np.ndarray] = (1.9705, 1.0308),
98+
allow_missing_keys: bool = False,
99+
) -> None:
100+
super().__init__(keys, allow_missing_keys)
101+
self.normalizer = NormalizeHEStains(tli=tli, alpha=alpha, beta=beta, target_he=target_he, max_cref=max_cref)
102+
103+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
104+
d = dict(data)
105+
for key in self.key_iterator(d):
106+
d[key] = self.normalizer(d[key])
107+
return d
108+
109+
110+
ExtractHEStainsDict = ExtractHEStainsD = ExtractHEStainsd
111+
NormalizeHEStainsDict = NormalizeHEStainsD = NormalizeHEStainsd

0 commit comments

Comments
 (0)