Skip to content

Commit 8f6c5c2

Browse files
authored
refactor: simplify the model loading logic (#933)
* remove String2GGMLType * remove preprocess_tensor * fix clip init * simplify the logic for reading weights
1 parent 6103d86 commit 8f6c5c2

21 files changed

+533
-621
lines changed

clip.hpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,12 @@ struct CLIPLayer : public GGMLBlock {
476476
public:
477477
CLIPLayer(int64_t d_model,
478478
int64_t n_head,
479-
int64_t intermediate_size)
479+
int64_t intermediate_size,
480+
bool proj_in = false)
480481
: d_model(d_model),
481482
n_head(n_head),
482483
intermediate_size(intermediate_size) {
483-
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true));
484+
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true, proj_in));
484485

485486
blocks["layer_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
486487
blocks["layer_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
@@ -509,11 +510,12 @@ struct CLIPEncoder : public GGMLBlock {
509510
CLIPEncoder(int64_t n_layer,
510511
int64_t d_model,
511512
int64_t n_head,
512-
int64_t intermediate_size)
513+
int64_t intermediate_size,
514+
bool proj_in = false)
513515
: n_layer(n_layer) {
514516
for (int i = 0; i < n_layer; i++) {
515517
std::string name = "layers." + std::to_string(i);
516-
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size));
518+
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size, proj_in));
517519
}
518520
}
519521

