@@ -1992,17 +1992,63 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
19921992 return wtype_stat;
19931993}
19941994
1995- void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
1995+ static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
1996+ std::vector<std::pair<std::string, ggml_type>> result;
1997+ for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
1998+ if (item.size () == 0 )
1999+ continue ;
2000+ std::string::size_type pos = item.find (' =' );
2001+ if (pos == std::string::npos) {
2002+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2003+ continue ;
2004+ }
2005+ std::string tensor_pattern = item.substr (0 , pos);
2006+ std::string type_name = item.substr (pos + 1 );
2007+
2008+ ggml_type tensor_type = GGML_TYPE_COUNT;
2009+
2010+ if (type_name == " f32" ) {
2011+ tensor_type = GGML_TYPE_F32;
2012+ } else {
2013+ for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
2014+ auto trait = ggml_get_type_traits ((ggml_type)i);
2015+ if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
2016+ tensor_type = (ggml_type)i;
2017+ }
2018+ }
2019+ }
2020+
2021+ if (tensor_type != GGML_TYPE_COUNT) {
2022+ result.emplace_back (tensor_pattern, tensor_type);
2023+ } else {
2024+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2025+ }
2026+ }
2027+ return result;
2028+ }
2029+
2030+ void ModelLoader::set_wtype_override (ggml_type wtype, std::string tensor_type_rules) {
2031+ auto map_rules = parse_tensor_type_rules (tensor_type_rules);
19962032 for (auto & pair : tensor_storages_types) {
1997- if (prefix.size () < 1 || pair.first .substr (0 , prefix.size ()) == prefix) {
2033+ ggml_type dst_type = wtype;
2034+
2035+ for (const auto & tensor_type_rule : map_rules) {
2036+ std::regex pattern (tensor_type_rule.first );
2037+ if (std::regex_search (pair.first , pattern)) {
2038+ dst_type = tensor_type_rule.second ;
2039+ break ;
2040+ }
2041+ }
2042+
2043+ if (dst_type != GGML_TYPE_COUNT) {
19982044 bool found = false ;
19992045 for (auto & tensor_storage : tensor_storages) {
20002046 std::map<std::string, ggml_type> temp;
20012047 add_preprocess_tensor_storage_types (temp, tensor_storage.name , tensor_storage.type );
20022048 for (auto & preprocessed_name : temp) {
20032049 if (preprocessed_name.first == pair.first ) {
2004- if (tensor_should_be_converted (tensor_storage, wtype )) {
2005- pair.second = wtype ;
2050+ if (tensor_should_be_converted (tensor_storage, dst_type )) {
2051+ pair.second = dst_type ;
20062052 }
20072053 found = true ;
20082054 break ;
@@ -2449,41 +2495,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
24492495 return true ;
24502496}
24512497
2452- std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
2453- std::vector<std::pair<std::string, ggml_type>> result;
2454- for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
2455- if (item.size () == 0 )
2456- continue ;
2457- std::string::size_type pos = item.find (' =' );
2458- if (pos == std::string::npos) {
2459- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2460- continue ;
2461- }
2462- std::string tensor_pattern = item.substr (0 , pos);
2463- std::string type_name = item.substr (pos + 1 );
2464-
2465- ggml_type tensor_type = GGML_TYPE_COUNT;
2466-
2467- if (type_name == " f32" ) {
2468- tensor_type = GGML_TYPE_F32;
2469- } else {
2470- for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
2471- auto trait = ggml_get_type_traits ((ggml_type)i);
2472- if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
2473- tensor_type = (ggml_type)i;
2474- }
2475- }
2476- }
2477-
2478- if (tensor_type != GGML_TYPE_COUNT) {
2479- result.emplace_back (tensor_pattern, tensor_type);
2480- } else {
2481- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2482- }
2483- }
2484- return result;
2485- }
2486-
24872498bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
24882499 const std::string& name = tensor_storage.name ;
24892500 if (type != GGML_TYPE_COUNT) {
0 commit comments