@@ -151,20 +151,16 @@ impl Infer {
151151 panic ! ( "unexpected enum variant" )
152152 } ;
153153
154- // Timings
155154 let total_time = start_time. elapsed ( ) ;
156155
157- // Metrics
158- let counter = metrics:: counter!( "te_embed_success" ) ;
159- counter. increment ( 1 ) ;
160- let histogram = metrics:: histogram!( "te_embed_duration" ) ;
161- histogram. record ( total_time. as_secs_f64 ( ) ) ;
162- let histogram = metrics:: histogram!( "te_embed_tokenization_duration" ) ;
163- histogram. record ( response. metadata . tokenization . as_secs_f64 ( ) ) ;
164- let histogram = metrics:: histogram!( "te_embed_queue_duration" ) ;
165- histogram. record ( response. metadata . queue . as_secs_f64 ( ) ) ;
166- let histogram = metrics:: histogram!( "te_embed_inference_duration" ) ;
167- histogram. record ( response. metadata . inference . as_secs_f64 ( ) ) ;
156+ metrics:: counter!( "te_embed_success" ) . increment ( 1 ) ;
157+ metrics:: histogram!( "te_embed_duration" ) . record ( total_time. as_secs_f64 ( ) ) ;
158+ metrics:: histogram!( "te_embed_tokenization_duration" )
159+ . record ( response. metadata . tokenization . as_secs_f64 ( ) ) ;
160+ metrics:: histogram!( "te_embed_queue_duration" )
161+ . record ( response. metadata . queue . as_secs_f64 ( ) ) ;
162+ metrics:: histogram!( "te_embed_inference_duration" )
163+ . record ( response. metadata . inference . as_secs_f64 ( ) ) ;
168164
169165 Ok ( response)
170166 }
@@ -224,6 +220,7 @@ impl Infer {
224220 Ok ( response)
225221 }
226222
223+ #[ allow( clippy:: too_many_arguments) ]
227224 #[ instrument( skip( self , inputs, permit) ) ]
228225 pub async fn embed_pooled < I : Into < EncodingInput > + std:: fmt:: Debug > (
229226 & self ,
@@ -232,20 +229,31 @@ impl Infer {
232229 truncation_direction : TruncationDirection ,
233230 prompt_name : Option < String > ,
234231 normalize : bool ,
232+ dimensions : Option < usize > ,
235233 permit : OwnedSemaphorePermit ,
236234 ) -> Result < PooledEmbeddingsInferResponse , TextEmbeddingsError > {
237235 let start_time = Instant :: now ( ) ;
238236
239237 if self . is_splade ( ) && normalize {
240238 let counter = metrics:: counter!( "te_request_failure" , "err" => "model_type" ) ;
241239 counter. increment ( 1 ) ;
240+
242241 let message = "`normalize` is not available for SPLADE models" . to_string ( ) ;
243242 tracing:: error!( "{message}" ) ;
244243 return Err ( TextEmbeddingsError :: Backend ( BackendError :: Inference (
245244 message,
246245 ) ) ) ;
247246 }
248247
248+ if let Some ( dimensions) = dimensions {
249+ if dimensions == 0 {
250+ metrics:: counter!( "te_request_failure" , "err" => "validation" ) . increment ( 1 ) ;
251+ let message = "`dimensions` should be positive" . to_string ( ) ;
252+ tracing:: error!( "{message}" ) ;
253+ return Err ( TextEmbeddingsError :: Validation ( message) ) ;
254+ }
255+ }
256+
249257 let results = self
250258 . embed (
251259 inputs,
@@ -262,6 +270,21 @@ impl Infer {
262270 panic ! ( "unexpected enum variant" )
263271 } ;
264272
273+ if let Some ( mrl_dimensions) = dimensions {
274+ if mrl_dimensions > response. results . len ( ) {
275+ metrics:: counter!( "te_request_failure" , "err" => "validation" ) . increment ( 1 ) ;
276+
277+ let message =
278+ "`dimensions` should be smaller than the maximum embedding dimension."
279+ . to_string ( ) ;
280+ tracing:: error!( "{message}" ) ;
281+
282+ return Err ( TextEmbeddingsError :: Validation ( message) ) ;
283+ }
284+
285+ response. results . truncate ( mrl_dimensions) ;
286+ }
287+
265288 if normalize {
266289 // Normalize embedding
267290 let scale = ( 1.0
0 commit comments