Skip to content

Potential Bug in get_position_from_periods Function When Iteration Exceeds All Periods when it is used in CosineAnnealingRestartLR #728

Open
@lsyzc

Description

@lsyzc

Issue: Bug in get_position_from_periods Function When Iteration Exceeds All Periods

Description

There is a potential bug in the get_position_from_periods function within the CosineAnnealingRestartLR class. The function is designed to return the index of the right-closest number in the cumulative period list based on the current iteration. However, if the iteration exceeds all values in the cumulative period list, the function does not handle this case properly, which can lead to unexpected behavior.

Code Snippet

def get_position_from_periods(iteration, cumulative_period):
    """Get the position from a period list.

    It will return the index of the right-closest number in the period list.
    For example, the cumulative_period = [100, 200, 300, 400],
    if iteration == 50, return 0;
    if iteration == 210, return 2;
    if iteration == 300, return 2.

    Args:
        iteration (int): Current iteration.
        cumulative_period (list[int]): Cumulative period list.

    Returns:
        int: The position of the right-closest number in the period list.
    """
    for i, period in enumerate(cumulative_period):
        if iteration <= period:
            return i
    return len(cumulative_period) - 1  # If iteration exceeds all periods, return the last period index

Problem

If the iteration value exceeds all values in the cumulative_period list, the function will not return any value within the loop, leading to potential issues in the learning rate scheduler.

Proposed Solution

To handle this case, we should ensure that the function returns the index of the last period if the iteration exceeds all values in the cumulative_period list. This can be achieved by adding a default return statement after the loop.

Suggested Fix

Add a default return statement to handle cases where the iteration exceeds all periods:

def get_position_from_periods(iteration, cumulative_period):
    """Get the position from a period list.

    It will return the index of the right-closest number in the period list.
    For example, the cumulative_period = [100, 200, 300, 400],
    if iteration == 50, return 0;
    if iteration == 210, return 2;
    if iteration == 300, return 2.

    Args:
        iteration (int): Current iteration.
        cumulative_period (list[int]): Cumulative period list.

    Returns:
        int: The position of the right-closest number in the period list.
    """
    for i, period in enumerate(cumulative_period):
        if iteration <= period:
            return i
    return len(cumulative_period) - 1  # If iteration exceeds all periods, return the last period index

Impact

This fix ensures that the function always returns a valid index, even when the iteration exceeds all values in the cumulative period list, preventing potential issues in the learning rate scheduler.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions