Skip to content

Commit 655ee12

Browse files
committed
feat: support for --tensor-type-rules on generation modes
1 parent 6103d86 commit 655ee12

File tree

5 files changed

+58
-46
lines changed

5 files changed

+58
-46
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,10 +1163,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11631163
exit(1);
11641164
}
11651165

1166-
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
1167-
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
1168-
}
1169-
11701166
if (params.mode == VID_GEN && params.video_frames <= 0) {
11711167
fprintf(stderr, "warning: --video-frames must be at least 1\n");
11721168
exit(1);
@@ -1643,6 +1639,7 @@ int main(int argc, const char* argv[]) {
16431639
params.lora_model_dir.c_str(),
16441640
params.embedding_dir.c_str(),
16451641
params.photo_maker_path.c_str(),
1642+
params.tensor_type_rules.c_str(),
16461643
vae_decode_only,
16471644
true,
16481645
params.n_threads,

model.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,17 +1992,63 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
19921992
return wtype_stat;
19931993
}
19941994

1995-
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
1995+
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
1996+
std::vector<std::pair<std::string, ggml_type>> result;
1997+
for (const auto& item : split_string(tensor_type_rules, ',')) {
1998+
if (item.size() == 0)
1999+
continue;
2000+
std::string::size_type pos = item.find('=');
2001+
if (pos == std::string::npos) {
2002+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2003+
continue;
2004+
}
2005+
std::string tensor_pattern = item.substr(0, pos);
2006+
std::string type_name = item.substr(pos + 1);
2007+
2008+
ggml_type tensor_type = GGML_TYPE_COUNT;
2009+
2010+
if (type_name == "f32") {
2011+
tensor_type = GGML_TYPE_F32;
2012+
} else {
2013+
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
2014+
auto trait = ggml_get_type_traits((ggml_type)i);
2015+
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
2016+
tensor_type = (ggml_type)i;
2017+
}
2018+
}
2019+
}
2020+
2021+
if (tensor_type != GGML_TYPE_COUNT) {
2022+
result.emplace_back(tensor_pattern, tensor_type);
2023+
} else {
2024+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2025+
}
2026+
}
2027+
return result;
2028+
}
2029+
2030+
void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
2031+
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
19962032
for (auto& pair : tensor_storages_types) {
1997-
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
2033+
ggml_type dst_type = wtype;
2034+
2035+
for (const auto& tensor_type_rule : map_rules) {
2036+
std::regex pattern(tensor_type_rule.first);
2037+
if (std::regex_search(pair.first, pattern)) {
2038+
dst_type = tensor_type_rule.second;
2039+
break;
2040+
}
2041+
}
2042+
2043+
if (dst_type != GGML_TYPE_COUNT) {
19982044
bool found = false;
19992045
for (auto& tensor_storage : tensor_storages) {
20002046
std::map<std::string, ggml_type> temp;
20012047
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
20022048
for (auto& preprocessed_name : temp) {
20032049
if (preprocessed_name.first == pair.first) {
2004-
if (tensor_should_be_converted(tensor_storage, wtype)) {
2005-
pair.second = wtype;
2050+
if (tensor_should_be_converted(tensor_storage, dst_type)) {
2051+
pair.second = dst_type;
20062052
}
20072053
found = true;
20082054
break;
@@ -2449,41 +2495,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
24492495
return true;
24502496
}
24512497

2452-
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
2453-
std::vector<std::pair<std::string, ggml_type>> result;
2454-
for (const auto& item : split_string(tensor_type_rules, ',')) {
2455-
if (item.size() == 0)
2456-
continue;
2457-
std::string::size_type pos = item.find('=');
2458-
if (pos == std::string::npos) {
2459-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2460-
continue;
2461-
}
2462-
std::string tensor_pattern = item.substr(0, pos);
2463-
std::string type_name = item.substr(pos + 1);
2464-
2465-
ggml_type tensor_type = GGML_TYPE_COUNT;
2466-
2467-
if (type_name == "f32") {
2468-
tensor_type = GGML_TYPE_F32;
2469-
} else {
2470-
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
2471-
auto trait = ggml_get_type_traits((ggml_type)i);
2472-
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
2473-
tensor_type = (ggml_type)i;
2474-
}
2475-
}
2476-
}
2477-
2478-
if (tensor_type != GGML_TYPE_COUNT) {
2479-
result.emplace_back(tensor_pattern, tensor_type);
2480-
} else {
2481-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
2482-
}
2483-
}
2484-
return result;
2485-
}
2486-
24872498
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
24882499
const std::string& name = tensor_storage.name;
24892500
if (type != GGML_TYPE_COUNT) {

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class ModelLoader {
271271
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
272272
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
273273
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
274-
void set_wtype_override(ggml_type wtype, std::string prefix = "");
274+
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
275275
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
276276
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
277277
std::set<std::string> ignore_tensors = {},

stable-diffusion.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,9 @@ class StableDiffusionGGML {
286286
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
287287
? (ggml_type)sd_ctx_params->wtype
288288
: GGML_TYPE_COUNT;
289-
if (wtype != GGML_TYPE_COUNT) {
290-
model_loader.set_wtype_override(wtype);
289+
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
290+
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
291+
model_loader.set_wtype_override(wtype, tensor_type_rules);
291292
}
292293

293294
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
@@ -1894,6 +1895,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18941895
"lora_model_dir: %s\n"
18951896
"embedding_dir: %s\n"
18961897
"photo_maker_path: %s\n"
1898+
"tensor_type_rules: %s\n"
18971899
"vae_decode_only: %s\n"
18981900
"free_params_immediately: %s\n"
18991901
"n_threads: %d\n"
@@ -1923,6 +1925,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
19231925
SAFE_STR(sd_ctx_params->lora_model_dir),
19241926
SAFE_STR(sd_ctx_params->embedding_dir),
19251927
SAFE_STR(sd_ctx_params->photo_maker_path),
1928+
SAFE_STR(sd_ctx_params->tensor_type_rules),
19261929
BOOL_STR(sd_ctx_params->vae_decode_only),
19271930
BOOL_STR(sd_ctx_params->free_params_immediately),
19281931
sd_ctx_params->n_threads,

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ typedef struct {
151151
const char* lora_model_dir;
152152
const char* embedding_dir;
153153
const char* photo_maker_path;
154+
const char* tensor_type_rules;
154155
bool vae_decode_only;
155156
bool free_params_immediately;
156157
int n_threads;

0 commit comments

Comments
 (0)