Text Generation
Transformers
Safetensors
PyTorch
nvidia
conversational
suhara ameyasunilm commited on
Commit
b90f131
·
verified ·
1 Parent(s): 579351d

Updating streaming tool-call parser to return ChoiceDeltaToolCall (#31)

Browse files

- Updating streaming tool-call parser to return ChoiceDeltaToolCall (ee35ca89f5df3773dc1e1e131000b8a4a138a3c9)


Co-authored-by: Ameya Sunil Mahabaleshwarkar <ameyasunilm@users.noreply.huggingface.co>

Files changed (1) hide show
  1. nemotron_toolcall_parser_streaming.py +436 -186
nemotron_toolcall_parser_streaming.py CHANGED
@@ -1,123 +1,166 @@
1
  import json
2
- import re
3
  from collections.abc import Sequence
4
- from typing import Union, Optional
 
 
5
 
6
  import partial_json_parser
 
 
 
7
 
8
- from vllm.entrypoints.openai.protocol import (
9
- ChatCompletionRequest,
10
- DeltaFunctionCall,
11
- DeltaMessage,
12
- DeltaToolCall,
13
- ExtractedToolCallInformation,
14
- FunctionCall,
15
- ToolCall,
16
- )
17
  from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
- ToolParser,
19
- ToolParserManager,
20
- )
21
  from vllm.logger import init_logger
22
- from vllm.transformers_utils.tokenizer import AnyTokenizer
23
- from vllm.utils import random_uuid
24
 
25
  logger = init_logger(__name__)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @ToolParserManager.register_module("nemotron_json")
29
- class NemotronJSONToolParser(ToolParser):
 
 
 
 
 
30
 
31
  def __init__(self, tokenizer: AnyTokenizer):
32
  super().__init__(tokenizer)
33
-
34
- # Streaming state tracking
35
- self.current_tool_name_sent: bool = False
36
  self.prev_tool_call_arr: list[dict] = []
37
  self.current_tool_id: int = -1
38
- self.streamed_args_for_tool: list[str] = []
39
- self.tool_call_ids: list[str] = [] # Track IDs for each tool call
40
-
41
- # Track what we've sent so far in streaming
42
- self.sent_tool_calls_count: int = 0
43
- self.sent_args_length: dict[int, int] = {} # tool_idx -> length of args sent
 
 
 
 
 
 
44
 
45
- self.tool_call_start_token: str = "<TOOLCALL>"
46
- self.tool_call_end_token: str = "</TOOLCALL>"
 
47
 
48
- self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def extract_tool_calls(
51
  self,
52
  model_output: str,
53
  request: ChatCompletionRequest,
54
  ) -> ExtractedToolCallInformation:
55
- """Extract tool calls from non-streaming (complete) output."""
56
-
57
- if self.tool_call_start_token not in model_output:
58
- return ExtractedToolCallInformation(
59
- tools_called=False,
60
- tool_calls=[],
61
- content=model_output,
62
- )
 
 
 
 
 
 
63
 
64
  try:
65
- # Try to extract complete <TOOLCALL>...</TOOLCALL> blocks
66
- tool_call_matches = self.tool_call_regex.findall(model_output)
67
-
68
- if tool_call_matches:
69
- # Complete tool call block found
70
- str_tool_calls = tool_call_matches[0].strip()
71
- else:
72
- # Incomplete - extract everything after <TOOLCALL>
73
- start_idx = model_output.find(self.tool_call_start_token) + len(self.tool_call_start_token)
74
- str_tool_calls = model_output[start_idx:].strip()
75
-
76
- # Ensure array brackets
77
- if not str_tool_calls.startswith("["):
78
- str_tool_calls = "[" + str_tool_calls
79
- if not str_tool_calls.endswith("]"):
80
- str_tool_calls = str_tool_calls + "]"
81
-
82
- # Use partial JSON parser for incomplete JSON
83
- json_tool_calls = partial_json_parser.loads(str_tool_calls)
84
-
85
- if not isinstance(json_tool_calls, list):
86
- raise ValueError("Tool calls must be a list")
87
-
88
- tool_calls = []
89
-
90
- for tool_call in json_tool_calls:
91
- if not isinstance(tool_call, dict):
92
- continue
93
- try:
94
- tool_calls.append(ToolCall(
95
- type="function",
96
- function=FunctionCall(
97
- name=tool_call.get("name", ""),
98
- arguments=json.dumps(tool_call.get("arguments", {}), ensure_ascii=False) \
99
- if isinstance(tool_call.get("arguments"), dict) else str(tool_call.get("arguments", "")),
100
- ),
101
- ))
102
- except Exception as e:
103
- logger.warning(f"Failed to parse tool call: {e}")
104
- continue
105
-
106
- content = model_output[:model_output.find(self.tool_call_start_token)].strip()
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return ExtractedToolCallInformation(
109
- tools_called=True if tool_calls else False,
110
  tool_calls=tool_calls,
111
- content=content if content else None,
112
- )
113
 
