Skip to content

Add parse blocks to allow setting values per block #2032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: sd3
Choose a base branch
from

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Apr 5, 2025

This could allow setting values per block. Handles a lot of variations.

I tried to implement this but it required a lot of changes. So providing the function to parse the block values.

For example would handle all these options for values:

1:1.0,2:0.5,10-18:0.4
1-18:0.4
1.0,0:0.5
[0.0,0.0,1.0,1.0,0.9,0.8,0.6,0.0,0.0,1.0,1.0,0.9,0.8,0.6,0.0,0.0,1.0,1.0,0.9]
1.0
1-18:cos
0-10:cos(0,1.0),10-18:sin
1-10:sin
1:0.4,10-18:linear
5-15:reverse_linear

Then create a list of values for each block. Would be similar to behavior in networks.lora but be more flexible in what sort of values it handles. Could work for block alpha, dim/rank, dropout (neuron, rank, module), block LR, LoRA+ LR.

Some issues with integration is changing a single value to a list of values, but also handling non-block values like embedders, which would require more indices or a different way to handle it.

Some of the math for the functions may not be accurate and we can add further tests for those cases.

Pattern Pattern
Pattern_0-10_cos_0_1_0__10-18_sin Pattern_1-18_0_4
Pattern__0_0_0_0_1_0_1_0_0_9_0_8_0_6_0_0_0_0_1_0_1_0_0_9_0_8_0_6_0_0_0_0_1_0_1_0_0_9_ Pattern_1-18_cos

Plotting the blocks:

import matplotlib.pyplot as plt

def plot_blocks(input_str, length=19, title=None, save_dir="plots"):
    """
    Parse the input string and plot the resulting values on a line graph.

    Args:
        input_str (str): The input string after the '=' sign
        length (int): The desired length of the output list (default: 19)
        title (str): Optional title for the plot
    """
    # Parse the input string
    values = parse_blocks(input_str, length)

    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(range(length), values, marker="o", linestyle="-", markersize=5)

    # Add labels and title
    plt.xlabel("Index")
    plt.ylabel("Value")
    if title:
        plt.title(title)
    else:
        plt.title(f"Pattern: {input_str}")

    # Add grid for better readability
    plt.grid(True, linestyle="--", alpha=0.7)

    # Set y-axis limits
    plt.ylim(-0.1, 1.1)

    # Mark the indices on the x-axis
    plt.xticks(range(0, length, 1))

    # Create save directory if it doesn't exist
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)
    
    # Generate a filename from the title or input string
    if title:
        filename = title
    else:
        filename = f'Pattern_{input_str}'
    
    # Clean the filename to remove invalid characters
    filename = re.sub(r'[^\w\-_]', '_', filename)
    
    # Save the plot
    filepath = save_path / f"{filename}.png"
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    print(f"Plot saved to: {filepath}")

    # Show the plot
    plt.tight_layout()
    plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant