Skip to content

Commit 98675ae

Browse files
committed
refactor 2-cfg conditioning + better img_cond defaults
1 parent dd75fc0 commit 98675ae

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

stable-diffusion.cpp

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ class StableDiffusionGGML {
11201120
ggml_tensor* noise,
11211121
SDCondition cond,
11221122
SDCondition uncond,
1123-
SDCondition img_cond,
1123+
SDCondition img_uncond,
11241124
ggml_tensor* control_hint,
11251125
float control_strength,
11261126
sd_guidance_params_t guidance,
@@ -1147,7 +1147,7 @@ class StableDiffusionGGML {
11471147

11481148
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
11491149
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
1150-
img_cfg_scale = cfg_scale;
1150+
img_cfg_scale = 1.0f;
11511151
}
11521152

11531153
size_t steps = sigmas.size() - 1;
@@ -1159,10 +1159,11 @@ class StableDiffusionGGML {
11591159
}
11601160

11611161
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x);
1162-
1163-
bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != nullptr;
1164-
bool has_img_cond = cfg_scale != img_cfg_scale && img_cond.c_crossattn != nullptr;
1162+
11651163
bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0;
1164+
bool has_conditionned = (has_skiplayer || cfg_scale != 0.0) && cond.c_crossattn != nullptr;
1165+
bool has_unconditioned = cfg_scale != img_cfg_scale && uncond.c_crossattn != nullptr;
1166+
bool has_img_uncond = img_cfg_scale != 1.0 && img_uncond.c_crossattn != nullptr;
11661167

