| | from transformers import AutoTokenizer |
| |
|
| |
|
| | class BaseStreamer: |
| | """ |
| | Base class from which `.generate()` streamers should inherit. |
| | """ |
| |
|
| | def put(self, value): |
| | """Function that is called by `.generate()` to push new tokens""" |
| | raise NotImplementedError() |
| |
|
| | def end(self): |
| | """Function that is called by `.generate()` to signal the end of generation""" |
| | raise NotImplementedError() |
| |
|
| |
|
| | class ByteStreamer(BaseStreamer): |
| | """ |
| | Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. |
| | |
| | <Tip warning={true}> |
| | |
| | The API for the streamer classes is still under development and may change in the future. |
| | |
| | </Tip> |
| | |
| | Parameters: |
| | tokenizer (`AutoTokenizer`): |
| | The tokenized used to decode the tokens. |
| | skip_prompt (`bool`, *optional*, defaults to `False`): |
| | Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. |
| | decode_kwargs (`dict`, *optional*): |
| | Additional keyword arguments to pass to the tokenizer's `decode` method. |
| | |
| | Examples: |
| | |
| | ```python |
| | >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| | |
| | >>> tok = AutoTokenizer.from_pretrained("gpt2") |
| | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
| | >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") |
| | >>> streamer = TextStreamer(tok) |
| | |
| | >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. |
| | >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) |
| | An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, |
| | ``` |
| | """ |
| |
|
| | def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): |
| | self.tokenizer = tokenizer |
| | self.skip_prompt = skip_prompt |
| | self.decode_kwargs = decode_kwargs |
| |
|
| | |
| | self.token_cache = [] |
| | self.print_len = 0 |
| | self.next_tokens_are_prompt = True |
| |
|
| | def put(self, value): |
| | """ |
| | Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. |
| | """ |
| | if len(value.shape) > 1 and value.shape[0] > 1: |
| | raise ValueError("TextStreamer only supports batch size 1") |
| | elif len(value.shape) > 1: |
| | value = value[0] |
| |
|
| | if self.skip_prompt and self.next_tokens_are_prompt: |
| | self.next_tokens_are_prompt = False |
| | return |
| |
|
| | |
| | self.token_cache.extend(value.tolist()) |
| | text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) |
| |
|
| | |
| | if text.endswith("\n"): |
| | printable_text = text[self.print_len :] |
| | self.token_cache = [] |
| | self.print_len = 0 |
| | else: |
| | printable_text = text[self.print_len : self.print_len + 1] |
| | self.print_len += len(printable_text) |
| |
|
| | self.on_finalized_text(printable_text) |
| |
|
| | def end(self): |
| | """Flushes any remaining cache and prints a newline to stdout.""" |
| | |
| | if len(self.token_cache) > 0: |
| | text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) |
| | printable_text = text[self.print_len :] |
| | self.token_cache = [] |
| | self.print_len = 0 |
| | else: |
| | printable_text = "" |
| |
|
| | self.next_tokens_are_prompt = True |
| | self.on_finalized_text(printable_text, stream_end=True) |
| |
|
| | def on_finalized_text(self, text: str, stream_end: bool = False): |
| | """Prints the new text to stdout. If the stream is ending, also prints a newline.""" |
| | print(text, flush=True, end="" if not stream_end else None) |
| |
|