Skip to content

Commit 11a1399

Browse files
committed
Add ultrasound confidence map to transforms
1 parent 17c1e3a commit 11a1399

File tree

1 file changed

+309
-1
lines changed

1 file changed

+309
-1
lines changed

monai/transforms/intensity/array.py

Lines changed: 309 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import abstractmethod
1818
from collections.abc import Callable, Iterable, Sequence
1919
from functools import partial
20-
from typing import Any
20+
from typing import Any, Tuple, Literal
2121
from warnings import warn
2222

2323
import numpy as np
@@ -37,6 +37,10 @@
3737
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
3838

3939
skimage, _ = optional_import("skimage", "0.19.0", min_version)
40+
cv2, _ = optional_import("cv2")
41+
Oct2Py, _ = optional_import("oct2py", "5.6.0", min_version, "Oct2Py")
42+
csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix")
43+
hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert")
4044

4145
__all__ = [
4246
"RandGaussianNoise",
@@ -77,6 +81,7 @@
7781
"RandIntensityRemap",
7882
"ForegroundMask",
7983
"ComputeHoVerMaps",
84+
"UltrasoundConfidenceMap",
8085
]
8186

8287

@@ -2577,3 +2582,306 @@ def __call__(self, mask: NdarrayOrTensor):
25772582