@@ -549,10 +551,10 @@ class CLIPEmbeddings : public GGMLBlock {
549551
int64_t num_positions;
550552
bool force_clip_f32;
551553

552-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
554+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
553555
enum ggml_type token_wtype = GGML_TYPE_F32;
554556
if (!force_clip_f32) {
555-
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
557+
token_wtype = get_type(prefix + "token_embedding.weight", tensor_storage_map, GGML_TYPE_F32);
556558
if (!support_get_rows(token_wtype)) {
557559
token_wtype = GGML_TYPE_F32;
558560
}
@@ -605,7 +607,8 @@ class CLIPVisionEmbeddings : public GGMLBlock {
605607
int64_t image_size;
606608
int64_t num_patches;
607609
int64_t num_positions;
608-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
610+
611+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
609612
enum ggml_type patch_wtype = GGML_TYPE_F16;
610613
enum ggml_type class_wtype = GGML_TYPE_F32;
611614
enum ggml_type position_wtype = GGML_TYPE_F32;
@@ -668,7 +671,7 @@ enum CLIPVersion {
668671

669672
class CLIPTextModel : public GGMLBlock {
670673
protected:
671-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
674+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
672675
if (version == OPEN_CLIP_VIT_BIGG_14) {
673676
enum ggml_type wtype = GGML_TYPE_F32;
674677
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
@@ -689,7 +692,8 @@ class CLIPTextModel : public GGMLBlock {
689692

690693
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
691694
bool with_final_ln = true,
692-
bool force_clip_f32 = false)
695+
bool force_clip_f32 = false,
696+
bool proj_in = false)
693697
: version(version), with_final_ln(with_final_ln) {
694698
if (version == OPEN_CLIP_VIT_H_14) {
695699
hidden_size = 1024;
@@ -704,7 +708,7 @@ class CLIPTextModel : public GGMLBlock {
704708
}
705709

706710
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32));
707-
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
711+
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
708712
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
709713
}
710714

@@ -758,7 +762,7 @@ class CLIPVisionModel : public GGMLBlock {
758762
int32_t n_layer = 24;
759763

760764
public:
761-
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) {
765+
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool proj_in = false) {
762766
if (version == OPEN_CLIP_VIT_H_14) {
763767
hidden_size = 1280;
764768
intermediate_size = 5120;
@@ -773,7 +777,7 @@ class CLIPVisionModel : public GGMLBlock {
773777

774778
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPVisionEmbeddings(hidden_size, num_channels, patch_size, image_size));
775779
blocks["pre_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
776-
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
780+
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
777781
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
778782
}
779783

@@ -811,8 +815,8 @@ class CLIPProjection : public UnaryBlock {
811815
int64_t out_features;
812816
bool transpose_weight;
813817

814-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
815-
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
818+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
819+
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
816820
if (transpose_weight) {
817821
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
818822
} else {
@@ -845,15 +849,16 @@ class CLIPVisionModelProjection : public GGMLBlock {
845849

846850
public:
847851
CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
848-
bool transpose_proj_w = false) {
852+
bool transpose_proj_w = false,
853+
bool proj_in = false) {
849854
if (version == OPEN_CLIP_VIT_H_14) {
850855
hidden_size = 1280;
851856
projection_dim = 1024;
852857
} else if (version == OPEN_CLIP_VIT_BIGG_14) {
853858
hidden_size = 1664;
854859
}
855860

856-
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version));
861+
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version, proj_in));
857862
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
858863
}
859864

@@ -881,13 +886,24 @@ struct CLIPTextModelRunner : public GGMLRunner {
881886

882887
CLIPTextModelRunner(ggml_backend_t backend,
883888
bool offload_params_to_cpu,
884-
const String2GGMLType& tensor_types,
889+
const String2TensorStorage& tensor_storage_map,
885890
const std::string prefix,
886891
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
887892
bool with_final_ln = true,
888893
bool force_clip_f32 = false)
889-
: GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, force_clip_f32) {
890-
model.init(params_ctx, tensor_types, prefix);
894+
: GGMLRunner(backend, offload_params_to_cpu) {
895+
bool proj_in = false;
896+
for (const auto& [name, tensor_storage] : tensor_storage_map) {
897+
if (!starts_with(name, prefix)) {
898+
continue;
899+
}
900+
if (contains(name, "self_attn.in_proj")) {
901+
proj_in = true;
902+
break;
903+
}
904+
}
905+
model = CLIPTextModel(version, with_final_ln, force_clip_f32, proj_in);
906+
model.init(params_ctx, tensor_storage_map, prefix);
891907
}
892908

893909
std::string get_desc() override {

common.hpp

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ class GEGLU : public UnaryBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
186-
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
185+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
186+
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
187187
enum ggml_type bias_wtype = GGML_TYPE_F32;
188188
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
189189
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
@@ -408,30 +408,40 @@ class SpatialTransformer : public GGMLBlock {
408408
int64_t d_head;
409409
int64_t depth = 1; // 1
410410
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
411+
bool use_linear = false;
411412

412413
public:
413414
SpatialTransformer(int64_t in_channels,
414415
int64_t n_head,
415416
int64_t d_head,
416417
int64_t depth,
417-
int64_t context_dim)
418+
int64_t context_dim,
419+
bool use_linear)
418420
: in_channels(in_channels),
419421
n_head(n_head),
420422
d_head(d_head),
421423
depth(depth),
422-
context_dim(context_dim) {
423-
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
424+
context_dim(context_dim),
425+
use_linear(use_linear) {
424426
// disable_self_attn is always False
425427
int64_t inner_dim = n_head * d_head; // in_channels
426428
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
427-
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
429+
if (use_linear) {
430+
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, inner_dim));
431+
} else {
432+
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
433+
}
428434

429435
for (int i = 0; i < depth; i++) {
430436
std::string name = "transformer_blocks." + std::to_string(i);
431437
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false));
432438
}
433439

434-
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
440+
if (use_linear) {
441+
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, in_channels));
442+
} else {
443+
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
444+
}
435445
}
436446

437447
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@@ -440,8 +450,8 @@ class SpatialTransformer : public GGMLBlock {
440450
// x: [N, in_channels, h, w]
441451
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
442452
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
443-
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
444-
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
453+
auto proj_in = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_in"]);
454+
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
445455

446456
auto x_in = x;
447457
int64_t n = x->ne[3];
@@ -450,10 +460,15 @@ class SpatialTransformer : public GGMLBlock {
450460
int64_t inner_dim = n_head * d_head;
451461

452462
x = norm->forward(ctx, x);
453-
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
454-
455-
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
456-
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
463+
if (use_linear) {
464+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
465+
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
466+
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
467+
} else {
468+
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
469+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
470+
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
471+
}
457472

458473
for (int i = 0; i < depth; i++) {
459474
std::string name = "transformer_blocks." + std::to_string(i);
@@ -462,11 +477,19 @@ class SpatialTransformer : public GGMLBlock {
462477
x = transformer_block->forward(ctx, x, context);
463478
}
464479

465-
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
466-
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
480+
if (use_linear) {
481+
// proj_out
482+
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
467483

468-
// proj_out
469-
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
484+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
485+
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
486+
} else {
487+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
488+
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
489+
490+
// proj_out
491+
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
492+
}
470493

471494
x = ggml_add(ctx->ggml_ctx, x, x_in);
472495
return x;
@@ -475,7 +498,7 @@ class SpatialTransformer : public GGMLBlock {
475498

476499
class AlphaBlender : public GGMLBlock {
477500
protected:
478-
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
501+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
479502
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
480503
enum ggml_type wtype = GGML_TYPE_F32;
481504
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);

0 commit comments

Comments
 (0)