114
- except Exception as e:
115
- logger.exception(f"Error extracting tool calls. Response: {model_output}")
116
- return ExtractedToolCallInformation(
117
- tools_called=False,
118
- tool_calls=[],
119
- content=model_output,
120
- )
121
 
122
  def extract_tool_calls_streaming(
123
  self,
@@ -129,108 +172,315 @@ class NemotronJSONToolParser(ToolParser):
129
  delta_token_ids: Sequence[int],
130
  request: ChatCompletionRequest,
131
  ) -> Union[DeltaMessage, None]:
132
- """Extract tool calls from streaming output.
133
-
134
- This incrementally parses the <TOOLCALL> JSON as it streams in,
135
- sending delta updates for each tool call and its arguments.
136
- """
137
-
138
- # Check if we just started tool calling
139
- if self.tool_call_start_token in delta_text and self.tool_call_start_token not in previous_text:
140
- # First time seeing <TOOLCALL>, return content before it
141
- content_before = delta_text.split(self.tool_call_start_token)[0]
142
- if content_before:
143
- return DeltaMessage(content=content_before)
144
- # Start of tool call section - no delta yet
145
- return None
146
-
147
- # Check if we're not in tool call mode yet
148
- if self.tool_call_start_token not in current_text:
149
- # Regular content, no tool calls
150
- return DeltaMessage(content=delta_text) if delta_text else None
151
-
152
- # We're inside <TOOLCALL>...</TOOLCALL>
153
- # For Nemotron, the entire TOOLCALL block is generated at once
154
- # So we should only parse when we have the complete </TOOLCALL>
155
-
156
- # Check if we have the complete tool call block yet
157
- if self.tool_call_end_token not in current_text:
158
- # Incomplete tool call, don't send deltas yet
159
- return None
160
-
161
- # We have the complete tool call block, parse it
162
- start_idx = current_text.find(self.tool_call_start_token) + len(self.tool_call_start_token)
163
- end_idx = current_text.find(self.tool_call_end_token)
164
- json_str = current_text[start_idx:end_idx].strip()
165
-
166
- # Parse the complete JSON
167
  try:
168
- # Ensure we have array brackets
169
- if not json_str.startswith("["):
170
- json_str = "[" + json_str
171
- if not json_str.endswith("]"):
172
- json_str = json_str + "]"
173
-
174
- # Parse complete JSON
175
- tool_calls_arr = json.loads(json_str)
176
-
177
- if not isinstance(tool_calls_arr, list):
178
  return None
179
 
180
- # Generate delta updates for new/updated tool calls
181
- delta_tool_calls = []
182
-
183
- for idx, tool_call in enumerate(tool_calls_arr):
184
- if not isinstance(tool_call, dict):
185
- continue
186
 
187
- # Ensure we have a tool ID for this call
188
- while len(self.tool_call_ids) <= idx:
189
- self.tool_call_ids.append(random_uuid())
 
 
 
 
190
 
191
- tool_id = self.tool_call_ids[idx]
192
- tool_name = tool_call.get("name", "")
193
- tool_args = tool_call.get("arguments", {})
194
 
