Skip to content

Commit c42826b

Browse files
authored
fix: resolve multiple inpainting issues (#926)
* Fix inpainting masked image being broken by side effect * Fix unet inpainting concat not being set * Fix Flex.2 inpaint mode crash (+ use scale factor)
1 parent 945d9a9 commit c42826b

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

stable-diffusion.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,14 +2532,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25322532
sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img);
25332533
sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img);
25342534

2535-
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
2536-
25372535
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
25382536
int64_t mask_channels = 1;
25392537
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
2540-
mask_channels = 8 * 8; // flatten the whole mask
2538+
mask_channels = vae_scale_factor * vae_scale_factor; // flatten the whole mask
25412539
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2542-
mask_channels = 1 + init_latent->ne[2];
2540+
mask_channels = 1 + sd_ctx->sd->get_latent_channel();
25432541
}
25442542
ggml_tensor* masked_latent = nullptr;
25452543

@@ -2548,8 +2546,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25482546
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
25492547
ggml_ext_tensor_apply_mask(init_img, mask_img, masked_img);
25502548
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2549+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
25512550
} else {
25522551
// mask after vae
2552+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
25532553
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
25542554
ggml_ext_tensor_apply_mask(init_latent, mask_img, masked_latent, 0.);
25552555
}
@@ -2590,9 +2590,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25902590
for (int k = 0; k < masked_latent->ne[2]; k++) {
25912591
ggml_ext_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
25922592
}
2593+
} else {
2594+
float m = ggml_ext_tensor_get_f32(mask_img, mx, my);
2595+
ggml_ext_tensor_set_f32(concat_latent, m, ix, iy, 0);
2596+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2597+
float v = ggml_ext_tensor_get_f32(masked_latent, ix, iy, k);
2598+
ggml_ext_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
2599+
}
25932600
}
25942601
}
25952602
}
2603+
} else {
2604+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
25962605
}
25972606

25982607
{

0 commit comments

Comments
 (0)