diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8f938c9b4..c35ca306d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -98,6 +98,9 @@ struct SDParams { std::vector high_noise_skip_layers = {7, 8, 9}; sd_sample_params_t high_noise_sample_params; + std::string easycache_option; + sd_easycache_params_t easycache_params; + float moe_boundary = 0.875f; int video_frames = 1; int fps = 16; @@ -139,6 +142,7 @@ struct SDParams { sd_sample_params_init(&sample_params); sd_sample_params_init(&high_noise_sample_params); high_noise_sample_params.sample_steps = -1; + sd_easycache_params_init(&easycache_params); } }; @@ -208,6 +212,11 @@ void print_params(SDParams params) { printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); printf(" video_frames: %d\n", params.video_frames); + printf(" easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", + params.easycache_params.enabled ? "enabled" : "disabled", + params.easycache_params.reuse_threshold, + params.easycache_params.start_percent, + params.easycache_params.end_percent); printf(" vace_strength: %.2f\n", params.vace_strength); printf(" fps: %d\n", params.fps); free(sample_params_str); @@ -593,6 +602,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--upscale-model", "path to esrgan model.", ¶ms.esrgan_path}, + {"", + "--easycache", + "enable EasyCache for DiT models with \"threshold,start_percent,end_percent\" (example: 0.2,0.15,0.95)", + ¶ms.easycache_option}, }; options.int_options = { @@ -1117,6 +1130,59 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (!params.easycache_option.empty()) { + float values[3] = {0.0f, 0.0f, 0.0f}; + std::stringstream ss(params.easycache_option); + std::string token; + int idx = 0; + while (std::getline(ss, token, ',')) { + auto trim = [](std::string& s) { + const char* whitespace = " \t\r\n"; + auto start = s.find_first_not_of(whitespace); + if (start == std::string::npos) { + s.clear(); + return; + } + auto end = s.find_last_not_of(whitespace); + s = s.substr(start, end - start + 1); + }; + trim(token); + if (token.empty()) { + fprintf(stderr, "error: invalid easycache option '%s'\n", params.easycache_option.c_str()); + exit(1); + } + if (idx >= 3) { + fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + exit(1); + } + try { + values[idx] = std::stof(token); + } catch (const std::exception&) { + fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str()); + exit(1); + } + idx++; + } + if (idx != 3) { + fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n"); + exit(1); + } + if (values[0] < 0.0f) { + fprintf(stderr, "error: easycache threshold must be non-negative\n"); + exit(1); + } + if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) { + fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n"); + exit(1); + } + params.easycache_params.enabled = true; + params.easycache_params.reuse_threshold = values[0]; + params.easycache_params.start_percent = values[1]; + params.easycache_params.end_percent = values[2]; + } else { + params.easycache_params.enabled = false; + } + if (params.n_threads <= 0) { params.n_threads = get_num_physical_cores(); } @@ -1716,6 +1782,7 @@ int main(int argc, const char* argv[]) { params.pm_style_strength, }, // pm_params params.vae_tiling_params, + params.easycache_params, }; results = generate_image(sd_ctx, &img_gen_params); @@ -1738,6 +1805,7 @@ int main(int argc, const char* argv[]) { params.seed, params.video_frames, params.vace_strength, + params.easycache_params, }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9faba955a..97bf43d68 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -16,6 +16,9 @@ #include "tae.hpp" #include "vae.hpp" +#include +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -73,6 +76,260 @@ void calculate_alphas_cumprod(float* alphas_cumprod, } } +struct EasyCacheConfig { + bool enabled = false; + float reuse_threshold = 0.2f; + float start_percent = 0.15f; + float end_percent = 0.95f; +}; + +struct EasyCacheCacheEntry { + std::vector diff; +}; + +struct EasyCacheState { + EasyCacheConfig config; + Denoiser* denoiser = nullptr; + float start_sigma = std::numeric_limits::max(); + float end_sigma = 0.0f; + bool initialized = false; + bool initial_step = true; + bool skip_current_step = false; + bool step_active = false; + const SDCondition* anchor_condition = nullptr; + std::unordered_map cache_diffs; + std::vector prev_input; + std::vector prev_output; + float output_prev_norm = 0.0f; + bool has_prev_input = false; + bool has_prev_output = false; + bool has_output_prev_norm = false; + bool has_relative_transformation_rate = false; + float relative_transformation_rate = 0.0f; + float cumulative_change_rate = 0.0f; + float last_input_change = 0.0f; + bool has_last_input_change = false; + int total_steps_skipped = 0; + int current_step_index = -1; + + void reset_runtime() { + initial_step = true; + skip_current_step = false; + step_active = false; + anchor_condition = nullptr; + cache_diffs.clear(); + prev_input.clear(); + prev_output.clear(); + output_prev_norm = 0.0f; + has_prev_input = false; + has_prev_output = false; + has_output_prev_norm = false; + has_relative_transformation_rate = false; + relative_transformation_rate = 0.0f; + cumulative_change_rate = 0.0f; + last_input_change = 0.0f; + has_last_input_change = false; + total_steps_skipped = 0; + current_step_index = -1; + } + + void init(const EasyCacheConfig& cfg, Denoiser* d) { + config = cfg; + denoiser = d; + initialized = cfg.enabled && d != nullptr; + reset_runtime(); + if (initialized) { + start_sigma = percent_to_sigma(config.start_percent); + end_sigma = percent_to_sigma(config.end_percent); + } + } + + bool enabled() const { + return initialized && config.enabled; + } + + float percent_to_sigma(float percent) const { + if (!denoiser) { + return 0.0f; + } + if (percent <= 0.0f) { + return std::numeric_limits::max(); + } + if (percent >= 1.0f) { + return 0.0f; + } + float t = (1.0f - percent) * (TIMESTEPS - 1); + return denoiser->t_to_sigma(t); + } + + void begin_step(int step_index, float sigma) { + if (!enabled()) { + return; + } + if (step_index == current_step_index) { + return; + } + current_step_index = step_index; + skip_current_step = false; + has_last_input_change = false; + step_active = false; + if (sigma > start_sigma) { + return; + } + if (!(sigma > end_sigma)) { + return; + } + step_active = true; + } + + bool step_is_active() const { + return enabled() && step_active; + } + + bool has_cache(const SDCondition* cond) const { + auto it = cache_diffs.find(cond); + return it != cache_diffs.end() && !it->second.diff.empty(); + } + + void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + EasyCacheCacheEntry& entry = cache_diffs[cond]; + size_t ne = static_cast(ggml_nelements(output)); + entry.diff.resize(ne); + float* out_data = (float*)output->data; + float* in_data = (float*)input->data; + for (size_t i = 0; i < ne; ++i) { + entry.diff[i] = out_data[i] - in_data[i]; + } + } + + void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + auto it = cache_diffs.find(cond); + if (it == cache_diffs.end() || it->second.diff.empty()) { + return; + } + copy_ggml_tensor(output, input); + float* out_data = (float*)output->data; + const std::vector& diff = it->second.diff; + for (size_t i = 0; i < diff.size(); ++i) { + out_data[i] += diff[i]; + } + } + + bool before_condition(const SDCondition* cond, + ggml_tensor* input, + ggml_tensor* output, + float sigma, + int step_index) { + if (!enabled() || step_index < 0) { + return false; + } + if (step_index != current_step_index) { + begin_step(step_index, sigma); + } + if (!step_active) { + return false; + } + if (initial_step) { + anchor_condition = cond; + initial_step = false; + } + bool is_anchor = (cond == anchor_condition); + if (skip_current_step) { + if (has_cache(cond)) { + apply_cache(cond, input, output); + return true; + } + return false; + } + if (!is_anchor) { + return false; + } + if (!has_prev_input || !has_prev_output || !has_cache(cond)) { + return false; + } + size_t ne = static_cast(ggml_nelements(input)); + if (prev_input.size() != ne) { + return false; + } + float* input_data = (float*)input->data; + last_input_change = 0.0f; + for (size_t i = 0; i < ne; ++i) { + last_input_change += std::fabs(input_data[i] - prev_input[i]); + } + if (ne > 0) { + last_input_change /= static_cast(ne); + } + has_last_input_change = true; + + if (has_output_prev_norm && has_relative_transformation_rate && last_input_change > 0.0f && output_prev_norm > 0.0f) { + float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm; + cumulative_change_rate += approx_output_change_rate; + if (cumulative_change_rate < config.reuse_threshold) { + skip_current_step = true; + total_steps_skipped++; + apply_cache(cond, input, output); + return true; + } else { + cumulative_change_rate = 0.0f; + } + } + + return false; + } + + void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) { + if (!step_is_active()) { + return; + } + update_cache(cond, input, output); + if (cond != anchor_condition) { + return; + } + + size_t ne = static_cast(ggml_nelements(input)); + float* in_data = (float*)input->data; + prev_input.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_input[i] = in_data[i]; + } + has_prev_input = true; + + float* out_data = (float*)output->data; + float output_change = 0.0f; + if (has_prev_output && prev_output.size() == ne) { + for (size_t i = 0; i < ne; ++i) { + output_change += std::fabs(out_data[i] - prev_output[i]); + } + if (ne > 0) { + output_change /= static_cast(ne); + } + } + + prev_output.resize(ne); + for (size_t i = 0; i < ne; ++i) { + prev_output[i] = out_data[i]; + } + has_prev_output = true; + + float mean_abs = 0.0f; + for (size_t i = 0; i < ne; ++i) { + mean_abs += std::fabs(out_data[i]); + } + output_prev_norm = (ne > 0) ? (mean_abs / static_cast(ne)) : 0.0f; + has_output_prev_norm = output_prev_norm > 0.0f; + + if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) { + float rate = output_change / last_input_change; + if (std::isfinite(rate)) { + relative_transformation_rate = rate; + has_relative_transformation_rate = true; + } + } + cumulative_change_rate = 0.0f; + has_last_input_change = false; + } +}; + /*=============================================== StableDiffusionGGML ================================================*/ class StableDiffusionGGML { @@ -1129,7 +1386,8 @@ class StableDiffusionGGML { bool increase_ref_index = false, ggml_tensor* denoise_mask = nullptr, ggml_tensor* vace_context = nullptr, - float vace_strength = 1.f) { + float vace_strength = 1.f, + const sd_easycache_params_t* easycache_params = nullptr) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("timestep shifting is only supported for SDXL models!"); shifted_timestep = 0; @@ -1145,6 +1403,42 @@ class StableDiffusionGGML { img_cfg_scale = cfg_scale; } + EasyCacheState easycache_state; + bool easycache_enabled = false; + if (easycache_params != nullptr && easycache_params->enabled) { + bool easycache_supported = sd_version_is_dit(version); + if (!easycache_supported) { + LOG_WARN("EasyCache requested but not supported for this model type"); + } else { + EasyCacheConfig easycache_config; + easycache_config.enabled = true; + easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold); + easycache_config.start_percent = easycache_params->start_percent; + easycache_config.end_percent = easycache_params->end_percent; + bool percent_valid = easycache_config.start_percent >= 0.0f && + easycache_config.start_percent < 1.0f && + easycache_config.end_percent > 0.0f && + easycache_config.end_percent <= 1.0f && + easycache_config.start_percent < easycache_config.end_percent; + if (!percent_valid) { + LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)", + easycache_config.start_percent, + easycache_config.end_percent); + } else { + easycache_state.init(easycache_config, denoiser.get()); + if (easycache_state.enabled()) { + easycache_enabled = true; + LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f", + easycache_config.reuse_threshold, + easycache_config.start_percent, + easycache_config.end_percent); + } else { + LOG_WARN("EasyCache requested but could not be initialized for this run"); + } + } + } + } + size_t steps = sigmas.size() - 1; struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); copy_ggml_tensor(x, init_latent); @@ -1188,6 +1482,12 @@ class StableDiffusionGGML { pretty_progress(0, (int)steps, 0); } + const bool easycache_step_active = easycache_enabled && step > 0; + int easycache_step_index = easycache_step_active ? (step - 1) : -1; + if (easycache_step_active) { + easycache_state.begin_step(easycache_step_index, sigma); + } + std::vector scaling = denoiser->get_scalings(sigma); GGML_ASSERT(scaling.size() == 3); float c_skip = scaling[0]; @@ -1239,21 +1539,38 @@ class StableDiffusionGGML { diffusion_params.vace_context = vace_context; diffusion_params.vace_strength = vace_strength; + const SDCondition* active_condition = nullptr; + struct ggml_tensor** active_output = &out_cond; if (start_merge_step == -1 || step <= start_merge_step) { // cond diffusion_params.context = cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; diffusion_params.y = cond.c_vector; - work_diffusion_model->compute(n_threads, - diffusion_params, - &out_cond); + active_condition = &cond; } else { diffusion_params.context = id_cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; diffusion_params.y = id_cond.c_vector; + active_condition = &id_cond; + } + + bool skip_model = false; + if (easycache_step_active && active_condition != nullptr) { + skip_model = easycache_state.before_condition(active_condition, + diffusion_params.x, + *active_output, + sigma, + easycache_step_index); + } + if (!skip_model) { work_diffusion_model->compute(n_threads, diffusion_params, - &out_cond); + active_output); + if (easycache_step_active && active_condition != nullptr) { + easycache_state.after_condition(active_condition, + diffusion_params.x, + *active_output); + } } float* negative_data = nullptr; @@ -1354,6 +1671,26 @@ class StableDiffusionGGML { sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); + if (easycache_enabled) { + size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; + if (easycache_state.total_steps_skipped > 0 && total_steps > 0) { + if (easycache_state.total_steps_skipped < static_cast(total_steps)) { + double speedup = static_cast(total_steps) / + static_cast(total_steps - easycache_state.total_steps_skipped); + LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)", + easycache_state.total_steps_skipped, + total_steps, + speedup); + } else { + LOG_INFO("EasyCache skipped %d/%zu steps", + easycache_state.total_steps_skipped, + total_steps); + } + } else if (total_steps > 0) { + LOG_INFO("EasyCache completed without skipping steps"); + } + } + if (inverse_noise_scaling) { x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); } @@ -1852,6 +2189,14 @@ enum prediction_t str_to_prediction(const char* str) { return PREDICTION_COUNT; } +void sd_easycache_params_init(sd_easycache_params_t* easycache_params) { + *easycache_params = {}; + easycache_params->enabled = false; + easycache_params->reuse_threshold = 0.2f; + easycache_params->start_percent = 0.15f; + easycache_params->end_percent = 0.95f; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; @@ -2004,6 +2349,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; + sd_easycache_params_init(&sd_img_gen_params->easycache); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2047,6 +2393,12 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params.id_images_count, SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); + snprintf(buf + strlen(buf), 4096 - strlen(buf), + "easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n", + sd_img_gen_params->easycache.enabled ? "enabled" : "disabled", + sd_img_gen_params->easycache.reuse_threshold, + sd_img_gen_params->easycache.start_percent, + sd_img_gen_params->easycache.end_percent); free(sample_params_str); return buf; } @@ -2063,6 +2415,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; + sd_easycache_params_init(&sd_vid_gen_params->easycache); } struct sd_ctx_t { @@ -2131,7 +2484,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, std::vector ref_latents, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, - ggml_tensor* denoise_mask = nullptr) { + ggml_tensor* denoise_mask = nullptr, + const sd_easycache_params_t* easycache_params = nullptr) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -2407,7 +2761,10 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, id_cond, ref_latents, increase_ref_index, - denoise_mask); + denoise_mask, + nullptr, + 1.0f, + easycache_params); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -2720,7 +3077,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ref_latents, sd_img_gen_params->increase_ref_index, concat_latent, - denoise_mask); + denoise_mask, + &sd_img_gen_params->easycache); size_t t2 = ggml_time_ms(); @@ -3040,7 +3398,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s false, denoise_mask, vace_context, - sd_vid_gen_params->vace_strength); + sd_vid_gen_params->vace_strength, + &sd_vid_gen_params->easycache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -3076,7 +3435,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s false, denoise_mask, vace_context, - sd_vid_gen_params->vace_strength); + sd_vid_gen_params->vace_strength, + &sd_vid_gen_params->easycache); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); diff --git a/stable-diffusion.h b/stable-diffusion.h index f618d457b..86487c4cd 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -209,6 +209,13 @@ typedef struct { float style_strength; } sd_pm_params_t; // photo maker +typedef struct { + bool enabled; + float reuse_threshold; + float start_percent; + float end_percent; +} sd_easycache_params_t; + typedef struct { const char* prompt; const char* negative_prompt; @@ -229,6 +236,7 @@ typedef struct { float control_strength; sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; + sd_easycache_params_t easycache; } sd_img_gen_params_t; typedef struct { @@ -248,6 +256,7 @@ typedef struct { int64_t seed; int video_frames; float vace_strength; + sd_easycache_params_t easycache; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; @@ -271,6 +280,8 @@ SD_API enum scheduler_t str_to_schedule(const char* str); SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API enum prediction_t str_to_prediction(const char* str); +SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params); + SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);