@@ -2532,14 +2532,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25322532 sd_image_to_ggml_tensor (sd_img_gen_params->mask_image , mask_img);
25332533 sd_image_to_ggml_tensor (sd_img_gen_params->init_image , init_img);
25342534
2535- init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
2536-
25372535 if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
25382536 int64_t mask_channels = 1 ;
25392537 if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
2540- mask_channels = 8 * 8 ; // flatten the whole mask
2538+ mask_channels = vae_scale_factor * vae_scale_factor ; // flatten the whole mask
25412539 } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2542- mask_channels = 1 + init_latent-> ne [ 2 ] ;
2540+ mask_channels = 1 + sd_ctx-> sd -> get_latent_channel () ;
25432541 }
25442542 ggml_tensor* masked_latent = nullptr ;
25452543
@@ -2548,8 +2546,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25482546 ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
25492547 ggml_ext_tensor_apply_mask (init_img, mask_img, masked_img);
25502548 masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
2549+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25512550 } else {
25522551 // mask after vae
2552+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25532553 masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
25542554 ggml_ext_tensor_apply_mask (init_latent, mask_img, masked_latent, 0 .);
25552555 }
@@ -2590,9 +2590,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25902590 for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
25912591 ggml_ext_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
25922592 }
2593+ } else {
2594+ float m = ggml_ext_tensor_get_f32 (mask_img, mx, my);
2595+ ggml_ext_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
2596+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2597+ float v = ggml_ext_tensor_get_f32 (masked_latent, ix, iy, k);
2598+ ggml_ext_tensor_set_f32 (concat_latent, v, ix, iy, k + mask_channels);
2599+ }
25932600 }
25942601 }
25952602 }
2603+ } else {
2604+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
25962605 }
25972606
25982607 {
0 commit comments