Skip to content

Commit 8693ada

Browse files
committed
Add alpha parameter to DiceFocalLoss
1 parent 4029c42 commit 8693ada

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

monai/losses/dice.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss):
811811
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
812812
813813
``gamma`` and ``lambda_focal`` are only used for the focal loss.
814-
``include_background``, ``weight`` and ``reduction`` are used for both losses
814+
``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,
815815
and other parameters are only used for dice loss.
816816
817817
"""
@@ -835,6 +835,7 @@ def __init__(
835835
gamma: float = 2.0,
836836
focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
837837
weight: Sequence[float] | float | int | torch.Tensor | None = None,
838+
alpha: float | None = None,
838839
lambda_dice: float = 1.0,
839840
lambda_focal: float = 1.0,
840841
) -> None:
@@ -867,6 +868,7 @@ def __init__(
867868
weight: weights to apply to the voxels of each class. If None no weights are applied.
868869
The input can be a single value (same weight for all classes), a sequence of values (the length
869870
of the sequence should be the same as the number of classes).
871+
alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. Defaults to None.
870872
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
871873
Defaults to 1.0.
872874
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
@@ -890,7 +892,12 @@ def __init__(
890892
weight=weight,
891893
)
892894
self.focal = FocalLoss(
893-
include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
895+
include_background=include_background,
896+
to_onehot_y=False,
897+
gamma=gamma,
898+
weight=weight,
899+
alpha=alpha,
900+
reduction=reduction
894901
)
895902
if lambda_dice < 0.0:
896903
raise ValueError("lambda_dice should be no less than 0.0.")

0 commit comments

Comments
 (0)