195
- # Convert arguments to JSON string
196
- if isinstance(tool_args, dict):
197
- args_str = json.dumps(tool_args, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
- args_str = str(tool_args)
200
-
201
- # Check if this is a new tool call
202
- if idx >= self.sent_tool_calls_count:
203
- # New tool call - send ID, name, and complete arguments all at once
204
- # This matches how other models (Llama, etc.) send tool calls
205
- delta_tool_calls.append(DeltaToolCall(
206
- index=idx,
207
- id=tool_id,
208
- type="function",
209
- function=DeltaFunctionCall(
210
- name=tool_name,
211
- arguments=args_str # Send complete JSON string
212
- )
213
- ))
214
- self.sent_tool_calls_count = idx + 1
215
- self.sent_args_length[idx] = len(args_str)
216
-
217
- # NOTE: We don't send incremental updates for arguments
218
- # because Nemotron generates complete tool calls in one shot
219
- # Unlike thinking models that stream arguments token-by-token
220
-
221
- if delta_tool_calls:
222
- return DeltaMessage(tool_calls=delta_tool_calls)
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- except Exception as e:
225
- # JSON parsing failed (expected for incomplete JSON)
226
- logger.debug(f"Partial JSON parse failed (expected during streaming): {e}")
227
- pass
228
-
229
- # Check if we just completed the tool calls (end tag in this delta)
230
- if self.tool_call_end_token in delta_text and self.tool_call_end_token not in previous_text:
231
- # We just completed - reset state for next potential tool call
232
- self.sent_tool_calls_count = 0
233
- self.sent_args_length = {}
234
- self.tool_call_ids = []
235
-
236
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
2
  from collections.abc import Sequence
3
+ from random import choices
4
+ from string import ascii_letters, digits
5
+ from typing import Union
6
 
7
  import partial_json_parser
8
+ import regex as re
9
+ from partial_json_parser.core.options import Allow
10
+ from pydantic import Field
11
 
12
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
13
+ DeltaFunctionCall, DeltaMessage,
14
+ DeltaToolCall,
15
+ ExtractedToolCallInformation,
16
+ FunctionCall, ToolCall)
 
 
 
 
17
  from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18
+ ToolParser, ToolParserManager)
19
+ from vllm.entrypoints.openai.tool_parsers.utils import (
20
+ extract_intermediate_diff)
21
  from vllm.logger import init_logger
22
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
 
23
 
24
  logger = init_logger(__name__)
25
 
26
+ ALPHANUMERIC = ascii_letters + digits
27
+
28
+
29
+ class NemotronToolCall(ToolCall):
30
+ id: str = Field(
31
+ default_factory=lambda: NemotronToolCall.generate_random_id())
32
+
33
+ @staticmethod
34
+ def generate_random_id():
35
+ return "".join(choices(ALPHANUMERIC, k=9))
36
+
37
+ @staticmethod
38
+ def is_valid_id(id: str) -> bool:
39
+ return id.isalnum() and len(id) == 9
40
+
41
+
42
+ def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
43
+ return isinstance(model_tokenizer, MistralTokenizer) \
44
+ and model_tokenizer.version >= 11
45
+
46
 
47
  @ToolParserManager.register_module("nemotron_json")
48
+ class NemotronToolParser(ToolParser):
49
+ """
50
+ Tool call parser for Nemotron-Nano-V2
51
+
52
+ Used when --enable-auto-tool-choice --tool-call-parser nemotron_json are all set
53
+ """
54
 
55
  def __init__(self, tokenizer: AnyTokenizer):
56
  super().__init__(tokenizer)
57
+ # initialize properties used for state when parsing tool calls in
58
+ # streaming mode
 
59
  self.prev_tool_call_arr: list[dict] = []
60
  self.current_tool_id: int = -1
61
+ self.current_tool_name_sent: bool = False
62
+ self.streamed_args_for_tool: list[str] = [
63
+ ] # map what has been streamed for each tool so far to a list
64
+ self.bot_token = "<TOOLCALL>"
65
+ self.bot_token_id = self.vocab.get(self.bot_token)
66
+ logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}")
67
+ self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
68
+ if _is_fn_name_regex_support(self.model_tokenizer):
69
+ self.fn_name_regex = re.compile(
70
+ r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL)
71
+ else:
72
+ self.fn_name_regex = None
73
 
74
+ # Buffer for partial tag sequences to disambiguate between normal content and
75
+ # a forthcoming <TOOLCALL> or </TOOLCALL> tag in streaming.
76
+ self._pending_tag_buffer: str = ""
77
 
