Skip to content

Commit c7e0a65

Browse files
fix: Implement log streaming for extract_lora_from_models_new_gui
This commit resolves the AttributeError related to a non-existent `custom_logging.LogStreaming` class in `kohya_gui/extract_lora_from_models_new_gui.py`. The fix includes: - Modifying the `extract_lora_new` function to become a generator. - Capturing `stdout` and `stderr` directly from the `subprocess.Popen` instance that runs the `extract_lora_from_models-new.py` script. - Iterating over the output lines from the subprocess and yielding them to update a `gr.Textbox` in the UI, providing real-time log feedback. - Ensuring the Gradio button's `.click()` event is configured to pipe the yielded output to the designated log Textbox.
1 parent 3aefbdf commit c7e0a65

File tree

1 file changed

+43
-47
lines changed

1 file changed

+43
-47
lines changed

kohya_gui/extract_lora_from_models_new_gui.py

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import subprocess
33
import os
44
import sys
5-
from kohya_gui import common_gui, custom_logging
5+
from kohya_gui import common_gui
66

77
# Define the script directory
88
scriptdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@@ -30,28 +30,6 @@ def extract_lora_new(
3030
verbose,
3131
no_metadata,
3232
):
33-
# This function will be implemented in a subsequent step
34-
print("Extract LoRA button clicked. Functionality to be implemented.")
35-
print(f"Model Tuned: {model_tuned}")
36-
print(f"Model Original: {model_org}")
37-
print(f"Save To: {save_to}")
38-
print(f"Save Precision: {save_precision}")
39-
print(f"Load Precision: {load_precision}")
40-
print(f"Dimension: {dim}")
41-
print(f"Conv Dimension: {conv_dim}")
42-
print(f"Device: {device}")
43-
print(f"SDXL: {sdxl}")
44-
print(f"v2: {v2}")
45-
print(f"v_parameterization: {v_parameterization}")
46-
print(f"Clamp Quantile: {clamp_quantile}")
47-
print(f"Min Diff: {min_diff}")
48-
print(f"Load Original Model To: {load_original_model_to}")
49-
print(f"Load Tuned Model To: {load_tuned_model_to}")
50-
print(f"Dynamic Method: {dynamic_method}")
51-
print(f"Dynamic Param: {dynamic_param}")
52-
print(f"Verbose: {verbose}")
53-
print(f"No Metadata: {no_metadata}")
54-
5533
# Construct the command
5634
command = [
5735
PYTHON,
@@ -67,13 +45,9 @@ def extract_lora_new(
6745
command.extend(["--save_to", save_to])
6846
if save_precision:
6947
command.extend(["--save_precision", save_precision])
70-
if load_precision and load_precision != "None": # Handle None case
48+
if load_precision and load_precision != "None":
7149
command.extend(["--load_precision", load_precision])
7250
command.extend(["--dim", str(dim)])
73-
# conv_dim defaults to dim if not provided or 0 in the script,
74-
# so we only add it if it's explicitly set to a non-zero value by the user that is different from dim,
75-
# or if the script requires it even if it's the same as dim (need to check script logic)
76-
# For now, pass it if it's > 0. The script itself handles the default.
7751
if conv_dim > 0:
7852
command.extend(["--conv_dim", str(conv_dim)])
7953
if device:
@@ -82,7 +56,7 @@ def extract_lora_new(
8256
command.append("--sdxl")
8357
if v2:
8458
command.append("--v2")
85-
if v_parameterization: # Only relevant if v2 is true, but script might handle it
59+
if v_parameterization:
8660
command.append("--v_parameterization")
8761
command.extend(["--clamp_quantile", str(clamp_quantile)])
8862
command.extend(["--min_diff", str(min_diff)])
@@ -95,9 +69,6 @@ def extract_lora_new(
9569

9670
if dynamic_method and dynamic_method != "None":
9771
command.extend(["--dynamic_method", dynamic_method])
98-
# dynamic_param is only needed for certain methods, script should handle if it's missing
99-
# but we should only pass it if a method requiring it is selected.
100-
# This requires knowing which methods need params. Assuming for now all non-"None" methods might use it if provided.
10172
command.extend(["--dynamic_param", str(dynamic_param)])
10273

10374
if verbose:
@@ -106,36 +77,61 @@ def extract_lora_new(
10677
command.append("--no_metadata")
10778

10879
# Run the script
109-
print(f"Running command: {' '.join(command)}")
110-
111-
log_stream = custom_logging.LogStreaming()
80+
print(f"Running command: {' '.join(command)}") # Log to console
11281

82+
all_logs = ""
11383
try:
84+
# Use Popen to capture stdout and stderr
11485
process = subprocess.Popen(
11586
command,
11687
stdout=subprocess.PIPE,
117-
stderr=subprocess.STDOUT,
88+
stderr=subprocess.PIPE, # Capture stderr separately
11889
text=True,
11990
bufsize=1,
12091
universal_newlines=True,
12192
)
12293

123-
for line in iter(process.stdout.readline, ''):
124-
log_stream.log(line.strip())
125-
# Optionally print to console during development/debugging
126-
# print(line.strip())
127-
process.wait()
94+
# Stream stdout
95+
if process.stdout:
96+
for line in iter(process.stdout.readline, ''):
97+
line = line.strip()
98+
if line:
99+
print(line) # Log to console
100+
all_logs += line + '\n'
101+
yield all_logs # Yield accumulated logs
128102

103+
# Stream stderr
104+
# After stdout is exhausted, check stderr
105+
stderr_output = ""
106+
if process.stderr:
107+
for line in iter(process.stderr.readline, ''):
108+
line = line.strip()
109+
if line:
110+
print(f"Error: {line}") # Log to console
111+
stderr_output += f"ERROR: {line}\n"
112+
113+
process.wait() # Wait for the process to complete
114+
115+
# Append any stderr output to all_logs after stdout and process completion
116+
if stderr_output:
117+
all_logs += "\n--- Errors/Warnings ---\n" + stderr_output
118+
yield all_logs
119+
129120
if process.returncode == 0:
130-
return "LoRA extraction completed successfully.", log_stream.get_logs()
121+
all_logs += "\nLoRA extraction completed successfully."
131122
else:
132-
return f"Error during LoRA extraction. Return code: {process.returncode}", log_stream.get_logs()
123+
all_logs += f"\nError during LoRA extraction. Return code: {process.returncode}"
124+
125+
yield all_logs
133126

134127
except Exception as e:
135-
return f"Failed to run script: {e}", log_stream.get_logs()
128+
all_logs += f"\nFailed to run script: {e}"
129+
yield all_logs
136130
finally:
137-
log_stream.close()
138-
131+
if process.stdout:
132+
process.stdout.close()
133+
if process.stderr:
134+
process.stderr.close()
139135

140136
# Gradio UI function
141137
def gradio_extract_lora_new_tab(headless=False):
@@ -364,7 +360,7 @@ def gradio_extract_lora_new_tab(headless=False):
364360
verbose,
365361
no_metadata,
366362
],
367-
outputs=[gr.Textbox(label="Status", interactive=False), output_logs], # Two outputs: status and logs
363+
outputs=[output_logs], # Output to the log textbox
368364
show_progress="full"
369365
)
370366

0 commit comments

Comments
 (0)