@@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss):
811
811
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
812
812
813
813
``gamma`` and ``lambda_focal`` are only used for the focal loss.
814
- ``include_background``, ``weight`` and ``reduction `` are used for both losses
814
+ ``include_background``, ``weight``, ``reduction``, and ``alpha `` are used for both losses,
815
815
and other parameters are only used for dice loss.
816
816
817
817
"""
@@ -835,6 +835,7 @@ def __init__(
835
835
gamma : float = 2.0 ,
836
836
focal_weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
837
837
weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
838
+ alpha : float | None = None ,
838
839
lambda_dice : float = 1.0 ,
839
840
lambda_focal : float = 1.0 ,
840
841
) -> None :
@@ -867,6 +868,7 @@ def __init__(
867
868
weight: weights to apply to the voxels of each class. If None no weights are applied.
868
869
The input can be a single value (same weight for all classes), a sequence of values (the length
869
870
of the sequence should be the same as the number of classes).
871
+ alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. Defaults to None.
870
872
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
871
873
Defaults to 1.0.
872
874
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
@@ -890,7 +892,12 @@ def __init__(
890
892
weight = weight ,
891
893
)
892
894
self .focal = FocalLoss (
893
- include_background = include_background , to_onehot_y = False , gamma = gamma , weight = weight , reduction = reduction
895
+ include_background = include_background ,
896
+ to_onehot_y = False ,
897
+ gamma = gamma ,
898
+ weight = weight ,
899
+ alpha = alpha ,
900
+ reduction = reduction
894
901
)
895
902
if lambda_dice < 0.0 :
896
903
raise ValueError ("lambda_dice should be no less than 0.0." )
0 commit comments