@@ -92,6 +92,8 @@ class TransformersModelConfig(ModelConfig):
9292 Additional keyword arguments passed to `from_pretrained`. Defaults to empty dict.
9393 add_special_tokens (bool):
9494 Whether to add special tokens during tokenization. Defaults to True.
95+ skip_special_tokens (bool):
96+ Whether the tokenizer should output special tokens back during generation. Needed for reasoning models. Defaults to True
9597 model_parallel (bool | None):
9698 Whether to use model parallelism across multiple GPUs. If None, automatically
9799 determined based on available GPUs and model size.
@@ -139,6 +141,7 @@ class TransformersModelConfig(ModelConfig):
139141 max_length : PositiveInt | None = None
140142 model_loading_kwargs : dict = Field (default_factory = dict )
141143 add_special_tokens : bool = True
144+ skip_special_tokens : bool = True
142145 model_parallel : bool | None = None
143146 dtype : str | None = None
144147 device : Union [int , str ] = "cuda"
@@ -187,6 +190,7 @@ def __init__(
187190 self ._device = self .accelerator .device
188191 self .multichoice_continuations_start_space = config .multichoice_continuations_start_space
189192 self ._add_special_tokens = config .add_special_tokens or False
193+ self .skip_special_tokens = config .skip_special_tokens or True
190194 self .pairwise_tokenization = config .pairwise_tokenization
191195 self .batch_size = config .batch_size
192196 self .continuous_batching = config .continuous_batching
@@ -244,6 +248,7 @@ def from_model(
244248 tokenizer_name : str = None , # custom tokenizer
245249 trust_remote_code : bool = False ,
246250 add_special_tokens : bool = True ,
251+ skip_special_tokens : bool = True ,
247252 pairwise_tokenization : bool = False ,
248253 multichoice_continuations_start_space : bool = None ,
249254 ):
@@ -280,6 +285,7 @@ def from_model(
280285
281286 self .use_chat_template = uses_chat_template (self ._tokenizer )
282287 self ._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
288+ self .skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
283289 self .pairwise_tokenization = pairwise_tokenization
284290 self .multichoice_continuations_start_space = multichoice_continuations_start_space
285291
@@ -396,6 +402,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
396402 revision = revision ,
397403 max_memory = max_memory ,
398404 device_map = device_map ,
405+ # tp_plan="auto",
399406 torch_dtype = torch_dtype ,
400407 trust_remote_code = self .config .trust_remote_code ,
401408 ** kwargs ,
@@ -595,7 +602,9 @@ def _continuous_greedy_until(
595602 # for output in _output.outputs:
596603 output_token_ids .append (_output .generated_tokens )
597604 # logprobs_raw.append(output.logprobs)
598- result .append (self .tokenizer .decode (_output .generated_tokens ))
605+ result .append (
606+ self .tokenizer .decode (_output .generated_tokens , skip_special_tokens = self .skip_special_tokens )
607+ )
599608
600609 if logprobs_raw and output_token_ids and False :
601610 logprobs = [logprobs_raw [0 ][token_id ].logprob for token_id in output_token_ids [0 ]]
@@ -646,7 +655,9 @@ def _padded_greedy_until(
646655 tokenized_context = self .tokenizer (context )
647656
648657 # Longest context in the current split is the first item (since we sort reversed)
649- longest_context_continuation_size_in_split = len (tokenized_context ) + split [0 ].generation_size
658+ longest_context_continuation_size_in_split = (
659+ len (tokenized_context ["input_ids" ]) + split [0 ].generation_size
660+ )
650661 max_context_continuation_size_allowed = min (
651662 longest_context_continuation_size_in_split , self .max_length
652663 )
@@ -669,12 +680,12 @@ def _padded_greedy_until(
669680
670681 # For chat models, generation stops with EOS token, so we don't need to specify stop tokens
671682 if self .use_chat_template :
672- stop_tokens = []
683+ stop_tokens = [self . tokenizer . eos_token ]
673684 else :
674685 # NOTE: we are assuming all items in a batch behave similarly (same
675686 # stop_tokens and max_tokens genrated) which is not necessarily
676687 # the case! Because of that we only use batch size of 1
677- stop_tokens = batch [0 ].stop_sequences
688+ stop_tokens = [ self . tokenizer . eos_token ] + batch [0 ].stop_sequences
678689
679690 max_new_tokens = batch [0 ].generation_size
680691 num_samples = batch [0 ].num_samples
@@ -1189,6 +1200,9 @@ def pad_and_gather(
11891200 output_tensor = self .accelerator .gather (output_tensor )
11901201 return output_tensor , length_tensor
11911202
1203+ def tok_decode (self , tokens : torch .LongTensor ) -> list [str ]:
1204+ return self .tokenizer .batch_decode (tokens , skip_special_tokens = self .skip_special_tokens )
1205+
11921206
11931207class MultiTokenEOSCriteria (transformers .StoppingCriteria ):
11941208 """Criteria to stop on the specified multi-token sequence."""
0 commit comments