@@ -627,7 +627,7 @@ def filter(self, indices):
627627
628628class LogitBiasProcessor (LogitsProcessor ):
629629 """
630- `LogitsProcessor ` creates a bias tensor from a dictionary of token IDs and their
630+ `LogitBiasProcessor ` creates a bias tensor from a dictionary of token IDs and their
631631 corresponding bias values. Bias are applied to the logits during each forward pass.
632632
633633 Supports token IDs provided as strings (e.g., {"9707": -100}).
@@ -656,7 +656,7 @@ def __init__(
656656 def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
657657 # Apply bias tensor as a broadcasted addition
658658 if self .bias_tensor .shape [0 ] != scores .shape [1 ]:
659- # Fix if the bias tensor is smaller than the scores
659+ # Pad the bias matrix to match the scores if it's smaller
660660 self .bias_tensor = torch .nn .functional .pad (
661661 self .bias_tensor , (0 , scores .shape [1 ] - self .bias_tensor .shape [0 ])
662662 )
@@ -699,7 +699,7 @@ def __init__(
699699 def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
700700 # Apply bias matrix as a broadcasted addition
701701 if self .bias_matrix .shape [1 ] != scores .shape [1 ]:
702- # Fix if the bias matrix is smaller than the scores
702+ # Pad the bias matrix to match the scores if it's smaller
703703 self .bias_matrix = torch .nn .functional .pad (
704704 self .bias_matrix , (0 , scores .shape [1 ] - self .bias_matrix .shape [1 ])
705705 )
0 commit comments