|
17 | 17 | from abc import abstractmethod
|
18 | 18 | from collections.abc import Callable, Iterable, Sequence
|
19 | 19 | from functools import partial
|
20 |
| -from typing import Any |
| 20 | +from typing import Any, Tuple, Literal |
21 | 21 | from warnings import warn
|
22 | 22 |
|
23 | 23 | import numpy as np
|
|
37 | 37 | from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
|
38 | 38 |
|
39 | 39 | skimage, _ = optional_import("skimage", "0.19.0", min_version)
|
| 40 | +cv2, _ = optional_import("cv2") |
| 41 | +Oct2Py, _ = optional_import("oct2py", "5.6.0", min_version, "Oct2Py") |
| 42 | +csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix") |
| 43 | +hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert") |
40 | 44 |
|
41 | 45 | __all__ = [
|
42 | 46 | "RandGaussianNoise",
|
|
77 | 81 | "RandIntensityRemap",
|
78 | 82 | "ForegroundMask",
|
79 | 83 | "ComputeHoVerMaps",
|
| 84 | + "UltrasoundConfidenceMap", |
80 | 85 | ]
|
81 | 86 |
|
82 | 87 |
|
@@ -2577,3 +2582,306 @@ def __call__(self, mask: NdarrayOrTensor):
|
2577 | 2582 |
|
2578 | 2583 | hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta())
|
2579 | 2584 | return hv_maps
|
| 2585 | + |
| 2586 | +class UltrasoundConfidenceMap(Transform): |
| 2587 | + """Compute confidence map from an ultrasound image. |
| 2588 | + This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005. |
| 2589 | + It generates a confidence map by setting source and sink points in the image and computing the probability |
| 2590 | + for random walks to reach the source for each pixel. |
| 2591 | +
|
| 2592 | + Args: |
| 2593 | + alpha (float, optional): Alpha parameter. Defaults to 2.0. |
| 2594 | + beta (float, optional): Beta parameter. Defaults to 90.0. |
| 2595 | + gamma (float, optional): Gamma parameter. Defaults to 0.05. |
| 2596 | + mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'. |
| 2597 | + """ |
| 2598 | + |
| 2599 | + def __init__( |
| 2600 | + self, |
| 2601 | + alpha: float = 2.0, |
| 2602 | + beta: float = 90.0, |
| 2603 | + gamma: float = 0.05, |
| 2604 | + mode: Literal["RF", "B"] = "B", |
| 2605 | + ): |
| 2606 | + |
| 2607 | + self.alpha = alpha |
| 2608 | + self.beta = beta |
| 2609 | + self.gamma = gamma |
| 2610 | + self.mode = mode |
| 2611 | + |
| 2612 | + # The precision to use for all computations |
| 2613 | + self.eps = np.finfo("float64").eps |
| 2614 | + |
| 2615 | + # Octave instance for computing the confidence map |
| 2616 | + self.oc = Oct2Py() |
| 2617 | + |
| 2618 | + def sub2ind(self, size: Tuple[int], rows: np.ndarray, cols: np.ndarray) -> np.ndarray: |
| 2619 | + """Converts row and column subscripts into linear indices, |
| 2620 | + basically the copy of the MATLAB function of the same name. |
| 2621 | + https://www.mathworks.com/help/matlab/ref/sub2ind.html |
| 2622 | +
|
| 2623 | + This function is Pythonic so the indices start at 0. |
| 2624 | +
|
| 2625 | + Args: |
| 2626 | + size Tuple[int]: Size of the matrix |
| 2627 | + rows (np.ndarray): Row indices |
| 2628 | + cols (np.ndarray): Column indices |
| 2629 | +
|
| 2630 | + Returns: |
| 2631 | + indices (np.ndarray): 1-D array of linear indices |
| 2632 | + """ |
| 2633 | + indices = rows + cols * size[0] |
| 2634 | + return indices |
| 2635 | + |
| 2636 | + def get_seed_and_labels(self, data : np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| 2637 | + """Get the seed and label arrays for the max-flow algorithm |
| 2638 | +
|
| 2639 | + Args: |
| 2640 | + data: Input array |
| 2641 | +
|
| 2642 | + Returns: |
| 2643 | + Tuple[np.ndarray, np.ndarray]: Seed and label arrays |
| 2644 | + """ |
| 2645 | + |
| 2646 | + # Seeds and labels (boundary conditions) |
| 2647 | + seeds = np.array([], dtype="float64") |
| 2648 | + labels = np.array([], dtype="float64") |
| 2649 | + |
| 2650 | + # Indices for all columns |
| 2651 | + sc = np.arange(data.shape[1], dtype="float64") |
| 2652 | + |
| 2653 | + # SOURCE ELEMENTS - 1st matrix row |
| 2654 | + # Indices for 1st row, it will be broadcasted with sc |
| 2655 | + sr_up = np.array([0]) |
| 2656 | + seed = self.sub2ind(data.shape, sr_up, sc).astype("float64") |
| 2657 | + seed = np.unique(seed) |
| 2658 | + seeds = np.concatenate((seeds, seed)) |
| 2659 | + |
| 2660 | + # Label 1 |
| 2661 | + label = np.ones_like(seed) |
| 2662 | + labels = np.concatenate((labels, label)) |
| 2663 | + |
| 2664 | + # SINK ELEMENTS - last image row |
| 2665 | + sr_down = np.ones_like(sc) * (data.shape[0] - 1) |
| 2666 | + seed = self.sub2ind(data.shape, sr_down, sc).astype("float64") |
| 2667 | + |
| 2668 | + seed = np.unique(seed) |
| 2669 | + seeds = np.concatenate((seeds, seed)) |
| 2670 | + |
| 2671 | + # Label 2 |
| 2672 | + label = np.ones_like(seed) * 2 |
| 2673 | + labels = np.concatenate((labels, label)) |
| 2674 | + |
| 2675 | + return seeds, labels |
| 2676 | + |
| 2677 | + def normalize(self, inp: np.ndarray) -> np.ndarray: |
| 2678 | + """Normalize an array to [0, 1]""" |
| 2679 | + return (inp - np.min(inp)) / (np.ptp(inp) + self.eps) |
| 2680 | + |
| 2681 | + def attenuation_weighting(self, A: np.ndarray, alpha: float) -> np.ndarray: |
| 2682 | + """Compute attenuation weighting |
| 2683 | +
|
| 2684 | + Args: |
| 2685 | + A (np.ndarray): Image |
| 2686 | + alpha: Attenuation coefficient (see publication) |
| 2687 | +
|
| 2688 | + Returns: |
| 2689 | + W (np.ndarray): Weighting expressing depth-dependent attenuation |
| 2690 | + """ |
| 2691 | + |
| 2692 | + # Create depth vector and repeat it for each column |
| 2693 | + Dw = np.linspace(0, 1, A.shape[0], dtype="float64") |
| 2694 | + Dw = np.tile(Dw.reshape(-1, 1), (1, A.shape[1])) |
| 2695 | + |
| 2696 | + W = 1.0 - np.exp(-alpha * Dw) # Compute exp inline |
| 2697 | + |
| 2698 | + return W |
| 2699 | + |
| 2700 | + def confidence_laplacian( |
| 2701 | + self, P: np.ndarray, A: np.ndarray, beta: float, gamma: float |
| 2702 | + ) -> csc_matrix: # type: ignore |
| 2703 | + """Compute 6-Connected Laplacian for confidence estimation problem |
| 2704 | +
|
| 2705 | + Args: |
| 2706 | + P (np.ndarray): The index matrix of the image with boundary padding. |
| 2707 | + A (np.ndarray): The padded image. |
| 2708 | + beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function. |
| 2709 | + gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. |
| 2710 | +
|
| 2711 | + Returns: |
| 2712 | + L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation. |
| 2713 | + """ |
| 2714 | + |
| 2715 | + m, _ = P.shape |
| 2716 | + |
| 2717 | + P = P.T.flatten() |
| 2718 | + A = A.T.flatten() |
| 2719 | + |
| 2720 | + p = np.where(P > 0)[0] |
| 2721 | + |
| 2722 | + i = P[p] - 1 # Index vector |
| 2723 | + j = P[p] - 1 # Index vector |
| 2724 | + # Entries vector, initially for diagonal |
| 2725 | + s = np.zeros_like(p, dtype="float64") |
| 2726 | + |
| 2727 | + vl = 0 # Vertical edges length |
| 2728 | + |
| 2729 | + edge_templates = [ |
| 2730 | + -1, # Vertical edges |
| 2731 | + 1, |
| 2732 | + m - 1, # Diagonal edges |
| 2733 | + m + 1, |
| 2734 | + -m - 1, |
| 2735 | + -m + 1, |
| 2736 | + m, # Horizontal edges |
| 2737 | + -m, |
| 2738 | + ] |
| 2739 | + |
| 2740 | + vertical_end = None |
| 2741 | + diagonal_end = None |
| 2742 | + |
| 2743 | + for iter_idx, k in enumerate(edge_templates): |
| 2744 | + |
| 2745 | + Q = P[p + k] |
| 2746 | + |
| 2747 | + q = np.where(Q > 0)[0] |
| 2748 | + |
| 2749 | + ii = P[p[q]] - 1 |
| 2750 | + i = np.concatenate((i, ii)) |
| 2751 | + jj = Q[q] - 1 |
| 2752 | + j = np.concatenate((j, jj)) |
| 2753 | + W = np.abs(A[p[ii]] - A[p[jj]]) # Intensity derived weight |
| 2754 | + s = np.concatenate((s, W)) |
| 2755 | + |
| 2756 | + if iter_idx == 1: |
| 2757 | + vertical_end = s.shape[0] # Vertical edges length |
| 2758 | + elif iter_idx == 5: |
| 2759 | + diagonal_end = s.shape[0] # Diagonal edges length |
| 2760 | + |
| 2761 | + # Normalize weights |
| 2762 | + s = self.normalize(s) |
| 2763 | + |
| 2764 | + # Horizontal penalty |
| 2765 | + s[:vertical_end] += gamma |
| 2766 | + #s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2) since the diagonal edges are longer yet does not exist in the original code |
| 2767 | + |
| 2768 | + # Normalize differences |
| 2769 | + s = self.normalize(s) |
| 2770 | + |
| 2771 | + # Gaussian weighting function |
| 2772 | + s = -( |
| 2773 | + (np.exp(-beta * s, dtype="float64")) + 1.0e-6 |
| 2774 | + ) # --> This epsilon changes results drastically default: 1.e-6 |
| 2775 | + |
| 2776 | + # Create Laplacian, diagonal missing |
| 2777 | + L = csc_matrix((s, (i, j))) |
| 2778 | + |
| 2779 | + # Reset diagonal weights to zero for summing |
| 2780 | + # up the weighted edge degree in the next step |
| 2781 | + L.setdiag(0) |
| 2782 | + |
| 2783 | + # Weighted edge degree |
| 2784 | + D = np.abs(L.sum(axis=0).A)[0] |
| 2785 | + |
| 2786 | + # Finalize Laplacian by completing the diagonal |
| 2787 | + L.setdiag(D) |
| 2788 | + |
| 2789 | + return L |
| 2790 | + |
| 2791 | + def confidence_estimation(self, A, seeds, labels, beta, gamma): |
| 2792 | + """Compute confidence map |
| 2793 | +
|
| 2794 | + Args: |
| 2795 | + A (np.ndarray): Processed image. |
| 2796 | + seeds (np.ndarray): Seeds for the random walks framework. These are indices of the source and sink nodes. |
| 2797 | + labels (np.ndarray): Labels for the random walks framework. These represent the classes or groups of the seeds. |
| 2798 | + beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function. |
| 2799 | + gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian. |
| 2800 | +
|
| 2801 | + Returns: |
| 2802 | + map: Confidence map which shows the probability of each pixel belonging to the source or sink group. |
| 2803 | + """ |
| 2804 | + |
| 2805 | + # Index matrix with boundary padding |
| 2806 | + G = np.arange(1, A.shape[0] * A.shape[1] + 1).reshape(A.shape[1], A.shape[0]).T |
| 2807 | + pad = 1 |
| 2808 | + |
| 2809 | + G = np.pad(G, (pad, pad), "constant", constant_values=(0, 0)) |
| 2810 | + B = np.pad(A, (pad, pad), "constant", constant_values=(0, 0)) |
| 2811 | + |
| 2812 | + # Laplacian |
| 2813 | + D = self.confidence_laplacian(G, B, beta, gamma) |
| 2814 | + |
| 2815 | + # Select marked columns from Laplacian to create L_M and B^T |
| 2816 | + B = D[:, seeds] |
| 2817 | + |
| 2818 | + # Select marked nodes to create B^T |
| 2819 | + N = np.sum(G > 0).item() |
| 2820 | + i_U = np.setdiff1d(np.arange(N), seeds.astype(int)) # Index of unmarked nodes |
| 2821 | + B = B[i_U, :] |
| 2822 | + |
| 2823 | + # Remove marked nodes from Laplacian by deleting rows and cols |
| 2824 | + keep_indices = np.setdiff1d(np.arange(D.shape[0]), seeds) |
| 2825 | + D = csc_matrix(D[keep_indices, :][:, keep_indices]) |
| 2826 | + |
| 2827 | + # Define M matrix |
| 2828 | + M = np.zeros((seeds.shape[0], 1), dtype="float64") |
| 2829 | + M[:, 0] = labels == 1 |
| 2830 | + |
| 2831 | + # Right-handside (-B^T*M) |
| 2832 | + rhs = -B @ M # type: ignore |
| 2833 | + |
| 2834 | + # Solve system exactly |
| 2835 | + x = self.oc.mldivide(D, rhs)[:, 0] |
| 2836 | + |
| 2837 | + # Prepare output |
| 2838 | + probabilities = np.zeros((N,), dtype="float64") |
| 2839 | + # Probabilities for unmarked nodes |
| 2840 | + probabilities[i_U] = x |
| 2841 | + # Max probability for marked node |
| 2842 | + probabilities[seeds[labels == 1].astype(int)] = 1.0 |
| 2843 | + |
| 2844 | + # Final reshape with same size as input image (no padding) |
| 2845 | + probabilities = probabilities.reshape((A.shape[1], A.shape[0])).T |
| 2846 | + |
| 2847 | + return probabilities |
| 2848 | + |
| 2849 | + def __call__(self, data: np.ndarray, downsample=None) -> np.ndarray: |
| 2850 | + """Compute the confidence map |
| 2851 | +
|
| 2852 | + Args: |
| 2853 | + data (np.ndarray): RF ultrasound data (one scanline per column) |
| 2854 | +
|
| 2855 | + Returns: |
| 2856 | + map (np.ndarray): Confidence map |
| 2857 | + """ |
| 2858 | + |
| 2859 | + # Normalize data |
| 2860 | + data = data.astype("float64") |
| 2861 | + data = self.normalize(data) |
| 2862 | + |
| 2863 | + if self.mode == "RF": |
| 2864 | + # MATLAB hilbert applies the Hilbert transform to columns |
| 2865 | + data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore |
| 2866 | + |
| 2867 | + org_H, org_W = data.shape |
| 2868 | + if downsample is not None: |
| 2869 | + data = cv2.resize(data, (org_W // downsample, org_H // downsample), interpolation=cv2.INTER_CUBIC) |
| 2870 | + |
| 2871 | + seeds, labels = self.get_seed_and_labels(data) |
| 2872 | + |
| 2873 | + # Attenuation with Beer-Lambert |
| 2874 | + W = self.attenuation_weighting(data, self.alpha) |
| 2875 | + |
| 2876 | + # Apply weighting directly to image |
| 2877 | + # Same as applying it individually during the formation of the |
| 2878 | + # Laplacian |
| 2879 | + data = data * W |
| 2880 | + |
| 2881 | + # Find condidence values |
| 2882 | + map_ = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma) |
| 2883 | + |
| 2884 | + if downsample is not None: |
| 2885 | + map_ = cv2.resize(map_, (org_W, org_H), interpolation=cv2.INTER_CUBIC) |
| 2886 | + |
| 2887 | + return map_ |
0 commit comments