@@ -62,6 +62,7 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n
6262 # or 0 for numerical data
6363 self .num_categories_per_col = num_categories_per_col
6464 self .embed_features = self .num_categories_per_col > 0
65+ self .num_features_excl_embed = num_features_excl_embed
6566
6667 self .num_embed_features = self .num_categories_per_col [self .embed_features ]
6768
@@ -84,8 +85,8 @@ def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbed
8485 partial_model (_LearnedEntityEmbedding)
8586 a new partial model
8687 """
87- num_input_features = self .num_input_features [subset_features ]
88- num_numerical_features = sum ([sf < self .num_numerical for sf in subset_features ])
88+ num_input_features = self .num_categories_per_col [subset_features ]
89+ num_features_excl_embed = sum ([sf < self .num_features_excl_embed for sf in subset_features ])
8990
9091 num_output_dimensions = [self .num_output_dimensions [sf ] for sf in subset_features ]
9192 embed_features = [self .embed_features [sf ] for sf in subset_features ]
@@ -98,7 +99,7 @@ def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbed
9899 ee_layer_tracker += 1
99100 ee_layers = nn .ModuleList (ee_layers )
100101
101- return PartialLearnedEntityEmbedding (num_input_features , num_numerical_features , embed_features ,
102+ return PartialLearnedEntityEmbedding (num_input_features , num_features_excl_embed , embed_features ,
102103 num_output_dimensions , ee_layers )
103104
104105 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -136,28 +137,27 @@ class PartialLearnedEntityEmbedding(_LearnedEntityEmbedding):
136137 of the input features. This is applied to forecasting tasks where not all the features might be known beforehand
137138 """
138139 def __init__ (self ,
139- num_input_features : np .ndarray ,
140- num_numerical_features : int ,
140+ num_categories_per_col : np .ndarray ,
141+ num_features_excl_embed : int ,
141142 embed_features : List [bool ],
142143 num_output_dimensions : List [int ],
143144 ee_layers : nn .Module
144145 ):
145146 super (_LearnedEntityEmbedding , self ).__init__ ()
146- self .num_numerical = num_numerical_features
147+ self .num_features_excl_embed = num_features_excl_embed
147148 # list of number of categories of categorical data
148149 # or 0 for numerical data
149- self .num_input_features = num_input_features
150- categorical_features : np .ndarray = self .num_input_features > 0
151-
152- self .num_categorical_features = self .num_input_features [categorical_features ]
150+ self .num_categories_per_col = num_categories_per_col
153151
154152 self .embed_features = embed_features
155153
156154 self .num_output_dimensions = num_output_dimensions
157- self .num_out_feats = self .num_numerical + sum (self .num_output_dimensions )
155+ self .num_out_feats = self .num_features_excl_embed + sum (self .num_output_dimensions )
158156
159157 self .ee_layers = ee_layers
160158
159+ self .num_embed_features = self .num_categories_per_col [self .embed_features ]
160+
161161
162162class LearnedEntityEmbedding (NetworkEmbeddingComponent ):
163163 """
0 commit comments