diff --git a/python_coreml_stable_diffusion/controlnet.py b/python_coreml_stable_diffusion/controlnet.py index d13c13f6..28fa9e8c 100644 --- a/python_coreml_stable_diffusion/controlnet.py +++ b/python_coreml_stable_diffusion/controlnet.py @@ -12,41 +12,42 @@ from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map + class ControlNetConditioningEmbedding(nn.Module): + """ + Embeds conditioning input into a feature space suitable for ControlNet. + """ - def __init__( - self, - conditioning_embedding_channels, - conditioning_channels=3, - block_out_channels=(16, 32, 96, 256), - ): + def __init__(self, conditioning_embedding_channels, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)): super().__init__() - + # Initial convolution self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + # Convolutional blocks for progressive embedding + self.blocks = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + if i % 2 == 0 + else nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2) + for i, (in_channels, out_channels) in enumerate(zip(block_out_channels[:-1], block_out_channels[1:])) + ] + ) + # Final embedding convolution self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - + # Process the conditioning input through the embedding layers + embedding = F.silu(self.conv_in(conditioning)) for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) + embedding = F.silu(block(embedding)) + return self.conv_out(embedding) - return embedding class ControlNetModel(ModelMixin, ConfigMixin): + """ + Implements a ControlNet model with flexible configuration for conditioning, downsampling, and cross-attention blocks. + """ @register_to_config def __init__( @@ -54,12 +55,7 @@ def __init__( in_channels=4, flip_sin_to_cos=True, freq_shift=0, - down_block_types=( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), only_cross_attention=False, block_out_channels=(320, 640, 1280, 1280), layers_per_block=2, @@ -79,66 +75,42 @@ def __init__( ): super().__init__() - # Check inputs + # Validate inputs if len(block_out_channels) != len(down_block_types): raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + f"`block_out_channels` length must match `down_block_types` length. Received {len(block_out_channels)} and {len(down_block_types)}." ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) + # Convert scalar parameters into lists if needed + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # Register pre-hook for state dict mapping self._register_load_state_dict_pre_hook(linear_to_conv2d_map) - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) + # Initial convolution + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) - # time + # Time embedding time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - ) + self.time_embedding = TimestepEmbedding(block_out_channels[0], time_embed_dim) - # control net conditioning embedding + # ControlNet conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, ) - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) + # Down blocks + self.down_blocks = nn.ModuleList() + self.controlnet_down_blocks = nn.ModuleList([nn.Conv2d(block_out_channels[0], block_out_channels[0], kernel_size=1)]) - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - # down output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - self.controlnet_down_blocks.append(controlnet_block) - for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] @@ -160,22 +132,14 @@ def __init__( ) self.down_blocks.append(down_block) - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channel = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) - self.controlnet_mid_block = controlnet_block + # Add corresponding ControlNet blocks + for _ in range(layers_per_block + (0 if is_final_block else 1)): + self.controlnet_down_blocks.append(nn.Conv2d(output_channel, output_channel, kernel_size=1)) + # Mid block + self.controlnet_mid_block = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], kernel_size=1) self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=mid_block_channel, + in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, @@ -189,62 +153,48 @@ def __init__( ) def get_num_residuals(self): - num_res = 2 # initial sample + mid block + """ + Returns the total number of residual connections. + """ + num_res = 2 # Includes initial sample and mid block for down_block in self.down_blocks: num_res += len(down_block.resnets) if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None: num_res += len(down_block.downsamplers) return num_res - def forward( - self, - sample, - timestep, - encoder_hidden_states, - controlnet_cond, - ): - # 1. time + def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond): + """ + Forward pass through the ControlNet model. + """ + # Time embedding t_emb = self.time_proj(timestep) emb = self.time_embedding(t_emb) - # 2. pre-process + # Input convolution and conditioning sample = self.conv_in(sample) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample += controlnet_cond - # 3. down + # Down blocks down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, + hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - down_block_res_samples += res_samples - # 4. mid + # Mid block if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - ) + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - # 5. Control net blocks + # ControlNet-specific processing controlnet_down_block_res_samples = () - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) + controlnet_down_block_res_samples += (controlnet_block(down_block_res_sample),) - return down_block_res_samples, mid_block_res_sample \ No newline at end of file + # Return results + return controlnet_down_block_res_samples, self.controlnet_mid_block(sample) diff --git a/swift/StableDiffusion/tokenizer/BPETokenizer.swift b/swift/StableDiffusion/tokenizer/BPETokenizer.swift index 3f7ed9d2..13c5cc2b 100644 --- a/swift/StableDiffusion/tokenizer/BPETokenizer.swift +++ b/swift/StableDiffusion/tokenizer/BPETokenizer.swift @@ -1,5 +1,9 @@ -// For licensing see accompanying LICENSE.md file. -// Copyright (C) 2022 Apple Inc. All Rights Reserved. +// +// BPETokenizer.swift +// +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2022 Apple Inc. All Rights Reserved. +// import Foundation @@ -7,12 +11,12 @@ import Foundation @available(iOS 16.2, macOS 13.1, *) public struct BPETokenizer { /// A dictionary that maps pairs of tokens to the rank/order of the merge. - let merges: [TokenPair : Int] + let merges: [TokenPair: Int] - /// A dictionary from of tokens to identifiers. + /// A dictionary from tokens to identifiers. let vocabulary: [String: Int] - /// The token used for padding + /// The token used for padding. let padToken: String /// The start token. @@ -24,6 +28,7 @@ public struct BPETokenizer { /// The unknown token. let unknownToken: String = "<|endoftext|>" + /// The ID of the unknown token, or 0 by default. var unknownTokenID: Int { vocabulary[unknownToken, default: 0] } @@ -32,7 +37,7 @@ public struct BPETokenizer { /// /// - Parameters: /// - merges: A dictionary that maps pairs of tokens to the rank/order of the merge. - /// - vocabulary: A dictionary from of tokens to identifiers. + /// - vocabulary: A dictionary from tokens to identifiers. public init(merges: [TokenPair: Int], vocabulary: [String: Int], padToken: String = "<|endoftext|>") { self.merges = merges self.vocabulary = vocabulary @@ -45,8 +50,9 @@ public struct BPETokenizer { /// - mergesURL: The URL of a text file containing merges. /// - vocabularyURL: The URL of a JSON file containing the vocabulary. public init(mergesAt mergesURL: URL, vocabularyAt vocabularyURL: URL, padToken: String = "<|endoftext|>") throws { + // Improved error handling for file reading self.merges = try Self.readMerges(url: mergesURL) - self.vocabulary = try! Self.readVocabulary(url: vocabularyURL) + self.vocabulary = try Self.readVocabulary(url: vocabularyURL) self.padToken = padToken } @@ -57,18 +63,14 @@ public struct BPETokenizer { /// - minCount: The minimum number of tokens to return. /// - Returns: An array of tokens and an array of token identifiers. public func tokenize(input: String, minCount: Int? = nil) -> (tokens: [String], tokenIDs: [Int]) { - var tokens: [String] = [] + var tokens: [String] = [startToken] + encode(input: input) + [endToken] - tokens.append(startToken) - tokens.append(contentsOf: encode(input: input)) - tokens.append(endToken) - - // Pad if there was a min length specified + // Pad if there was a minimum length specified if let minLen = minCount, minLen > tokens.count { tokens.append(contentsOf: repeatElement(padToken, count: minLen - tokens.count)) } - let ids = tokens.map({ vocabulary[$0, default: unknownTokenID] }) + let ids = tokens.map { vocabulary[$0, default: unknownTokenID] } return (tokens: tokens, tokenIDs: ids) } @@ -82,97 +84,95 @@ public struct BPETokenizer { vocabulary.first(where: { $0.value == id })?.key } - /// Decodes a sequence of tokens into a fully formed string + /// Decodes a sequence of tokens into a fully formed string. public func decode(tokens: [String]) -> String { - String(tokens.joined()) + tokens.joined() .replacingOccurrences(of: "", with: " ") .replacingOccurrences(of: startToken, with: "") .replacingOccurrences(of: endToken, with: "") } - /// Encode an input string to a sequence of tokens + /// Encodes an input string into a sequence of tokens. func encode(input: String) -> [String] { let normalized = input.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() - let words = normalized.split(separator: " ") - return words.flatMap({ encode(word: $0) }) + return normalized.split(separator: " ").flatMap { encode(word: $0) } } - /// Encode a single word into a sequence of tokens + /// Encodes a single word into a sequence of tokens. func encode(word: Substring) -> [String] { var tokens = word.map { String($0) } if let last = tokens.indices.last { - tokens[last] = tokens[last] + "" + tokens[last] += "" } while true { let pairs = pairs(for: tokens) - let canMerge = pairs.filter { merges[$0] != nil } + let canMerge = pairs.compactMap { merges[$0] } if canMerge.isEmpty { break } - // If multiple merges are found, use the one with the lowest rank - let shouldMerge = canMerge.min { merges[$0]! < merges[$1]! }! + // Select the pair with the lowest rank + let shouldMerge = canMerge.min()! tokens = update(tokens, merging: shouldMerge) } return tokens } - /// Get the set of adjacent pairs / bigrams from a sequence of tokens + /// Gets the set of adjacent pairs/bigrams from a sequence of tokens. func pairs(for tokens: [String]) -> Set { - guard tokens.count > 1 else { - return Set() - } - - var pairs = Set(minimumCapacity: tokens.count - 1) - var prev = tokens.first! - for current in tokens.dropFirst() { - pairs.insert(TokenPair(prev, current)) - prev = current - } - return pairs + guard tokens.count > 1 else { return [] } + return Set(zip(tokens, tokens.dropFirst()).map { TokenPair($0.0, $0.1) }) } - /// Update the sequence of tokens by greedily merging instance of a specific bigram + /// Updates the sequence of tokens by greedily merging instances of a specific bigram. func update(_ tokens: [String], merging bigram: TokenPair) -> [String] { - guard tokens.count > 1 else { - return [] - } + guard tokens.count > 1 else { return tokens } var newTokens = [String]() - newTokens.reserveCapacity(tokens.count - 1) - - var index = 0 - while index < tokens.count { - let remainingTokens = tokens[index...] - if let startMatchIndex = remainingTokens.firstIndex(of: bigram.first) { - // Found a possible match, append everything before it - newTokens.append(contentsOf: tokens[index.. [TokenPair: Int] { + let data = try Data(contentsOf: url) + let lines = String(data: data, encoding: .utf8)!.split(separator: "\n") + var merges = [TokenPair: Int]() + for (index, line) in lines.enumerated() { + let tokens = line.split(separator: " ") + if tokens.count == 2 { + merges[TokenPair(String(tokens[0]), String(tokens[1]))] = index + } + } + return merges + } + + /// Reads vocabulary from a file. + static func readVocabulary(url: URL) throws -> [String: Int] { + let data = try Data(contentsOf: url) + return try JSONDecoder().decode([String: Int].self, from: data) + } } @available(iOS 16.2, macOS 13.1, *) extension BPETokenizer { - - /// A hashable tuple of strings + /// A hashable tuple of strings representing a token pair. public struct TokenPair: Hashable { let first: String let second: String