25782583
hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta())
25792584
return hv_maps
2585+
2586+
class UltrasoundConfidenceMap(Transform):
2587+
"""Compute confidence map from an ultrasound image.
2588+
This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.
2589+
It generates a confidence map by setting source and sink points in the image and computing the probability
2590+
for random walks to reach the source for each pixel.
2591+
2592+
Args:
2593+
alpha (float, optional): Alpha parameter. Defaults to 2.0.
2594+
beta (float, optional): Beta parameter. Defaults to 90.0.
2595+
gamma (float, optional): Gamma parameter. Defaults to 0.05.
2596+
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
2597+
"""
2598+
2599+
def __init__(
2600+
self,
2601+
alpha: float = 2.0,
2602+
beta: float = 90.0,
2603+
gamma: float = 0.05,
2604+
mode: Literal["RF", "B"] = "B",
2605+
):
2606+
2607+
self.alpha = alpha
2608+
self.beta = beta
2609+
self.gamma = gamma
2610+
self.mode = mode
2611+
2612+
# The precision to use for all computations
2613+
self.eps = np.finfo("float64").eps
2614+
2615+
# Octave instance for computing the confidence map
2616+
self.oc = Oct2Py()
2617+
2618+
def sub2ind(self, size: Tuple[int], rows: np.ndarray, cols: np.ndarray) -> np.ndarray:
2619+
"""Converts row and column subscripts into linear indices,
2620+
basically the copy of the MATLAB function of the same name.
2621+
https://www.mathworks.com/help/matlab/ref/sub2ind.html
2622+
2623+
This function is Pythonic so the indices start at 0.
2624+
2625+
Args:
2626+
size Tuple[int]: Size of the matrix
2627+
rows (np.ndarray): Row indices
2628+
cols (np.ndarray): Column indices
2629+
2630+
Returns:
2631+
indices (np.ndarray): 1-D array of linear indices
2632+
"""
2633+
indices = rows + cols * size[0]
2634+
return indices
2635+
2636+
def get_seed_and_labels(self, data : np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
2637+
"""Get the seed and label arrays for the max-flow algorithm
2638+
2639+
Args:
2640+
data: Input array
2641+
2642+
Returns:
2643+
Tuple[np.ndarray, np.ndarray]: Seed and label arrays
2644+
"""
2645+
2646+
# Seeds and labels (boundary conditions)
2647+
seeds = np.array([], dtype="float64")
2648+
labels = np.array([], dtype="float64")
2649+
2650+
# Indices for all columns
2651+
sc = np.arange(data.shape[1], dtype="float64")
2652+
2653+
# SOURCE ELEMENTS - 1st matrix row
2654+
# Indices for 1st row, it will be broadcasted with sc
2655+
sr_up = np.array([0])
2656+
seed = self.sub2ind(data.shape, sr_up, sc).astype("float64")
2657+
seed = np.unique(seed)
2658+
seeds = np.concatenate((seeds, seed))
2659+
2660+
# Label 1
2661+
label = np.ones_like(seed)
2662+
labels = np.concatenate((labels, label))
2663+
2664+
# SINK ELEMENTS - last image row
2665+
sr_down = np.ones_like(sc) * (data.shape[0] - 1)
2666+
seed = self.sub2ind(data.shape, sr_down, sc).astype("float64")
2667+
2668+
seed = np.unique(seed)
2669+
seeds = np.concatenate((seeds, seed))
2670+
2671+
# Label 2
2672+
label = np.ones_like(seed) * 2
2673+
labels = np.concatenate((labels, label))
2674+
2675+
return seeds, labels
2676+
2677+
def normalize(self, inp: np.ndarray) -> np.ndarray:
2678+
"""Normalize an array to [0, 1]"""
2679+
return (inp - np.min(inp)) / (np.ptp(inp) + self.eps)
2680+
2681+
def attenuation_weighting(self, A: np.ndarray, alpha: float) -> np.ndarray:
2682+
"""Compute attenuation weighting
2683+
2684+
Args:
2685+
A (np.ndarray): Image
2686+
alpha: Attenuation coefficient (see publication)
2687+
2688+
Returns:
2689+
W (np.ndarray): Weighting expressing depth-dependent attenuation
2690+
"""
2691+
2692+
# Create depth vector and repeat it for each column
2693+
Dw = np.linspace(0, 1, A.shape[0], dtype="float64")
2694+
Dw = np.tile(Dw.reshape(-1, 1), (1, A.shape[1]))
2695+
2696+
W = 1.0 - np.exp(-alpha * Dw) # Compute exp inline
2697+
2698+
return W
2699+
2700+
def confidence_laplacian(
2701+
self, P: np.ndarray, A: np.ndarray, beta: float, gamma: float
2702+
) -> csc_matrix: # type: ignore
2703+
"""Compute 6-Connected Laplacian for confidence estimation problem
2704+
2705+
Args:
2706+
P (np.ndarray): The index matrix of the image with boundary padding.
2707+
A (np.ndarray): The padded image.
2708+
beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function.
2709+
gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.
2710+
2711+
Returns:
2712+
L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation.
2713+
"""
2714+
2715+
m, _ = P.shape
2716+
2717+
P = P.T.flatten()
2718+
A = A.T.flatten()
2719+
2720+
p = np.where(P > 0)[0]
2721+
2722+
i = P[p] - 1 # Index vector
2723+
j = P[p] - 1 # Index vector
2724+
# Entries vector, initially for diagonal
2725+
s = np.zeros_like(p, dtype="float64")
2726+
2727+
vl = 0 # Vertical edges length
2728+
2729+
edge_templates = [
2730+
-1, # Vertical edges
2731+
1,
2732+
m - 1, # Diagonal edges
2733+
m + 1,
2734+
-m - 1,
2735+
-m + 1,
2736+
m, # Horizontal edges
2737+
-m,
2738+
]
2739+
2740+
vertical_end = None
2741+
diagonal_end = None
2742+
2743+
for iter_idx, k in enumerate(edge_templates):
2744+
2745+
Q = P[p + k]
2746+
2747+
q = np.where(Q > 0)[0]
2748+
2749+
ii = P[p[q]] - 1
2750+
i = np.concatenate((i, ii))
2751+
jj = Q[q] - 1
2752+
j = np.concatenate((j, jj))
2753+
W = np.abs(A[p[ii]] - A[p[jj]]) # Intensity derived weight
2754+
s = np.concatenate((s, W))
2755+
2756+
if iter_idx == 1:
2757+
vertical_end = s.shape[0] # Vertical edges length
2758+
elif iter_idx == 5:
2759+
diagonal_end = s.shape[0] # Diagonal edges length
2760+
2761+
# Normalize weights
2762+
s = self.normalize(s)
2763+
2764+
# Horizontal penalty
2765+
s[:vertical_end] += gamma
2766+
#s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2) since the diagonal edges are longer yet does not exist in the original code
2767+
2768+
# Normalize differences
2769+
s = self.normalize(s)
2770+
2771+
# Gaussian weighting function
2772+
s = -(
2773+
(np.exp(-beta * s, dtype="float64")) + 1.0e-6
2774+
) # --> This epsilon changes results drastically default: 1.e-6
2775+
2776+
# Create Laplacian, diagonal missing
2777+
L = csc_matrix((s, (i, j)))
2778+
2779+
# Reset diagonal weights to zero for summing
2780+
# up the weighted edge degree in the next step
2781+
L.setdiag(0)
2782+
2783+
# Weighted edge degree
2784+
D = np.abs(L.sum(axis=0).A)[0]
2785+
2786+
# Finalize Laplacian by completing the diagonal
2787+
L.setdiag(D)
2788+
2789+
return L
2790+
2791+
def confidence_estimation(self, A, seeds, labels, beta, gamma):
2792+
"""Compute confidence map
2793+
2794+
Args:
2795+
A (np.ndarray): Processed image.
2796+
seeds (np.ndarray): Seeds for the random walks framework. These are indices of the source and sink nodes.
2797+
labels (np.ndarray): Labels for the random walks framework. These represent the classes or groups of the seeds.
2798+
beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function.
2799+
gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.
2800+
2801+
Returns:
2802+
map: Confidence map which shows the probability of each pixel belonging to the source or sink group.
2803+
"""
2804+
2805+
# Index matrix with boundary padding
2806+
G = np.arange(1, A.shape[0] * A.shape[1] + 1).reshape(A.shape[1], A.shape[0]).T
2807+
pad = 1
2808+
2809+
G = np.pad(G, (pad, pad), "constant", constant_values=(0, 0))
2810+
B = np.pad(A, (pad, pad), "constant", constant_values=(0, 0))
2811+
2812+
# Laplacian
2813+
D = self.confidence_laplacian(G, B, beta, gamma)
2814+
2815+
# Select marked columns from Laplacian to create L_M and B^T
2816+
B = D[:, seeds]
2817+
2818+
# Select marked nodes to create B^T
2819+
N = np.sum(G > 0).item()
2820+
i_U = np.setdiff1d(np.arange(N), seeds.astype(int)) # Index of unmarked nodes
2821+
B = B[i_U, :]
2822+
2823+
# Remove marked nodes from Laplacian by deleting rows and cols
2824+
keep_indices = np.setdiff1d(np.arange(D.shape[0]), seeds)
2825+
D = csc_matrix(D[keep_indices, :][:, keep_indices])
2826+
2827+
# Define M matrix
2828+
M = np.zeros((seeds.shape[0], 1), dtype="float64")
2829+
M[:, 0] = labels == 1
2830+
2831+
# Right-handside (-B^T*M)
2832+
rhs = -B @ M # type: ignore
2833+
2834+
# Solve system exactly
2835+
x = self.oc.mldivide(D, rhs)[:, 0]
2836+
2837+
# Prepare output
2838+
probabilities = np.zeros((N,), dtype="float64")
2839+
# Probabilities for unmarked nodes
2840+
probabilities[i_U] = x
2841+
# Max probability for marked node
2842+
probabilities[seeds[labels == 1].astype(int)] = 1.0
2843+
2844+
# Final reshape with same size as input image (no padding)
2845+
probabilities = probabilities.reshape((A.shape[1], A.shape[0])).T
2846+
2847+
return probabilities
2848+
2849+
def __call__(self, data: np.ndarray, downsample=None) -> np.ndarray:
2850+
"""Compute the confidence map
2851+
2852+
Args:
2853+
data (np.ndarray): RF ultrasound data (one scanline per column)
2854+
2855+
Returns:
2856+
map (np.ndarray): Confidence map
2857+
"""
2858+
2859+
# Normalize data
2860+
data = data.astype("float64")
2861+
data = self.normalize(data)
2862+
2863+
if self.mode == "RF":
2864+
# MATLAB hilbert applies the Hilbert transform to columns
2865+
data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore
2866+
2867+
org_H, org_W = data.shape
2868+
if downsample is not None:
2869+
data = cv2.resize(data, (org_W // downsample, org_H // downsample), interpolation=cv2.INTER_CUBIC)
2870+
2871+
seeds, labels = self.get_seed_and_labels(data)
2872+
2873+
# Attenuation with Beer-Lambert
2874+
W = self.attenuation_weighting(data, self.alpha)
2875+
2876+
# Apply weighting directly to image
2877+
# Same as applying it individually during the formation of the
2878+
# Laplacian
2879+
data = data * W
2880+
2881+
# Find condidence values
2882+
map_ = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma)
2883+
2884+
if downsample is not None:
2885+
map_ = cv2.resize(map_, (org_W, org_H), interpolation=cv2.INTER_CUBIC)
2886+
2887+
return map_

0 commit comments

Comments
 (0)