78
+ def adjust_request(
79
+ self, request: ChatCompletionRequest) -> ChatCompletionRequest:
80
+ if not isinstance(
81
+ self.model_tokenizer, MistralTokenizer
82
+ ) and request.tools and request.tool_choice != 'none':
83
+ # Do not skip special tokens when using chat template
84
+ # with Mistral parser as TOOL_CALL token is needed
85
+ # for tool detection.
86
+ # Note: we don't want skip_special_tokens=False
87
+ # with MistralTokenizer as it is incompatible
88
+ request.skip_special_tokens = False
89
+ return request
90
 
91
  def extract_tool_calls(
92
  self,
93
  model_output: str,
94
  request: ChatCompletionRequest,
95
  ) -> ExtractedToolCallInformation:
96
+ """
97
+ Extract the tool calls from a complete model response. Requires
98
+ find-and-replacing single quotes with double quotes for JSON parsing,
99
+ make sure your tool call arguments don't ever include quotes!
100
+ """
101
+
102
+ # case -- if a tool call token is not present, return a text response
103
+ if self.bot_token not in model_output:
104
+ return ExtractedToolCallInformation(tools_called=False,
105
+ tool_calls=[],
106
+ content=model_output)
107
+
108
+ # first remove the BOT token
109
+ tool_content = model_output.replace(self.bot_token, "").strip()
110
 
111
  try:
