bkb2135 commited on
Commit
c22f824
·
1 Parent(s): c9cda2b

Clean up repo more

Browse files
.gitignore CHANGED
@@ -159,4 +159,6 @@ cython_debug/
159
  testing/
160
  core
161
  app.config.js
162
- .vscode
 
 
 
159
  testing/
160
  core
161
  app.config.js
162
+ .vscode
163
+ wandb/
164
+ _saved_runs.csv
common/__init__.py DELETED
File without changes
common/middlewares.py DELETED
@@ -1,42 +0,0 @@
1
- import os
2
- import json
3
- import bittensor as bt
4
- from aiohttp.web import Request, Response, middleware
5
-
6
- EXPECTED_ACCESS_KEY = os.environ.get("EXPECTED_ACCESS_KEY")
7
-
8
-
9
- @middleware
10
- async def api_key_middleware(request: Request, handler):
11
- if request.path.startswith("/docs") or request.path.startswith("/static/swagger"):
12
- # Skip checks when accessing OpenAPI documentation.
13
- return await handler(request)
14
-
15
- # Logging the request
16
- bt.logging.info(f"Handling {request.method} request to {request.path}")
17
-
18
- # Check access key
19
- access_key = request.headers.get("api_key")
20
- if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
21
- bt.logging.error(f"Invalid access key: {access_key}")
22
- return Response(status=401, reason="Invalid access key")
23
-
24
- # Continue to the next handler if the API key is valid
25
- return await handler(request)
26
-
27
-
28
- @middleware
29
- async def json_parsing_middleware(request: Request, handler):
30
- if request.path.startswith("/docs") or request.path.startswith("/static/swagger"):
31
- # Skip checks when accessing OpenAPI documentation.
32
- return await handler(request)
33
-
34
- try:
35
- # Parsing JSON data from the request
36
- request["data"] = await request.json()
37
- except json.JSONDecodeError as e:
38
- bt.logging.error(f"Invalid JSON data: {str(e)}")
39
- return Response(status=400, text="Invalid JSON")
40
-
41
- # Continue to the next handler if JSON is successfully parsed
42
- return await handler(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/schemas.py DELETED
@@ -1,29 +0,0 @@
1
- from marshmallow import Schema, fields
2
-
3
-
4
- class QueryChatSchema(Schema):
5
- k = fields.Int(description="The number of miners from which to request responses.")
6
- exclude = fields.List(fields.Str(), description="A list of roles or agents to exclude from querying.")
7
- roles = fields.List(fields.Str(), required=True, description="The roles of the agents to query.")
8
- messages = fields.List(fields.Str(), required=True, description="The messages to be sent to the network.")
9
- timeout = fields.Int(description="The time in seconds to wait for a response.")
10
- prefer = fields.Str(description="The preferred response format, can be either 'longest' or 'shortest'.")
11
- sampling_mode = fields.Str(
12
- description="The mode of sampling to use, defaults to 'random'. Can be either 'random' or 'top_incentive'.")
13
-
14
-
15
- class StreamChunkSchema(Schema):
16
- delta = fields.Str(required=True, description="The new chunk of response received.")
17
- finish_reason = fields.Str(description="The reason for the response completion, if applicable.")
18
- accumulated_chunks = fields.List(fields.Str(), description="All accumulated chunks of responses.")
19
- accumulated_chunks_timings = fields.List(fields.Float(), description="Timing for each chunk received.")
20
- timestamp = fields.Str(required=True, description="The timestamp at which the chunk was processed.")
21
- sequence_number = fields.Int(required=True, description="A sequential identifier for the response part.")
22
- selected_uid = fields.Int(required=True, description="The identifier for the selected response source.")
23
-
24
-
25
- class StreamErrorSchema(Schema):
26
- error = fields.Str(required=True, description="Description of the error occurred.")
27
- timestamp = fields.Str(required=True, description="The timestamp of the error.")
28
- sequence_number = fields.Int(required=True, description="A sequential identifier for the error.")
29
- finish_reason = fields.Str(default="error", description="Indicates an error completion.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/utils.py DELETED
@@ -1,159 +0,0 @@
1
- import re
2
- import asyncio
3
- import bittensor as bt
4
- from aiohttp import web
5
- from collections import Counter
6
- from prompting.rewards import DateRewardModel, FloatDiffModel
7
- from validators.streamer import AsyncResponseDataStreamer
8
-
9
- UNSUCCESSFUL_RESPONSE_PATTERNS = [
10
- "I'm sorry",
11
- "unable to",
12
- "I cannot",
13
- "I can't",
14
- "I am unable",
15
- "I am sorry",
16
- "I can not",
17
- "don't know",
18
- "not sure",
19
- "don't understand",
20
- "not capable",
21
- ]
22
-
23
- reward_models = {
24
- "date_qa": DateRewardModel(),
25
- "math": FloatDiffModel(),
26
- }
27
-
28
-
29
- def completion_is_valid(completion: str):
30
- """
31
- Get the completion statuses from the completions.
32
- """
33
- if not completion.strip():
34
- return False
35
-
36
- patt = re.compile(
37
- r"\b(?:" + "|".join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r")\b", re.IGNORECASE
38
- )
39
- if not len(re.findall(r"\w+", completion)) or patt.search(completion):
40
- return False
41
- return True
42
-
43
-
44
- def ensemble_result(completions: list, task_name: str, prefer: str = "longest"):
45
- """
46
- Ensemble completions from multiple models.
47
- # TODO: Measure agreement
48
- # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
49
- # TODO: Reward pipeline
50
- """
51
- if not completions:
52
- return None
53
-
54
- answer = None
55
- if task_name in ("qa", "summarization"):
56
- # No special handling for QA or summarization
57
- supporting_completions = completions
58
-
59
- elif task_name == "date_qa":
60
- # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
61
- dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
62
- bt.logging.info(f"Unprocessed dates: {dates}")
63
- valid_date_indices = [i for i, d in enumerate(dates) if d]
64
- valid_completions = [completions[i] for i in valid_date_indices]
65
- valid_dates = [dates[i] for i in valid_date_indices]
66
- dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
67
- if not dates:
68
- return None
69
-
70
- counter = Counter(dates)
71
- most_common, count = counter.most_common()[0]
72
- answer = most_common
73
- if count == 1:
74
- supporting_completions = valid_completions
75
- else:
76
- supporting_completions = [
77
- c for i, c in enumerate(valid_completions) if dates[i] == most_common
78
- ]
79
-
80
- elif task_name == "math":
81
- # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
82
- # TODO: use the median instead of the most common value
83
- vals = list(map(reward_models[task_name].extract_number, completions))
84
- vals = [val for val in vals if val]
85
- if not vals:
86
- return None
87
-
88
- most_common, count = Counter(dates).most_common()[0]
89
- bt.logging.info(f"Most common value: {most_common}, count: {count}")
90
- answer = most_common
91
- if count == 1:
92
- supporting_completions = completions
93
- else:
94
- supporting_completions = [
95
- c for i, c in enumerate(completions) if vals[i] == most_common
96
- ]
97
-
98
- bt.logging.info(f"Supporting completions: {supporting_completions}")
99
- if prefer == "longest":
100
- preferred_completion = sorted(supporting_completions, key=len)[-1]
101
- elif prefer == "shortest":
102
- preferred_completion = sorted(supporting_completions, key=len)[0]
103
- elif prefer == "most_common":
104
- preferred_completion = max(
105
- set(supporting_completions), key=supporting_completions.count
106
- )
107
- else:
108
- raise ValueError(f"Unknown ensemble preference: {prefer}")
109
-
110
- return {
111
- "completion": preferred_completion,
112
- "accepted_answer": answer,
113
- "support": len(supporting_completions),
114
- "support_indices": [completions.index(c) for c in supporting_completions],
115
- "method": f'Selected the {prefer.replace("_", " ")} completion',
116
- }
117
-
118
-
119
- def guess_task_name(challenge: str):
120
- # TODO: use a pre-trained classifier to guess the task name
121
- categories = {
122
- "summarization": re.compile("summar|quick rundown|overview"),
123
- "date_qa": re.compile(
124
- "exact date|tell me when|on what date|on what day|was born?|died?"
125
- ),
126
- "math": re.compile(
127
- "math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial"
128
- ),
129
- }
130
- for task_name, patt in categories.items():
131
- if patt.search(challenge):
132
- return task_name
133
-
134
- return "qa"
135
-
136
-
137
- # Simulate the stream synapse for the echo endpoint
138
- class EchoAsyncIterator:
139
- def __init__(self, message: str, k: int, delay: float):
140
- self.message = message
141
- self.k = k
142
- self.delay = delay
143
-
144
- async def __aiter__(self):
145
- for _ in range(self.k):
146
- for word in self.message.split():
147
- yield [word]
148
- await asyncio.sleep(self.delay)
149
-
150
-
151
- async def echo_stream(request: web.Request) -> web.StreamResponse:
152
- request_data = request["data"]
153
- k = request_data.get("k", 1)
154
- message = "\n\n".join(request_data["messages"])
155
-
156
- echo_iterator = EchoAsyncIterator(message, k, delay=0.3)
157
- streamer = AsyncResponseDataStreamer(echo_iterator, selected_uid=0, delay=0.3)
158
-
159
- return await streamer.stream(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,3 @@
1
- aiohttp
2
- deprecated
3
- aiohttp_apispec>=2.2.3
4
- aiofiles
5
  streamlit
6
  plotly
7
  wandb
 
 
 
 
 
1
  streamlit
2
  plotly
3
  wandb