11671168
// denoise wrapper
11681169
struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x);
@@ -1181,7 +1182,7 @@ class StableDiffusionGGML {
11811182
LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]);
11821183
}
11831184
}
1184-
if (has_img_cond) {
1185+
if (has_img_uncond) {
11851186
out_img_cond = ggml_dup_tensor(work_ctx, x);
11861187
}
11871188
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
@@ -1244,21 +1245,23 @@ class StableDiffusionGGML {
12441245
diffusion_params.vace_context = vace_context;
12451246
diffusion_params.vace_strength = vace_strength;
12461247

1247-
if (start_merge_step == -1 || step <= start_merge_step) {
1248-
// cond
1249-
diffusion_params.context = cond.c_crossattn;
1250-
diffusion_params.c_concat = cond.c_concat;
1251-
diffusion_params.y = cond.c_vector;
1252-
work_diffusion_model->compute(n_threads,
1253-
diffusion_params,
1254-
&out_cond);
1255-
} else {
1256-
diffusion_params.context = id_cond.c_crossattn;
1257-
diffusion_params.c_concat = cond.c_concat;
1258-
diffusion_params.y = id_cond.c_vector;
1259-
work_diffusion_model->compute(n_threads,
1260-
diffusion_params,
1261-
&out_cond);
1248+
if (has_conditionned) {
1249+
if (start_merge_step == -1 || step <= start_merge_step) {
1250+
// cond
1251+
diffusion_params.context = cond.c_crossattn;
1252+
diffusion_params.c_concat = cond.c_concat;
1253+
diffusion_params.y = cond.c_vector;
1254+
work_diffusion_model->compute(n_threads,
1255+
diffusion_params,
1256+
&out_cond);
1257+
} else {
1258+
diffusion_params.context = id_cond.c_crossattn;
1259+
diffusion_params.c_concat = cond.c_concat;
1260+
diffusion_params.y = id_cond.c_vector;
1261+
work_diffusion_model->compute(n_threads,
1262+
diffusion_params,
1263+
&out_cond);
1264+
}
12621265
}
12631266

12641267
float* negative_data = nullptr;
@@ -1279,10 +1282,10 @@ class StableDiffusionGGML {
12791282
}
12801283

12811284
float* img_cond_data = nullptr;
1282-
if (has_img_cond) {
1283-
diffusion_params.context = img_cond.c_crossattn;
1284-
diffusion_params.c_concat = img_cond.c_concat;
1285-
diffusion_params.y = img_cond.c_vector;
1285+
if (has_img_uncond) {
1286+
diffusion_params.context = img_uncond.c_crossattn;
1287+
diffusion_params.c_concat = img_uncond.c_concat;
1288+
diffusion_params.y = img_uncond.c_vector;
12861289
work_diffusion_model->compute(n_threads,
12871290
diffusion_params,
12881291
&out_img_cond);
@@ -1325,19 +1328,19 @@ class StableDiffusionGGML {
13251328
float latent_result = positive_data[i];
13261329
if (has_unconditioned) {
13271330
// out_uncond + cfg_scale * (out_cond - out_uncond)
1328-
if (has_img_cond) {
1331+
if (has_img_uncond) {
13291332
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
13301333
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
13311334
} else {
13321335
// img_cfg_scale == cfg_scale
13331336
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
13341337
}
1335-
} else if (has_img_cond) {
1338+
} else if (has_img_uncond) {
13361339
// img_cfg_scale == 1
13371340
latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]);
13381341
}
13391342
if (is_skiplayer_step) {
1340-
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale;
1343+
latent_result = latent_result + slg_scale * (positive_data[i] - skip_layer_data[i]);
13411344
}
13421345
// v = latent_result, eps = latent_result
13431346
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
@@ -1981,7 +1984,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
19811984
sample_params->guidance.txt_cfg,
19821985
std::isfinite(sample_params->guidance.img_cfg)
19831986
? sample_params->guidance.img_cfg
1984-
: sample_params->guidance.txt_cfg,
1987+
: 1.0f,
19851988
sample_params->guidance.distilled_guidance,
19861989
sample_params->guidance.slg.layer_count,
19871990
sample_params->guidance.slg.layer_start,
@@ -2146,7 +2149,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
21462149
}
21472150

21482151
if (!std::isfinite(guidance.img_cfg)) {
2149-
guidance.img_cfg = guidance.txt_cfg;
2152+
// default to 1
2153+
guidance.img_cfg = 1.0f;
21502154
}
21512155

21522156
// for (auto v : sigmas) {
@@ -2254,7 +2258,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
22542258

22552259
SDCondition uncond;
22562260
if (guidance.txt_cfg != 1.0 ||
2257-
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
2261+
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != 1.0f)) {
22582262
bool zero_out_masked = false;
22592263
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
22602264
zero_out_masked = true;
@@ -2292,14 +2296,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
22922296
ggml_ext_tensor_scale_inplace(control_latent, control_strength);
22932297
}
22942298

2299+
struct ggml_tensor* empty_latent;
2300+
22952301
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
22962302
int64_t mask_channels = 1;
22972303
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
22982304
mask_channels = 8 * 8; // flatten the whole mask
22992305
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
23002306
mask_channels = 1 + init_latent->ne[2];
23012307
}
2302-
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
2308+
empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
23032309
// no mask, set the whole image as masked
23042310
for (int64_t x = 0; x < empty_latent->ne[0]; x++) {
23052311
for (int64_t y = 0; y < empty_latent->ne[1]; y++) {
@@ -2349,31 +2355,31 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
23492355
concat_latent = empty_latent;
23502356
}
23512357
cond.c_concat = concat_latent;
2352-
uncond.c_concat = empty_latent;
2358+
uncond.c_concat = concat_latent;
23532359
denoise_mask = nullptr;
23542360
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
2355-
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
2361+
empty_latent = ggml_dup_tensor(work_ctx, init_latent);
23562362
ggml_set_f32(empty_latent, 0);
2357-
uncond.c_concat = empty_latent;
2358-
cond.c_concat = ref_latents[0];
2363+
cond.c_concat = ref_latents[0];
23592364
if (cond.c_concat == nullptr) {
23602365
cond.c_concat = empty_latent;
23612366
}
2367+
uncond.c_concat = cond.c_concat;
23622368
} else if (sd_version_is_control(sd_ctx->sd->version)) {
2363-
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
2369+
empty_latent = ggml_dup_tensor(work_ctx, init_latent);
23642370
ggml_set_f32(empty_latent, 0);
2365-
uncond.c_concat = empty_latent;
23662371
if (sd_ctx->sd->control_net == nullptr) {
23672372
cond.c_concat = control_latent;
23682373
}
23692374
if (cond.c_concat == nullptr) {
23702375
cond.c_concat = empty_latent;
23712376
}
2377+
uncond.c_concat = cond.c_concat;
23722378
}
2373-
SDCondition img_cond;
2379+
SDCondition img_uncond = uncond;
23742380
if (uncond.c_crossattn != nullptr &&
2375-
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
2376-
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
2381+
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.img_cfg != 1.0)) {
2382+
img_uncond = SDCondition(uncond.c_crossattn, uncond.c_vector, empty_latent);
23772383
}
23782384
for (int b = 0; b < batch_count; b++) {
23792385
int64_t sampling_start = ggml_time_ms();
@@ -2400,7 +2406,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
24002406
noise,
24012407
cond,
24022408
uncond,
2403-
img_cond,
2409+
img_uncond,
24042410
image_hint,
24052411
control_strength,
24062412
guidance,
@@ -3022,7 +3028,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
30223028
noise,
30233029
cond,
30243030
uncond,
3025-
{},
3031+
uncond,
30263032
nullptr,
30273033
0,
30283034
sd_vid_gen_params->high_noise_sample_params.guidance,
@@ -3058,7 +3064,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
30583064
noise,
30593065
cond,
30603066
uncond,
3061-
{},
3067+
uncond,
30623068
nullptr,
30633069
0,
30643070
sd_vid_gen_params->sample_params.guidance,

0 commit comments

Comments
 (0)