From 9ac48e80484c484a9a5c972eb266505bbbdaca0b Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Mon, 3 Nov 2025 20:22:45 -0300 Subject: [PATCH] feat: support --tensor-type-rules on generation modes --- examples/cli/main.cpp | 5 +-- model.cpp | 87 ++++++++++++++++++++++++------------------- model.h | 2 +- stable-diffusion.cpp | 7 +++- stable-diffusion.h | 1 + 5 files changed, 56 insertions(+), 46 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8f938c9b4..a7003b430 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1163,10 +1163,6 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } - if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) { - fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); - } - if (params.mode == VID_GEN && params.video_frames <= 0) { fprintf(stderr, "warning: --video-frames must be at least 1\n"); exit(1); @@ -1643,6 +1639,7 @@ int main(int argc, const char* argv[]) { params.lora_model_dir.c_str(), params.embedding_dir.c_str(), params.photo_maker_path.c_str(), + params.tensor_type_rules.c_str(), vae_decode_only, true, params.n_threads, diff --git a/model.cpp b/model.cpp index cec696632..6660bb9fc 100644 --- a/model.cpp +++ b/model.cpp @@ -1877,15 +1877,59 @@ std::map ModelLoader::get_vae_wtype_stat() { return wtype_stat; } -void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { +static std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { + std::vector> result; + for (const auto& item : split_string(tensor_type_rules, ',')) { + if (item.size() == 0) + continue; + std::string::size_type pos = item.find('='); + if (pos == std::string::npos) { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + continue; + } + std::string tensor_pattern = item.substr(0, pos); + std::string type_name = item.substr(pos + 1); + + ggml_type tensor_type = GGML_TYPE_COUNT; + + if (type_name == "f32") { + tensor_type = GGML_TYPE_F32; + } else { + for (size_t i = 0; i < GGML_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + if (trait->to_float && trait->type_size && type_name == trait->type_name) { + tensor_type = (ggml_type)i; + } + } + } + + if (tensor_type != GGML_TYPE_COUNT) { + result.emplace_back(tensor_pattern, tensor_type); + } else { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + } + } + return result; +} + +void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) { + auto map_rules = parse_tensor_type_rules(tensor_type_rules); for (auto& [name, tensor_storage] : tensor_storage_map) { - if (!starts_with(name, prefix)) { + ggml_type dst_type = wtype; + for (const auto& tensor_type_rule : map_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + if (dst_type == GGML_TYPE_COUNT) { continue; } - if (!tensor_should_be_converted(tensor_storage, wtype)) { + if (!tensor_should_be_converted(tensor_storage, dst_type)) { continue; } - tensor_storage.expected_type = wtype; + tensor_storage.expected_type = dst_type; } } @@ -2226,41 +2270,6 @@ bool ModelLoader::load_tensors(std::map& tenso return true; } -std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { - std::vector> result; - for (const auto& item : split_string(tensor_type_rules, ',')) { - if (item.size() == 0) - continue; - std::string::size_type pos = item.find('='); - if (pos == std::string::npos) { - LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); - continue; - } - std::string tensor_pattern = item.substr(0, pos); - std::string type_name = item.substr(pos + 1); - - ggml_type tensor_type = GGML_TYPE_COUNT; - - if (type_name == "f32") { - tensor_type = GGML_TYPE_F32; - } else { - for (size_t i = 0; i < GGML_TYPE_COUNT; i++) { - auto trait = ggml_get_type_traits((ggml_type)i); - if (trait->to_float && trait->type_size && type_name == trait->type_name) { - tensor_type = (ggml_type)i; - } - } - } - - if (tensor_type != GGML_TYPE_COUNT) { - result.emplace_back(tensor_pattern, tensor_type); - } else { - LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); - } - } - return result; -} - bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) { const std::string& name = tensor_storage.name; if (type != GGML_TYPE_COUNT) { diff --git a/model.h b/model.h index a29160cf0..30fbd3e23 100644 --- a/model.h +++ b/model.h @@ -281,7 +281,7 @@ class ModelLoader { std::map get_diffusion_model_wtype_stat(); std::map get_vae_wtype_stat(); String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } - void set_wtype_override(ggml_type wtype, std::string prefix = ""); + void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(std::map& tensors, std::set ignore_tensors = {}, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9faba955a..ac5000b5a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -286,8 +286,9 @@ class StableDiffusionGGML { ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype : GGML_TYPE_COUNT; - if (wtype != GGML_TYPE_COUNT) { - model_loader.set_wtype_override(wtype); + std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules); + if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) { + model_loader.set_wtype_override(wtype, tensor_type_rules); } std::map wtype_stat = model_loader.get_wtype_stat(); @@ -1893,6 +1894,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "lora_model_dir: %s\n" "embedding_dir: %s\n" "photo_maker_path: %s\n" + "tensor_type_rules: %s\n" "vae_decode_only: %s\n" "free_params_immediately: %s\n" "n_threads: %d\n" @@ -1922,6 +1924,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->lora_model_dir), SAFE_STR(sd_ctx_params->embedding_dir), SAFE_STR(sd_ctx_params->photo_maker_path), + SAFE_STR(sd_ctx_params->tensor_type_rules), BOOL_STR(sd_ctx_params->vae_decode_only), BOOL_STR(sd_ctx_params->free_params_immediately), sd_ctx_params->n_threads, diff --git a/stable-diffusion.h b/stable-diffusion.h index f618d457b..397f838f8 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -151,6 +151,7 @@ typedef struct { const char* lora_model_dir; const char* embedding_dir; const char* photo_maker_path; + const char* tensor_type_rules; bool vae_decode_only; bool free_params_immediately; int n_threads;