112
+ # we first try to directly load the json as parsing very nested
113
+ # jsons is difficult
114
+ try:
115
+ if self.fn_name_regex:
116
+ matches = self.fn_name_regex.findall(tool_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ function_call_arr = []
119
+ for match in matches:
120
+ fn_name = match[0]
121
+ args = match[1]
122
+
123
+ # fn_name is encoded outside serialized json dump
124
+ # only arguments are serialized
125
+ function_call_arr.append({
126
+ "name": fn_name,
127
+ "arguments": json.loads(args)
128
+ })
129
+ else:
130
+ function_call_arr = json.loads(tool_content)
131
+ except json.JSONDecodeError:
132
+ # use a regex to find the part corresponding to the tool call.
133
+ # NOTE: This use case should not happen if the model is trained
134
+ # correctly. It's a easy possible fix so it's included, but
135
+ # can be brittle for very complex / highly nested tool calls
136
+ raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
137
+ function_call_arr = json.loads(raw_tool_call)
138
+
139
+ # Tool Call
140
+ tool_calls: list[NemotronToolCall] = [
141
+ NemotronToolCall(
142
+ type="function",
143
+ function=FunctionCall(
144
+ name=raw_function_call["name"],
145
+ # function call args are JSON but as a string
146
+ arguments=json.dumps(raw_function_call["arguments"],
147
+ ensure_ascii=False)))
148
+ for raw_function_call in function_call_arr
149
+ ]
150
+
151
+ # get any content before the tool call
152
+ content = model_output.split(self.bot_token)[0]
153
  return ExtractedToolCallInformation(
154
+ tools_called=True,
155
  tool_calls=tool_calls,
156
+ content=content if len(content) > 0 else None)
 
157
 
158
+ except Exception:
159
+ logger.exception("Error in extracting tool call from response.")
160
+ # return information to just treat the tool call as regular JSON
161
+ return ExtractedToolCallInformation(tools_called=False,
162
+ tool_calls=[],
163
+ content=tool_content)
 
164
 
165
  def extract_tool_calls_streaming(
166
  self,
 
172
  delta_token_ids: Sequence[int],
173
  request: ChatCompletionRequest,
174
  ) -> Union[DeltaMessage, None]:
175
+ # if candidates tool call tokens are in the tokens generated so far, that
176
+ # means we're parsing as tool calls now. Suppress streaming if we are
177
+ # currently generating any prefix of the start or end tag.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  try:
179
+ start_token = self.bot_token
180
+ end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None
181
+
182
+ # Handle potential start of tool call tags by buffering partial sequences
183
+ if delta_text == '<' and not self._pending_tag_buffer:
184
+ # Start buffering a potential tag
185
+ self._pending_tag_buffer = '<'
 
 
 
186
  return None
187
 
188
+ # If we have a pending tag buffer, accumulate and decide
189
+ if self._pending_tag_buffer:
190
+ # Accumulate the current token into the buffer
191
+ self._pending_tag_buffer += delta_text
 
 
192
 
193
+ # Extract just the alphabetic part after '<'
194
+ alpha_part = ""
195
+ for i in range(1, len(self._pending_tag_buffer)):
196
+ if self._pending_tag_buffer[i].isalpha():
197
+ alpha_part += self._pending_tag_buffer[i].upper()
198
+ else:
199
+ break
200
 
 
 
 
201
 
202
+ # Check if we have a complete opening tag '<TOOLCALL>'
203
+ if '<TOOLCALL>' in self._pending_tag_buffer:
204
+ # We have the complete opening tag - stop buffering and let normal processing take over
205
+ buffered_content = self._pending_tag_buffer
206
+ self._pending_tag_buffer = ""
207
+
208
+ # Update the text variables to include the buffered content
209
+ updated_current_text = previous_text + buffered_content
210
+ updated_delta_text = buffered_content # The entire buffered content is the delta
211
+
212
+ # Continue processing with the complete tool call content
213
+ current_text = updated_current_text
214
+ delta_text = updated_delta_text
215
+ # Fall through to normal processing
216
+ elif self._pending_tag_buffer.startswith('</'):
217
+ # End tag pattern - keep buffering until we see if it's a valid end tag
218
+ return None
219
+ elif alpha_part and "TOOLCALL".startswith(alpha_part) and len(alpha_part) < 8:
220
+ # Could be building to TOOLCALL and haven't completed it yet - keep buffering
221
+ return None
222
+ elif len(alpha_part) > 0 and not "TOOLCALL".startswith(alpha_part):
223
+ # Alphabetic content that definitely won't become TOOLCALL - flush as content
224
+ content_to_flush = self._pending_tag_buffer
225
+ self._pending_tag_buffer = ""
226
+ return DeltaMessage(content=content_to_flush)
227
  else:
228
+ # Keep buffering - not enough info yet
229
+ return None
230
+
231
+ # Suppress ANY partial prefix of the start/end tag to avoid leaking tag characters.
232
+ if any(current_text.endswith(start_token[:k]) for k in range(1, len(start_token))):
233
+ return None
234
+ if end_token and any(current_text.endswith(end_token[:k]) for k in range(1, len(end_token))):
235
+ return None
236
+ except Exception:
237
+ # Fallback to conservative checks in case of any issues
238
+ if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'):
239
+ return None
240
+
241
+ # if the tool call token is not in the tokens generated so far, append
242
+ # output to contents since it's not a tool
243
+ if self.bot_token not in current_text:
244
+ # If we were buffering a partial tag and reached here, flush it first.
245
+ if self._pending_tag_buffer:
246
+ content_to_flush = self._pending_tag_buffer + delta_text
247
+ self._pending_tag_buffer = ""
248
+ return DeltaMessage(content=content_to_flush)
249
+ return DeltaMessage(content=delta_text)
250
+
251
+ # bit mask flags for partial JSON parsing. If the name hasn't been
252
+ # sent yet, don't allow sending
253
+ # an incomplete string since OpenAI only ever (as far as I have
254
+ # seen) allows sending the entire tool/ function name at once.
255
+ flags = Allow.ALL if self.current_tool_name_sent \
256
+ else Allow.ALL & ~Allow.STR
257
+ end_of_call: bool = False
258
+ try:
259
+
260
+ # replace BOT token with empty string, and convert single quotes
261
+ # to double to allow parsing as JSON since mistral uses single
262
+ # quotes instead of double for tool calls
263
+ parsable_arr = current_text.split(self.bot_token)[-1]
264
 
265
+ # Check if we're at the end of the tool call
266
+ if '</TOOLCALL>' in parsable_arr:
267
+ end_of_call = True
268
+ parsable_arr = parsable_arr.split('</TOOLCALL>')[0]
269
+
270
+ # tool calls are generated in an array, so do partial JSON
271
+ # parsing on the entire array
272
+ try:
273
+ tool_call_arr: list[dict] = partial_json_parser.loads(
274
+ parsable_arr, flags)
275
+ except partial_json_parser.core.exceptions.MalformedJSON:
276
+ return None
277
+
278
+ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
279
+ if len(tool_call_arr) > 0 else {}
280
+
281
+ # case -- if no tokens have been streamed for the tool, e.g.
282
+ # only the array brackets, stream nothing
283
+ if len(tool_call_arr) == 0:
284
+ return None
285
+
286
+ # case: we are starting a new tool in the array
287
+ # -> array has > 0 length AND length has moved past cursor
288
+ elif (len(tool_call_arr) > 0
289
+ and len(tool_call_arr) > self.current_tool_id + 1):
290
+
291
+ # if we're moving on to a new call, first make sure we
292
+ # haven't missed anything in the previous one that was
293
+ # auto-generated due to JSON completions, but wasn't
294
+ # streamed to the client yet.
295
+ if self.current_tool_id >= 0:
296
+ diff: Union[str, None] = current_tool_call.get("arguments")
297
+
298
+ if diff:
299
+ diff = json.dumps(diff, ensure_ascii=False).replace(
300
+ self.streamed_args_for_tool[self.current_tool_id],
301
+ "")
302
+ delta = DeltaMessage(tool_calls=[
303
+ DeltaToolCall(index=self.current_tool_id,
304
+ function=DeltaFunctionCall(
305
+ arguments=diff).model_dump(
306
+ exclude_none=True))
307
+ ])
308
+ self.streamed_args_for_tool[
309
+ self.current_tool_id] += diff
310
+ else:
311
+ delta = None
312
+ else:
313
+ delta = None
314
+ # re-set stuff pertaining to progress in the current tool
315
+ self.current_tool_id = len(tool_call_arr) - 1
316
+ self.current_tool_name_sent = False
317
+ self.streamed_args_for_tool.append("")
318
+ return delta
319
+
320
+ # case: update an existing tool - this is handled below
321
+
322
+ # if the current tool name hasn't been sent, send if available
323
+ # - otherwise send nothing
324
+ if not self.current_tool_name_sent:
325
+ function_name = current_tool_call.get("name")
326
+ if function_name:
327
+
328
+ delta = DeltaMessage(tool_calls=[
329
+ DeltaToolCall(index=self.current_tool_id,
330
+ type="function",
331
+ id=NemotronToolCall.generate_random_id(),
332
+ function=DeltaFunctionCall(
333
+ name=function_name).model_dump(
334
+ exclude_none=True))
335
+ ])
336
+ self.current_tool_name_sent = True
337
+ else:
338
+ delta = None
339
+
340
+ # now we know we're on the same tool call and we're streaming
341
+ # arguments
342
+ else:
343
+
344
+ prev_arguments = self.prev_tool_call_arr[
345
+ self.current_tool_id].get("arguments")
346
+ cur_arguments = current_tool_call.get("arguments")
347
+
348
+ new_text = delta_text.replace("\'", "\"")
349
+ if ('"}' in new_text):
350
+ new_text = new_text[:new_text.rindex('"}')]
351
+
352
+ if not cur_arguments and not prev_arguments:
353
+
354
+ delta = None
355
+ elif not cur_arguments and prev_arguments:
356
+ logger.error(
357
+ "INVARIANT - impossible to have arguments reset "
358
+ "mid-arguments")
359
+ delta = None
360
+ elif cur_arguments and not prev_arguments:
361
+ cur_arguments_json = json.dumps(cur_arguments,
362
+ ensure_ascii=False)
363
+ streamed_prefix = self.streamed_args_for_tool[
364
+ self.current_tool_id]
365
+
366
+ # The issue: partial JSON parser auto-completes incomplete strings
367
+ # e.g., {"location": " becomes {"location": ""} in parsed result
368
+ # We need to handle this by detecting when the parsed result has auto-completed empty strings
369
+
370
+ # Check if this looks like an auto-completed partial string
371
+ if (cur_arguments_json.endswith('": ""}') and
372
+ not streamed_prefix and
373
+ '": ""' in cur_arguments_json):
374
+ # This is likely auto-completed - remove the auto-completed empty string
375
+ # e.g., {"location": ""} -> {"location": "
376
+ closing_pos = cur_arguments_json.rfind('": ""}')
377
+ if closing_pos != -1:
378
+ arguments_delta = cur_arguments_json[:closing_pos + 4] # Keep up to ": "
379
+ else:
380
+ arguments_delta = cur_arguments_json
381
+ else:
382
+ # Normal case - use diff calculation
383
+ if cur_arguments_json.startswith(streamed_prefix):
384
+ arguments_delta = cur_arguments_json[len(streamed_prefix):]
385
+ else:
386
+ # Fallback: compute diff when prefix does not match.
387
+ arguments_delta = extract_intermediate_diff(
388
+ cur_arguments_json, streamed_prefix)
389
+
390
+ # Do not include a trailing '}' in the very first
391
+ # arguments chunk; defer it to the end-of-call flush to
392
+ # avoid prematurely closing the JSON object.
393
+ if (not self.streamed_args_for_tool[self.current_tool_id]
394
+ and not end_of_call and arguments_delta
395
+ and arguments_delta.endswith('}')):
396
+ arguments_delta = arguments_delta[:-1]
397
+ # if there is an auto-completed closing quote '"' before the }, strip it too
398
+ # e.g., {"color_hex": "#"} -> {"color_hex": "#"} -> {"color_hex": "#"}
399
+ if arguments_delta.endswith('"'):
400
+ arguments_delta = arguments_delta[:-1]
401
+ if arguments_delta:
402
+ delta = DeltaMessage(tool_calls=[
403
+ DeltaToolCall(index=self.current_tool_id,
404
+ function=DeltaFunctionCall(
405
+ arguments=arguments_delta).
406
+ model_dump(exclude_none=True))
407
+ ])
408
+ self.streamed_args_for_tool[
409
+ self.current_tool_id] += arguments_delta
410
+ else:
411
+ delta = None
412
+
413
+ elif cur_arguments and prev_arguments:
414
+ cur_args_json = json.dumps(cur_arguments,
415
+ ensure_ascii=False)
416
+ prev_args_json = json.dumps(prev_arguments,
417
+ ensure_ascii=False)
418
+
419
+ argument_diff = extract_intermediate_diff(
420
+ cur_args_json, prev_args_json)
421
+ if argument_diff:
422
+ delta = DeltaMessage(tool_calls=[
423
+ DeltaToolCall(index=self.current_tool_id,
424
+ function=DeltaFunctionCall(
425
+ arguments=argument_diff).model_dump(
426
+ exclude_none=True))
427
+ ])
428
+ self.streamed_args_for_tool[
429
+ self.current_tool_id] += argument_diff
430
+ else:
431
+ # Do not flush final JSON here; let the serving layer
432
+ # compute a minimal remaining suffix on finish.
433
+ delta = None
434
+ else:
435
+ # End-of-call or equal state; do not force a final flush here.
436
+ delta = None
437
+
438
+ # check to see if the name is defined and has been sent. if so,
439
+ # stream the name - otherwise keep waiting
440
+ # finish by setting old and returning None as base case
441
+ self.prev_tool_call_arr = tool_call_arr
442
+ # If we've reached the end of a tool call, flush any remaining
443
+ # suffix (including a final '}') that hasn't been streamed yet.
444
+ if end_of_call and self.current_tool_id >= 0:
445
+ try:
446
+ cur_arguments = current_tool_call.get("arguments")
447
+ if cur_arguments is not None:
448
+ cur_args_json = json.dumps(cur_arguments,
449
+ ensure_ascii=False)
450
+ streamed_prefix = self.streamed_args_for_tool[
451
+ self.current_tool_id]
452
+
453
+ if cur_args_json.startswith(streamed_prefix):
454
+ remaining_suffix = cur_args_json[len(
455
+ streamed_prefix):]
456
+ else:
457
+ remaining_suffix = extract_intermediate_diff(
458
+ cur_args_json, streamed_prefix)
459
+
460
+ # Only send remaining suffix if it's non-empty and contains meaningful content
461
+ # (not just whitespace or single characters like closing braces)
462
+ if remaining_suffix and remaining_suffix.strip() and len(remaining_suffix.strip()) > 0:
463
+ extra = DeltaToolCall(
464
+ index=self.current_tool_id,
465
+ function=DeltaFunctionCall(
466
+ arguments=remaining_suffix).model_dump(
467
+ exclude_none=True))
468
+ if delta is None:
469
+ delta = DeltaMessage(tool_calls=[extra])
470
+ else:
471
+ if getattr(delta, "tool_calls", None):
472
+ delta.tool_calls.append(extra)
473
+ else:
474
+ delta.tool_calls = [extra]
475
+ self.streamed_args_for_tool[
476
+ self.current_tool_id] += remaining_suffix
477
+ else:
478
+ pass
479
+ except Exception:
480
+ pass
481
+
482
+ return delta
483
+
484
+ except Exception:
485
+ logger.exception("Error trying to handle streaming tool call.")
486
+ return None