diff --git a/css/style.css b/css/style.css
index 010c8e7f6..c832a8537 100644
--- a/css/style.css
+++ b/css/style.css
@@ -197,7 +197,8 @@
display: none;
}
-#stylePreviewOverlay {
+#stylePreviewOverlay,
+#modelPreviewOverlay {
opacity: 0;
pointer-events: none;
width: 128px;
@@ -215,6 +216,20 @@
transition: transform 0.1s ease, opacity 0.3s ease;
}
-#stylePreviewOverlay.lower-half {
+#stylePreviewOverlay.lower-half,
+#modelPreviewOverlay.lower-half {
transform: translate(-140px, -140px);
}
+
+#modelPreviewOverlay {
+ z-index: 10000000000 !important;
+ justify-content: center;
+ display: flex;
+ align-items: center;
+ background-color: rgba(0, 0, 0, 0.7) !important;
+ color: white;
+ padding: 5px;
+ max-width: 100%;
+ overflow-wrap: break-word;
+ word-break: break-all;
+}
\ No newline at end of file
diff --git a/javascript/script.js b/javascript/script.js
index 8f4cac58f..864cedda6 100644
--- a/javascript/script.js
+++ b/javascript/script.js
@@ -120,6 +120,7 @@ document.addEventListener("DOMContentLoaded", function() {
});
mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
initStylePreviewOverlay();
+ initModelPreviewOverlay();
});
/**
@@ -146,38 +147,141 @@ document.addEventListener('keydown', function(e) {
}
});
-function initStylePreviewOverlay() {
- let overlayVisible = false;
- const samplesPath = document.querySelector("meta[name='samples-path']").getAttribute("content")
+// Utility functions
+function formatImagePath(name, templateImagePath, replacedValue = "fooocus_v2") {
+ return templateImagePath.replace(replacedValue, name.toLowerCase().replaceAll(" ", "_")).replaceAll("\\", "\\\\");
+}
+
+function createOverlay(id) {
const overlay = document.createElement('div');
- overlay.id = 'stylePreviewOverlay';
+ overlay.id = id;
document.body.appendChild(overlay);
- document.addEventListener('mouseover', function(e) {
- const label = e.target.closest('.style_selections label');
+ return overlay;
+}
+
+function setImageBackground(overlay, url) {
+ unsetOverlayAsTooltip(overlay)
+ overlay.style.backgroundImage = `url("${url}")`;
+}
+
+function setOverlayAsTooltip(overlay, altText) {
+ // Set the text content and any dynamic styles
+ overlay.textContent = altText;
+ overlay.style.width = 'fit-content';
+ overlay.style.height = 'fit-content';
+ // Note: Other styles are already set via CSS
+}
+
+function unsetOverlayAsTooltip(overlay) {
+ // Clear the text content and reset any dynamic styles
+ overlay.textContent = '';
+ overlay.style.width = '128px';
+ overlay.style.height = '128px';
+ // Note: Other styles are managed via CSS
+}
+
+function handleMouseMove(overlay) {
+ return function(e) {
+ if (overlay.style.opacity !== "1") return;
+ overlay.style.left = `${e.clientX}px`;
+ overlay.style.top = `${e.clientY}px`;
+ overlay.className = e.clientY > window.innerHeight / 2 ? "lower-half" : "upper-half";
+ };
+}
+
+// Image path retrieval for models
+const getModelImagePath = selectedItemText => {
+ selectedItemText = selectedItemText.replace("✓\n", "")
+
+ let imagePath = null;
+
+ if (previewsCheckpoint)
+ imagePath = previewsCheckpoint[selectedItemText]
+
+ if (previewsLora && !imagePath)
+ imagePath = previewsLora[selectedItemText]
+
+ return imagePath;
+};
+
+// Mouse over handlers for different overlays
+function handleMouseOverModelPreviewOverlay(overlay, elementSelector, templateImagePath) {
+ return function(e) {
+ const targetElement = e.target.closest(elementSelector);
+ if (!targetElement) return;
+
+ targetElement.removeEventListener("mouseout", onMouseLeave);
+ targetElement.addEventListener("mouseout", onMouseLeave);
+
+ overlay.style.opacity = "1";
+ const selectedItemText = targetElement.innerText;
+ if (selectedItemText) {
+ let imagePath = getModelImagePath(selectedItemText);
+ if (imagePath) {
+ imagePath = formatImagePath(imagePath, templateImagePath, "sdxl_styles/samples/fooocus_v2.jpg");
+ setImageBackground(overlay, imagePath);
+ } else {
+ setOverlayAsTooltip(overlay, selectedItemText);
+ }
+ }
+
+ function onMouseLeave() {
+ overlay.style.opacity = "0";
+ overlay.style.backgroundImage = "";
+ targetElement.removeEventListener("mouseout", onMouseLeave);
+ }
+ };
+}
+
+function handleMouseOverStylePreviewOverlay(overlay, elementSelector, templateImagePath) {
+ return function(e) {
+ const label = e.target.closest(elementSelector);
if (!label) return;
+
label.removeEventListener("mouseout", onMouseLeave);
label.addEventListener("mouseout", onMouseLeave);
- overlayVisible = true;
+
overlay.style.opacity = "1";
+
const originalText = label.querySelector("span").getAttribute("data-original-text");
- const name = originalText || label.querySelector("span").textContent;
- overlay.style.backgroundImage = `url("${samplesPath.replace(
- "fooocus_v2",
- name.toLowerCase().replaceAll(" ", "_")
- ).replaceAll("\\", "\\\\")}")`;
+ let name = originalText || label.querySelector("span").textContent;
+ let imagePath = formatImagePath(name, templateImagePath);
+
+ overlay.style.backgroundImage = `url("${imagePath}")`;
+
function onMouseLeave() {
- overlayVisible = false;
overlay.style.opacity = "0";
overlay.style.backgroundImage = "";
label.removeEventListener("mouseout", onMouseLeave);
}
- });
- document.addEventListener('mousemove', function(e) {
- if(!overlayVisible) return;
- overlay.style.left = `${e.clientX}px`;
- overlay.style.top = `${e.clientY}px`;
- overlay.className = e.clientY > window.innerHeight / 2 ? "lower-half" : "upper-half";
- });
+ };
+}
+
+// Initialization functions for different overlays
+function initModelPreviewOverlay() {
+ const templateImagePath = document.querySelector("meta[name='samples-path']").getAttribute("content");
+ const modelOverlay = createOverlay('modelPreviewOverlay');
+
+ document.addEventListener('mouseover', handleMouseOverModelPreviewOverlay(
+ modelOverlay,
+ '.model_selections .item',
+ templateImagePath
+ ));
+
+ document.addEventListener('mousemove', handleMouseMove(modelOverlay));
+}
+
+function initStylePreviewOverlay() {
+ const templateImagePath = document.querySelector("meta[name='samples-path']").getAttribute("content");
+ const styleOverlay = createOverlay('stylePreviewOverlay');
+
+ document.addEventListener('mouseover', handleMouseOverStylePreviewOverlay(
+ styleOverlay,
+ '.style_selections label',
+ templateImagePath
+ ));
+
+ document.addEventListener('mousemove', handleMouseMove(styleOverlay));
}
/**
diff --git a/modules/async_worker.py b/modules/async_worker.py
index b2af67126..15dbf0776 100644
--- a/modules/async_worker.py
+++ b/modules/async_worker.py
@@ -42,6 +42,7 @@ def worker():
from modules.util import remove_empty_str, HWC3, resize_image, \
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate
from modules.upscaler import perform_upscale
+ from modules.model_previewer import add_preview_by_attempt
try:
async_gradio_app = shared.gradio_root
@@ -799,7 +800,10 @@ def callback(step, x0, x, total_steps, y):
if n != 'None':
d.append((f'LoRA {li + 1}', f'{n} : {w}'))
d.append(('Version', 'v' + fooocus_version.version))
- log(x, d)
+ image_location = log(x, d)
+
+ if modules.config.use_add_model_previews:
+ add_preview_by_attempt(base_model_name, refiner_model_name, loras, image_location)
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
except ldm_patched.modules.model_management.InterruptProcessingException as e:
diff --git a/modules/config.py b/modules/config.py
index 58107806c..45c591f92 100644
--- a/modules/config.py
+++ b/modules/config.py
@@ -8,7 +8,7 @@
from modules.model_loader import load_file_from_url
from modules.util import get_files_from_folder
-
+from modules.model_previewer import cleanup as cleanup_model_previews
config_path = os.path.abspath("./config.txt")
config_example_path = os.path.abspath("config_modification_tutorial.txt")
@@ -316,6 +316,16 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
default_value=-1,
validator=lambda x: isinstance(x, int)
)
+use_cleanup_model_previews = get_config_item_or_set_default(
+ key='use_cleanup_model_previews',
+ default_value=False,
+ validator=lambda x: x == False or x == True
+)
+use_add_model_previews = get_config_item_or_set_default(
+ key='use_add_model_previews',
+ default_value=True,
+ validator=lambda x: x == False or x == True
+)
example_inpaint_prompts = get_config_item_or_set_default(
key='example_inpaint_prompts',
default_value=[
@@ -342,6 +352,9 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
"default_prompt_negative",
"default_styles",
"default_aspect_ratio",
+ "default_aspect_ratio",
+ "use_cleanup_model_previews"
+ "use_add_model_previews",
"checkpoint_downloads",
"embeddings_downloads",
"lora_downloads",
@@ -514,5 +527,7 @@ def downloading_upscale_model():
)
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
-
update_all_model_names()
+
+if use_cleanup_model_previews:
+ cleanup_model_previews()
diff --git a/modules/model_previewer.py b/modules/model_previewer.py
new file mode 100644
index 000000000..a11b29e88
--- /dev/null
+++ b/modules/model_previewer.py
@@ -0,0 +1,148 @@
+import os
+import json
+
+# Constants
+CHECKPOINTS_DIR = 'models/checkpoints'
+LORAS_DIR = 'models/loras'
+OUTPUT_FOLDER = 'outputs'
+PREVIEW_LOG_FILE = 'preview_log.json'
+
+def read_json_file(file_path):
+ """ Reads a JSON file and returns its contents, or creates a new file if it doesn't exist. """
+ try:
+ if not file_exists(file_path):
+ with open(file_path, 'w') as file:
+ json.dump({}, file)
+ return {}
+
+ with open(file_path, 'r') as file:
+ return json.load(file)
+ except IOError as e:
+ print(f"Error reading file {file_path}: {e}")
+ return {}
+
+def update_json_file(file_path, data):
+ """ Writes updated data to a JSON file. """
+ try:
+ with open(file_path, 'w') as file:
+ json.dump(data, file, indent=4)
+ except IOError as e:
+ print(f"Error writing to file {file_path}: {e}")
+
+def file_exists(file_path):
+ """ Checks if a file exists at the given path. """
+ return os.path.exists(file_path)
+
+def verify_and_cleanup_data(json_data, base_folder):
+ """ Verifies the existence of files and cleans up JSON data. """
+ cleaned_data = {}
+ for safetensor, images in json_data.items():
+ safetensor_path = os.path.join(base_folder, safetensor)
+ if file_exists(safetensor_path):
+ existing_images = [img for img in images if file_exists(os.path.join(OUTPUT_FOLDER, img))]
+ if existing_images:
+ cleaned_data[safetensor] = existing_images
+ return cleaned_data
+
+def get_cleaned_data(json_path, base_folder):
+ data = read_json_file(json_path)
+ cleaned_data = verify_and_cleanup_data(data, base_folder)
+ return cleaned_data
+
+def process_directory(directory):
+ """ Process a single directory (checkpoints or loras). """
+ json_path = os.path.join(directory, PREVIEW_LOG_FILE)
+ cleaned_data = get_cleaned_data(json_path, directory)
+ try:
+ with open(json_path, 'w') as f:
+ json.dump(cleaned_data, f, indent=4)
+ except IOError as e:
+ print(f"Error writing to file {json_path}: {e}")
+
+def cleanup():
+ """ Cleans up the JSON files in both checkpoints and loras directories. """
+ process_directory(CHECKPOINTS_DIR)
+ process_directory(LORAS_DIR)
+
+def add_preview(model_name, image_location, directory):
+ """ Adds a new image location to the preview list of a given model file. """
+ print(f"Adding new preview '{image_location}' for '{directory}/{model_name}'")
+ json_path = os.path.join(directory, PREVIEW_LOG_FILE)
+ data = read_json_file(json_path)
+
+ if model_name not in data:
+ data[model_name] = []
+ if image_location not in data[model_name]:
+ data[model_name].append(image_location)
+ update_json_file(json_path, data)
+
+def add_preview_for_checkpoint(model_name, image_location):
+ """ Adds a new image location for the given model file in checkpoints. """
+ add_preview(model_name, image_location, CHECKPOINTS_DIR)
+
+def add_preview_image_for_lora(model_name, image_location):
+ """ Adds a new image location for the given model file in loras. """
+ add_preview(model_name, image_location, LORAS_DIR)
+
+def add_preview_by_attempt(base_model_name, refiner_model_name, loras, image_location):
+ print(f"Attempting to add new preview for base model '{base_model_name}', refiner model '{refiner_model_name}' or for lora model '{loras}' to image location '{image_location}'")
+
+ # Add preview based on the only one lora name
+ active_loras = [lora for lora in loras if lora[0] != 'None']
+ if len(active_loras) == 1:
+ active_lora_name = active_loras[0][0]
+ add_preview_image_for_lora(active_lora_name, image_location)
+
+ # Add preview based on only one model name if possible
+ if len(active_loras) == 0:
+ if refiner_model_name == "None":
+ add_preview_for_checkpoint(base_model_name, image_location)
+ elif "_SD_" in refiner_model_name:
+ add_preview_for_checkpoint(refiner_model_name, image_location)
+
+def get_preview(model_name, directory):
+ json_path = os.path.join(directory, PREVIEW_LOG_FILE)
+ cleaned_data = get_cleaned_data(json_path, directory)
+ return get_preview_from_data(model_name, cleaned_data)
+
+def get_preview_from_data(model_name, data):
+ """ Retrieves the latest available image for the given model file. """
+ images = data.get(model_name, [])
+ if images:
+ latest_image = sorted(images, reverse=True)[0]
+ latest_image_path = OUTPUT_FOLDER + "/" + latest_image
+ if file_exists(latest_image_path):
+ return latest_image_path
+ print(f"Verbose Debug: File exists for model '{model_name}' at path '{latest_image_path}'.")
+ else:
+ print(f"Verbose Debug: File does not exist for model '{model_name}' at path '{latest_image_path}'.")
+ else:
+ print(f"Verbose Debug: No images found for model '{model_name}' in data.")
+ return None
+
+def get_all_previews(directory):
+ """ Retrieves the latest available image for all. """
+ json_path = os.path.join(directory, PREVIEW_LOG_FILE)
+ print(f"Verbose Debug: Get previews from '{json_path}'.")
+ data = read_json_file(json_path)
+
+ valid_previews = {}
+
+ # Find all files in the specified directory (only first level)
+ for filename in os.listdir(directory):
+ image_path = get_preview_from_data(filename, data)
+ if image_path is not None:
+ print(f"Verbose Debug: Valid preview found for '{filename}'.")
+ valid_previews[filename] = image_path
+ else:
+ print(f"Verbose Debug: No valid preview found for '{filename}'.")
+
+ return valid_previews
+
+def get_all_previews_for_checkpoints():
+ """ Retrieves the available images for a list of all model names in checkpoints. """
+ return get_all_previews(CHECKPOINTS_DIR)
+
+def get_all_previews_for_loras():
+ """ Retrieves the available images for a list of all model names in loras. """
+ return get_all_previews(LORAS_DIR)
\ No newline at end of file
diff --git a/modules/private_logger.py b/modules/private_logger.py
index 968bd4f5d..8182d4ed6 100644
--- a/modules/private_logger.py
+++ b/modules/private_logger.py
@@ -105,4 +105,5 @@ def log(img, dic):
log_cache[html_name] = middle_part
- return
+ image_location = date_string + "/" + only_name
+ return image_location
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index bebf9f8ca..615ced201 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -3,6 +3,8 @@
import os
import gradio as gr
import args_manager
+import json
+from modules.model_previewer import get_all_previews_for_checkpoints, get_all_previews_for_loras
from modules.localization import localization_js
@@ -40,12 +42,29 @@ def javascript_html():
head += f'\n'
head += f'\n'
head += f'\n'
+
+ js_code = get_js_code_from_updated_previews()
+ head += f"\n"
if args_manager.args.theme:
head += f'\n'
return head
+def get_js_code_from_updated_previews():
+ # Fetch the updated previews data
+ updated_previews_checkpoint = get_all_previews_for_checkpoints()
+ updated_previews_lora = get_all_previews_for_loras()
+
+ # Convert to JSON strings
+ updated_previews_checkpoint_json = json.dumps(updated_previews_checkpoint)
+ updated_previews_lora_json = json.dumps(updated_previews_lora)
+
+ # Inject updated data into JavaScript
+ return f"""
+ previewsCheckpoint = {updated_previews_checkpoint_json};
+ previewsLora = {updated_previews_lora_json};
+ """
def css_html():
style_css_path = webpath('css/style.css')
diff --git a/webui.py b/webui.py
index fadd852af..835273694 100644
--- a/webui.py
+++ b/webui.py
@@ -294,8 +294,8 @@ def refresh_seed(r, seed_string):
with gr.Tab(label='Model'):
with gr.Group():
with gr.Row():
- base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
- refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
+ base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True, elem_classes=['model_selections'])
+ refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True, elem_classes=['model_selections'])
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
info='Use 0.4 for SD1.5 realistic models; '
@@ -314,7 +314,7 @@ def refresh_seed(r, seed_string):
for i, (n, v) in enumerate(modules.config.default_loras):
with gr.Row():
lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
- choices=['None'] + modules.config.lora_filenames, value=n)
+ choices=['None'] + modules.config.lora_filenames, value=n, elem_classes=['model_selections'])
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
elem_classes='lora_weight')
lora_ctrls += [lora_model, lora_weight]