Add stream output support

#21
Files changed (1) hide show
  1. modeling_minicpm.py +64 -8
modeling_minicpm.py CHANGED
@@ -22,12 +22,14 @@ import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
 
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache
33
  from transformers.modeling_attn_mask_utils import (
@@ -1248,6 +1250,9 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1248
  self.vocab_size = config.vocab_size
1249
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1250
 
 
 
 
1251
  # Initialize weights and apply final processing
1252
  self.post_init()
1253
 
@@ -1426,11 +1431,52 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1426
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1427
  )
1428
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1429
 
1430
  @torch.inference_mode()
1431
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1432
- max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1433
- **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434
  if history is None:
1435
  history = []
1436
  if logits_processor:
@@ -1443,12 +1489,22 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
1443
  history.append({"role": role, "content": query})
1444
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
1445
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1446
- outputs = self.generate(**inputs, **gen_kwargs)
1447
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1448
- response = tokenizer.decode(outputs)
1449
- history.append({"role": "assistant", "content": response})
1450
- return response, history
1451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1452
 
1453
  @add_start_docstrings(
1454
  """
 
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
 
25
+ from threading import Thread
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
+ from transformers import TextIteratorStreamer
33
  from transformers.activations import ACT2FN
34
  from transformers.cache_utils import Cache, DynamicCache
35
  from transformers.modeling_attn_mask_utils import (
 
1250
  self.vocab_size = config.vocab_size
1251
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1252
 
1253
+ # List of terminator tokens used to indicate the end of a sequence or conversation.
1254
+ self.terminators = ['</s>', '<|im_end|>']
1255
+
1256
  # Initialize weights and apply final processing
1257
  self.post_init()
1258
 
 
1431
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1432
  )
1433
  return reordered_past
1434
+
1435
+ # Internal function to handle streaming of generated text using TextIteratorStreamer.
1436
+ def _decode_stream(self, input_ids, tokenizer, **kwargs):
1437
+ # Convert terminators to token IDs
1438
+ terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
1439
+ # Initialize TextIteratorStreamer for handling streaming output
1440
+ streamer = TextIteratorStreamer(tokenizer=tokenizer,skip_prompt=True, skip_special_tokens=True)
1441
+ # Set up generation parameters, including input IDs, eos token IDs, and streamer
1442
+ generation_kwargs = {
1443
+ 'input_ids': input_ids,
1444
+ 'eos_token_id': terminators_ids,
1445
+ 'streamer': streamer
1446
+ }
1447
+ generation_kwargs.update(kwargs)
1448
+ # Run the generation task in a separate thread to enable streaming output
1449
+ thread = Thread(target=self.generate, kwargs=generation_kwargs)
1450
+ thread.start()
1451
+ # Return the streamer instance for later access to streamed text
1452
+ return streamer
1453
+
1454
 
1455
  @torch.inference_mode()
1456
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", max_length: int = 4096, num_beams=1,
1457
+ do_sample=True, logits_processor=None, stream=False, top_p=0.8, temperature=0.3, **kwargs):
1458
+ """
1459
+ Main function for handling dialogue generation based on the input query and history.
1460
+
1461
+ Parameters:
1462
+ - tokenizer: Tokenizer instance used for encoding and decoding.
1463
+ - query: The user input query string.
1464
+ - history: Dialogue history, a list of dictionaries where each dictionary contains role and content.
1465
+ - role: The current role, default is "user".
1466
+ - max_length: Maximum length of the generated text.
1467
+ - num_beams: Number of beams for beam search.
1468
+ - do_sample: Whether to use sampling for generation.
1469
+ - logits_processor: Function for processing logits (if any).
1470
+ - stream: Whether to use streaming output.
1471
+ - top_p: Nucleus sampling parameter.
1472
+ - temperature: Temperature parameter for generation.
1473
+ - **kwargs: Additional arguments for generation.
1474
+
1475
+ Returns:
1476
+ - If stream is True, returns a generator function to get the generated text incrementally.
1477
+ - If stream is False, returns the complete generated response string.
1478
+ """
1479
+
1480
  if history is None:
1481
  history = []
1482
  if logits_processor:
 
1489
  history.append({"role": role, "content": query})
1490
  history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
1491
  inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
 
 
 
 
 
1492
 
1493
+ if stream:
1494
+ res = self._decode_stream(inputs["input_ids"], tokenizer, **gen_kwargs)
1495
+ def stream_gen():
1496
+ for text in res:
1497
+ # Remove terminators from the text
1498
+ for term in self.terminators:
1499
+ text = text.replace(term, '')
1500
+ yield text
1501
+ return stream_gen()
1502
+
1503
+ else:
1504
+ outputs = self.generate(**inputs, **gen_kwargs)
1505
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1506
+ response = tokenizer.decode(outputs)
1507
+ return response
1508
 
1509
  @add_start_docstrings(
1510
  """