Germano Cavalcante commited on
Commit
25dbca2
1 Parent(s): 1b8973e

Fix tool_calls calling wiki_search without groups

Browse files
routers/embedding/embeddings_issues.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c3c012a8f86440dacedd6f1e4e9ea9f41f096031c0ac1ed5cdf64a9a8d46e42
3
- size 723452942
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c790700eeb373e6f80e8a52e3ac3eafd7591f2782a1fcb589a197769e25dfc7
3
+ size 723465376
routers/tool_calls.py CHANGED
@@ -5,10 +5,10 @@ from typing import List, Dict
5
  from pydantic import BaseModel
6
 
7
  try:
8
- from .tool_gpu_checker import gpu_checker_get_message
9
- from .tool_bpy_doc import bpy_doc_get_documentation
10
- from .tool_find_related import find_related
11
- from .tool_wiki_search import wiki_search
12
  except:
13
  from routers.tool_gpu_checker import gpu_checker_get_message
14
  from routers.tool_bpy_doc import bpy_doc_get_documentation
@@ -46,7 +46,7 @@ def process_tool_call(tool_call: ToolCallInput) -> Dict:
46
  output["output"] = find_related(
47
  function_args["repo"], function_args["number"])
48
  elif function_name == "wiki_search":
49
- output["output"] = wiki_search(function_args["query"])
50
  except json.JSONDecodeError as e:
51
  error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
52
  output["output"] = error_message
@@ -95,6 +95,14 @@ if __name__ == "__main__":
95
  "name": "find_related",
96
  "arguments": "{\"repo\":\"blender\",\"number\":111434}"
97
  }
 
 
 
 
 
 
 
 
98
  }
99
  ]
100
 
 
5
  from pydantic import BaseModel
6
 
7
  try:
8
+ from tool_gpu_checker import gpu_checker_get_message
9
+ from tool_bpy_doc import bpy_doc_get_documentation
10
+ from tool_find_related import find_related
11
+ from tool_wiki_search import wiki_search
12
  except:
13
  from routers.tool_gpu_checker import gpu_checker_get_message
14
  from routers.tool_bpy_doc import bpy_doc_get_documentation
 
46
  output["output"] = find_related(
47
  function_args["repo"], function_args["number"])
48
  elif function_name == "wiki_search":
49
+ output["output"] = wiki_search(**function_args)
50
  except json.JSONDecodeError as e:
51
  error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}"
52
  output["output"] = error_message
 
95
  "name": "find_related",
96
  "arguments": "{\"repo\":\"blender\",\"number\":111434}"
97
  }
98
+ },
99
+ {
100
+ "id": "call_abc101112",
101
+ "type": "function",
102
+ "function": {
103
+ "name": "wiki_search",
104
+ "arguments": "{\"query\":\"Set Snap Base\",\"groups\":[\"manual\"]}"
105
+ }
106
  }
107
  ]
108
 
routers/tool_wiki_search.py CHANGED
@@ -6,7 +6,7 @@ import pickle
6
  import re
7
  import torch
8
  from enum import Enum
9
- from fastapi import APIRouter, Query
10
  from fastapi.responses import PlainTextResponse
11
  from heapq import nlargest
12
  from sentence_transformers import util
@@ -177,6 +177,9 @@ def wiki_search(
177
  query: str = "",
178
  groups: Set[Group] = Query(default={Group.dev_docs, Group.manual})
179
  ) -> str:
 
 
 
180
  texts = G_data._sort_similarity(query, groups)
181
  result: str = ''
182
  for text in texts:
@@ -187,5 +190,5 @@ def wiki_search(
187
  if __name__ == '__main__':
188
  tests = ["Set Snap Base", "Building the Manual",
189
  "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
190
- result = wiki_search(tests[0], {Group.dev_docs, Group.manual})
191
  print(result)
 
6
  import re
7
  import torch
8
  from enum import Enum
9
+ from fastapi import APIRouter, Query, params
10
  from fastapi.responses import PlainTextResponse
11
  from heapq import nlargest
12
  from sentence_transformers import util
 
177
  query: str = "",
178
  groups: Set[Group] = Query(default={Group.dev_docs, Group.manual})
179
  ) -> str:
180
+ if isinstance(groups, params.Query):
181
+ groups = groups.default
182
+
183
  texts = G_data._sort_similarity(query, groups)
184
  result: str = ''
185
  for text in texts:
 
190
  if __name__ == '__main__':
191
  tests = ["Set Snap Base", "Building the Manual",
192
  "Bisect Object", "Who are the Triagers", "4.3 Release Notes Motion Paths"]
193
+ result = wiki_search(tests[0])
194
  print(result)