diff --git a/conditioner.hpp b/conditioner.hpp index cfd2b4ca7..0f8edac1d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -196,27 +196,91 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { } std::vector convert_token_to_id(std::string text) { + size_t search_pos = 0; auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { - size_t word_end = str.find(","); - std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); - embd_name = trim(embd_name); - std::string embd_path = get_full_path(embd_dir, embd_name + ".pt"); - if (embd_path.size() == 0) { - embd_path = get_full_path(embd_dir, embd_name + ".ckpt"); + std::string token_str; + size_t consumed_len = 0; + bool is_embed_tag = false; + + // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk. + std::string trimmed_str = trim(str); + size_t leading_spaces = str.length() - trimmed_str.length(); + + if (starts_with(trimmed_str, ""); + if (tag_end == std::string::npos) { + return false; // Incomplete tag. + } + std::string lower_tag = trimmed_str.substr(0, tag_end + 1); + token_str = lower_tag; // Fallback to lowercased version + + if (text.length() >= lower_tag.length()) { + for (size_t i = search_pos; i <= text.length() - lower_tag.length(); ++i) { + bool match = true; + for (size_t j = 0; j < lower_tag.length(); ++j) { + if (std::tolower(text[i + j]) != lower_tag[j]) { + match = false; + break; + } + } + if (match) { + token_str = text.substr(i, lower_tag.length()); + search_pos = i + token_str.length(); + break; + } + } + } + consumed_len = leading_spaces + token_str.length(); + is_embed_tag = true; + } else { + // Not a tag. Could be a plain trigger word. + size_t first_delim = trimmed_str.find_first_of(" ,"); + token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim); + consumed_len = leading_spaces + token_str.length(); + } + + std::string embd_name = trim(token_str); + if (is_embed_tag) { + embd_name = embd_name.substr(strlen(" 0) { if (load_embedding(embd_name, embd_path, bpe_tokens)) { - if (word_end != std::string::npos) { - str = str.substr(word_end); - } else { - str = ""; - } + str = str.substr(consumed_len); return true; } } + + if (is_embed_tag) { + LOG_WARN("could not load embedding '%s'", embd_name.c_str()); + str = str.substr(consumed_len); + return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text. + } + + // It was not a tag and we couldn't find a file for it as a trigger word. return false; }; std::vector curr_tokens = tokenizer.encode(text, on_new_token_cb); @@ -245,30 +309,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); } - auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { - size_t word_end = str.find(","); - std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); - embd_name = trim(embd_name); - std::string embd_path = get_full_path(embd_dir, embd_name + ".pt"); - if (embd_path.size() == 0) { - embd_path = get_full_path(embd_dir, embd_name + ".ckpt"); - } - if (embd_path.size() == 0) { - embd_path = get_full_path(embd_dir, embd_name + ".safetensors"); - } - if (embd_path.size() > 0) { - if (load_embedding(embd_name, embd_path, bpe_tokens)) { - if (word_end != std::string::npos) { - str = str.substr(word_end); - } else { - str = ""; - } - return true; - } - } - return false; - }; - std::vector tokens; std::vector weights; std::vector class_token_mask; @@ -278,6 +318,93 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids; const std::string& curr_text = item.first; float curr_weight = item.second; + size_t search_pos = 0; + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + std::string token_str; + size_t consumed_len = 0; + bool is_embed_tag = false; + + // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk. + std::string trimmed_str = trim(str); + size_t leading_spaces = str.length() - trimmed_str.length(); + + if (starts_with(trimmed_str, ""); + if (tag_end == std::string::npos) { + return false; // Incomplete tag. + } + std::string lower_tag = trimmed_str.substr(0, tag_end + 1); + token_str = lower_tag; // Fallback to lowercased version + + if (curr_text.length() >= lower_tag.length()) { + for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) { + bool match = true; + for (size_t j = 0; j < lower_tag.length(); ++j) { + if (std::tolower(curr_text[i + j]) != lower_tag[j]) { + match = false; + break; + } + } + if (match) { + token_str = curr_text.substr(i, lower_tag.length()); + search_pos = i + token_str.length(); + break; + } + } + } + consumed_len = leading_spaces + token_str.length(); + is_embed_tag = true; + } else { + // Not a tag. Could be a plain trigger word. + size_t first_delim = trimmed_str.find_first_of(" ,"); + token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim); + consumed_len = leading_spaces + token_str.length(); + } + + std::string embd_name = trim(token_str); + if (is_embed_tag) { + embd_name = embd_name.substr(strlen(" 0) { + if (load_embedding(embd_name, embd_path, bpe_tokens)) { + str = str.substr(consumed_len); + return true; + } + } + + if (is_embed_tag) { + LOG_WARN("could not load embedding '%s'", embd_name.c_str()); + str = str.substr(consumed_len); + return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text. + } + + // It was not a tag and we couldn't find a file for it as a trigger word. + return false; + }; // printf(" %s: %f \n", curr_text.c_str(), curr_weight); std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); int32_t clean_index = 0; @@ -359,35 +486,98 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); } - auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { - size_t word_end = str.find(","); - std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); - embd_name = trim(embd_name); - std::string embd_path = get_full_path(embd_dir, embd_name + ".pt"); - if (embd_path.size() == 0) { - embd_path = get_full_path(embd_dir, embd_name + ".ckpt"); - } - if (embd_path.size() == 0) { - embd_path = get_full_path(embd_dir, embd_name + ".safetensors"); - } - if (embd_path.size() > 0) { - if (load_embedding(embd_name, embd_path, bpe_tokens)) { - if (word_end != std::string::npos) { - str = str.substr(word_end); - } else { - str = ""; - } - return true; - } - } - return false; - }; - std::vector tokens; std::vector weights; for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; + size_t search_pos = 0; + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + std::string token_str; + size_t consumed_len = 0; + bool is_embed_tag = false; + + // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk. + std::string trimmed_str = trim(str); + size_t leading_spaces = str.length() - trimmed_str.length(); + + if (starts_with(trimmed_str, ""); + if (tag_end == std::string::npos) { + return false; // Incomplete tag. + } + std::string lower_tag = trimmed_str.substr(0, tag_end + 1); + token_str = lower_tag; // Fallback to lowercased version + + if (curr_text.length() >= lower_tag.length()) { + for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) { + bool match = true; + for (size_t j = 0; j < lower_tag.length(); ++j) { + if (std::tolower(curr_text[i + j]) != lower_tag[j]) { + match = false; + break; + } + } + if (match) { + token_str = curr_text.substr(i, lower_tag.length()); + search_pos = i + token_str.length(); + break; + } + } + } + consumed_len = leading_spaces + token_str.length(); + is_embed_tag = true; + } else { + // Not a tag. Could be a plain trigger word. + size_t first_delim = trimmed_str.find_first_of(" ,"); + token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim); + consumed_len = leading_spaces + token_str.length(); + } + + std::string embd_name = trim(token_str); + if (is_embed_tag) { + embd_name = embd_name.substr(strlen(" 0) { + if (load_embedding(embd_name, embd_path, bpe_tokens)) { + str = str.substr(consumed_len); + return true; + } + } + + if (is_embed_tag) { + LOG_WARN("could not load embedding '%s'", embd_name.c_str()); + str = str.substr(consumed_len); + return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text. + } + + // It was not a tag and we couldn't find a file for it as a trigger word. + return false; + }; std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index db4e07cb0..4af3d6ab3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -810,17 +810,34 @@ class StableDiffusionGGML { is_high_noise = true; LOG_DEBUG("high noise lora: %s", lora_name.c_str()); } - std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); - std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt"); + std::string st_file_path; + std::string ckpt_file_path; std::string file_path; - if (file_exists(st_file_path)) { + bool is_path = contains(lora_name, "/") || contains(lora_name, "\\"); + + if (is_path) { + st_file_path = lora_name + ".safetensors"; + ckpt_file_path = lora_name + ".ckpt"; + } else { + st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); + ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt"); + } + + if (is_path && file_exists(lora_name)) { + file_path = lora_name; + } else if (file_exists(st_file_path)) { file_path = st_file_path; } else if (file_exists(ckpt_file_path)) { file_path = ckpt_file_path; } else { - LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); + if (is_path) { + LOG_WARN("can not find lora file %s, %s or %s", lora_name.c_str(), st_file_path.c_str(), ckpt_file_path.c_str()); + } else { + LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); + } return; } + LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : ""); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str());