diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8f938c9b4..a57fc3b9b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -67,7 +67,7 @@ struct SDParams { std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; - std::string taesd_path; + std::string tae_path; std::string esrgan_path; std::string control_net_path; std::string embedding_dir; @@ -159,7 +159,7 @@ void print_params(SDParams params) { printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); - printf(" taesd_path: %s\n", params.taesd_path.c_str()); + printf(" tae_path: %s\n", params.tae_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" control_net_path: %s\n", params.control_net_path.c_str()); printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); @@ -523,10 +523,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--vae", "path to standalone vae model", ¶ms.vae_path}, - {"", + {"--tae", "--taesd", - "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", - ¶ms.taesd_path}, + "path to taesd or taehv. Using Tiny AutoEncoder for fast decoding (low quality)", + ¶ms.tae_path}, {"", "--control-net", "path to control net model", @@ -1638,7 +1638,7 @@ int main(int argc, const char* argv[]) { params.diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(), params.vae_path.c_str(), - params.taesd_path.c_str(), + params.tae_path.c_str(), params.control_net_path.c_str(), params.lora_model_dir.c_str(), params.embedding_dir.c_str(), diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9faba955a..ae588161f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -346,8 +346,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -389,10 +389,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -418,10 +418,10 @@ class StableDiffusionGGML { "", enable_vision); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -475,14 +475,27 @@ class StableDiffusionGGML { } if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - version); - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + if (!use_tiny_autoencoder) { + first_stage_model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + version); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); + } else { + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder", + vae_decode_only, + version); + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the tae model"); + tae_first_stage->set_conv2d_direct_enabled(true); + } + } } else if (version == VERSION_CHROMA_RADIANCE) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); @@ -510,12 +523,12 @@ class StableDiffusionGGML { first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); + tae_first_stage = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder.layers", + vae_decode_only, + version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); tae_first_stage->set_conv2d_direct_enabled(true); @@ -625,12 +638,15 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; + LOG_DEBUG("Here"); if (!use_tiny_autoencoder) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); } else { + LOG_DEBUG("Here"); if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } + LOG_DEBUG("Here"); vae_params_mem_size = tae_first_stage->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; @@ -1428,12 +1444,12 @@ class StableDiffusionGGML { -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; @@ -1474,12 +1490,12 @@ class StableDiffusionGGML { -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; } for (int i = 0; i < latent->ne[3]; i++) { float mean = latents_mean_vec[i]; @@ -1706,6 +1722,10 @@ class StableDiffusionGGML { first_stage_model->free_compute_buffer(); process_vae_output_tensor(result); } else { + if (sd_version_is_wan(version)) { + x = ggml_permute(work_ctx, x, 0, 1, 3, 2); + } + if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -1716,6 +1736,7 @@ class StableDiffusionGGML { tae_first_stage->compute(n_threads, x, true, &result); } tae_first_stage->free_compute_buffer(); + } int64_t t1 = ggml_time_ms(); @@ -3104,7 +3125,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { sd_ctx->sd->first_stage_model->free_params_buffer(); } diff --git a/tae.hpp b/tae.hpp index 14cdb578d..ad0bd373e 100644 --- a/tae.hpp +++ b/tae.hpp @@ -162,6 +162,227 @@ class TinyDecoder : public UnaryBlock { } }; +class TPool : public UnaryBlock { + int stride; + +public: + TPool(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels * stride, channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = x; + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] * stride, h->ne[3] / stride); + } + h = conv->forward(ctx, h); + return h; + } +}; + +class TGrow : public UnaryBlock { + int stride; + +public: + TGrow(int channels, int stride) : stride(stride) { + blocks["conv"] = std::shared_ptr(new Conv2d(channels, channels * stride, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto h = conv->forward(ctx, x); + if (stride != 1) { + h = ggml_reshape_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2] / stride, h->ne[3] * stride); + } + return h; + } +}; + +class MemBlock : public GGMLBlock { + bool has_skip_conv = false; + +public: + MemBlock(int channels, int out_channels) : has_skip_conv(channels != out_channels) { + blocks["conv.0"] = std::shared_ptr(new Conv2d(channels * 2, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + blocks["conv.4"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + if (has_skip_conv) { + blocks["skip"] = std::shared_ptr(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past) { + // x: [n, channels, h, w] + auto conv0 = std::dynamic_pointer_cast(blocks["conv.0"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv.2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv.4"]); + + auto h = ggml_concat(ctx->ggml_ctx, x, past, 2); + h = conv0->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto skip = x; + if (has_skip_conv) { + auto skip_conv = std::dynamic_pointer_cast(blocks["skip"]); + skip = skip_conv->forward(ctx, x); + } + h = ggml_add_inplace(ctx->ggml_ctx, h, skip); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + return h; + } +}; + +class TinyVideoEncoder : public UnaryBlock { + int in_channels = 3; + int hidden = 64; + int z_channels = 4; + int num_blocks = 3; + int num_layers = 3; + int patch_size = 1; + +public: + TinyVideoEncoder(int z_channels = 4, int patch_size = 1) + : z_channels(z_channels), patch_size(patch_size) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == num_layers - 1 ? 1 : 2; + blocks[std::to_string(index++)] = std::shared_ptr(new TPool(hidden, stride)); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(hidden, hidden)); + } + } + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + // return z; + auto first_conv = std::dynamic_pointer_cast(blocks["0"]); + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(num_layers * (num_blocks + 2) + 1)]); + auto h = first_conv->forward(ctx, z); + + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + for (int i = 2; i < num_layers * (num_blocks + 2) + 2; i++) { + if (blocks.find(std::to_string(i)) == blocks.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks[std::to_string(i)]); + h = block->forward(ctx, h); + } + h = last_conv->forward(ctx, h); + return h; + } +}; + +class TinyVideoDecoder : public UnaryBlock { + int z_channels = 4; + int out_channels = 3; + int num_blocks = 3; + static const int num_layers = 3; + int channels[num_layers + 1] = {256, 128, 64, 64}; + +public: + TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) { + int index = 1; // Clamp() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); + index++; // nn.ReLU() + for (int i = 0; i < num_layers; i++) { + int stride = i == 0 ? 1 : 2; + for (int j = 0; j < num_blocks; j++) { + blocks[std::to_string(index++)] = std::shared_ptr(new MemBlock(channels[i], channels[i])); + } + index++; // nn.Upsample() + blocks[std::to_string(index++)] = std::shared_ptr(new TGrow(channels[i], stride)); + LOG_DEBUG("Create Conv2d %d shape = %d %d", index, channels[i], channels[i + 1]); + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false)); + } + index++; // nn.ReLU() + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels[num_layers], out_channels * patch_size * patch_size, {3, 3}, {1, 1}, {1, 1})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { + auto first_conv = std::dynamic_pointer_cast(blocks["1"]); + + // Clamp() + auto h = ggml_scale_inplace(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), + 3.0f); + + h = first_conv->forward(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); + int index = 3; + for (int i = 0; i < num_layers; i++) { + for (int j = 0; j < num_blocks; j++) { + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + auto mem = ggml_pad(ctx->ggml_ctx, h, 0, 0, 0, 1); + mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); + h = block->forward(ctx, h, mem); + } + // upsample + index++; + h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); + auto block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + block = std::dynamic_pointer_cast(blocks[std::to_string(index++)]); + h = block->forward(ctx, h); + } + h = ggml_relu_inplace(ctx->ggml_ctx, h); + + auto last_conv = std::dynamic_pointer_cast(blocks[std::to_string(++index)]); + h = last_conv->forward(ctx, h); + + // shape(W, H, 3, T+3) => shape(W, H, 3, T) + h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0); + return h; + } +}; + +class TAEHV : public GGMLBlock { +protected: + bool decode_only; + SDVersion version; + +public: + TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) + : decode_only(decode_only), version(version) { + int z_channels = 16; + int patch = 1; + if (version == VERSION_WAN2_2_TI2V) { + z_channels = 48; + patch = 2; + } + blocks["decoder"] = std::shared_ptr(new TinyVideoDecoder(z_channels, patch)); + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new TinyVideoEncoder(z_channels, patch)); + } + } + + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto result = decoder->forward(ctx, z); + if (sd_version_is_wan(version)) { + // (W, H, C, T) -> (W, H, T, C) + result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); + } + return result; + } + + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + return nullptr; + // auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + // return encoder->forward(ctx, x); + } +}; + class TAESD : public GGMLBlock { protected: bool decode_only; @@ -192,18 +413,30 @@ class TAESD : public GGMLBlock { }; struct TinyAutoEncoder : public GGMLRunner { + TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) + : GGMLRunner(backend, offload_params_to_cpu) {} + virtual void compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) = 0; + + virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; +}; + +struct TinyImageAutoEncoder : public TinyAutoEncoder { TAESD taesd; bool decode_only = false; - TinyAutoEncoder(ggml_backend_t backend, - bool offload_params_to_cpu, - const String2TensorStorage& tensor_storage_map, - const std::string prefix, - bool decoder_only = true, - SDVersion version = VERSION_SD1) + TinyImageAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), - GGMLRunner(backend, offload_params_to_cpu) { + TinyAutoEncoder(backend, offload_params_to_cpu) { taesd.init(params_ctx, tensor_storage_map, prefix); } @@ -260,4 +493,73 @@ struct TinyAutoEncoder : public GGMLRunner { } }; +struct TinyVideoAutoEncoder : public TinyAutoEncoder { + TAEHV taehv; + bool decode_only = false; + + TinyVideoAutoEncoder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_WAN2) + : decode_only(decoder_only), + taehv(decoder_only, version), + TinyAutoEncoder(backend, offload_params_to_cpu) { + taehv.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "taehv"; + } + + bool load_from_file(const std::string& file_path, int n_threads) { + LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); + alloc_params_buffer(); + std::map taehv_tensors; + taehv.get_param_tensors(taehv_tensors); + std::set ignore_tensors; + if (decode_only) { + ignore_tensors.insert("encoder."); + } + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads); + + if (!success) { + LOG_ERROR("load tae tensors from model loader failed"); + return false; + } + + LOG_INFO("taehv model loaded"); + return success; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + z = to_backend(z); + auto runner_ctx = get_context(); + struct ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); + ggml_build_forward_expand(gf, out); + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } +}; + #endif // __TAE_HPP__ \ No newline at end of file