PB Unity commited on
Commit
d43e3d1
1 Parent(s): 7c8fd17

Upload RunPhi15.cs

Browse files
Files changed (1) hide show
  1. RunPhi15.cs +276 -0
RunPhi15.cs ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using System.Collections;
2
+ using System.Collections.Generic;
3
+ using UnityEngine;
4
+ using Unity.Sentis;
5
+ using System.IO;
6
+ using Newtonsoft.Json;
7
+ using System.Text;
8
+
9
+ /*
10
+ * Phi 1.5 Inference Code
11
+ * =======================
12
+ *
13
+ * Put this script on the Main Camera
14
+ *
15
+ * In Assets/StreamingAssets put:
16
+ *
17
+ * phi15.sentis
18
+ * vocab.json
19
+ * merges.txt
20
+ *
21
+ * Install package com.unity.nuget.newtonsoft-json from packagemanger
22
+ * Install package com.unity.sentis
23
+ *
24
+ */
25
+
26
+
27
+ public class RunPhi15 : MonoBehaviour
28
+ {
29
+ const BackendType backend = BackendType.GPUCompute;
30
+ //string outputString = "Question: \"What is the capital of France?\"\n Correct answer: \"";
31
+ //string outputString = "The human asked, \"What is your favourite animal?\" so the wise man answered correctly, \"";
32
+ string outputString = "Once upon a time, there were three";
33
+
34
+ // This is how many tokens you want. It can be adjusted.
35
+ const int maxTokens = 100;
36
+
37
+ //Make this smaller for more randomness
38
+ const float predictability = 5;
39
+
40
+ //Special tokens
41
+ const int END_OF_TEXT = 50256;
42
+
43
+ Ops ops;
44
+ ITensorAllocator allocator;
45
+
46
+ //Store the vocabulary
47
+ string[] tokens;
48
+
49
+ IWorker engine;
50
+
51
+ int currentToken = 0;
52
+ int[] outputTokens = new int[maxTokens];
53
+
54
+ // Used for special character decoding
55
+ int[] whiteSpaceCharacters = new int[256];
56
+ int[] encodedCharacters = new int[256];
57
+
58
+ bool runInference = false;
59
+
60
+
61
+ //stop after this many tokens
62
+ const int stopAfter = 200;
63
+
64
+ int totalTokens = 0;
65
+
66
+ string[] merges;
67
+ Dictionary<string, int> vocab;
68
+
69
+ void Start()
70
+ {
71
+ allocator = new TensorCachingAllocator();
72
+ ops = WorkerFactory.CreateOps(backend, allocator);
73
+
74
+ SetupWhiteSpaceShifts();
75
+
76
+ LoadVocabulary();
77
+
78
+ Model model = ModelLoader.Load(Application.streamingAssetsPath + "/phi15.sentis");
79
+ engine = WorkerFactory.CreateWorker(backend, model);
80
+
81
+ GO(outputString);
82
+ }
83
+
84
+ public void GO(string text)
85
+ {
86
+ outputString = text;
87
+ DecodePrompt(outputString);
88
+ runInference = true;
89
+ }
90
+
91
+ // Update is called once per frame
92
+ void Update()
93
+ {
94
+ if (runInference)
95
+ {
96
+ RunInference();
97
+ }
98
+ }
99
+
100
+ void RunInference()
101
+ {
102
+ using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
103
+
104
+ engine.Execute(tokensSoFar);
105
+
106
+ var tokensOut = engine.PeekOutput() as TensorFloat;
107
+
108
+ using var row = ops.Slice(tokensOut, new[] { currentToken }, new[] { currentToken + 1 }, new[] { 1 }, new[] { 1 });
109
+ using var rowB = ops.Mul(predictability, row);
110
+ using var probs = ops.Softmax(rowB, 2);
111
+ probs.MakeReadable();
112
+
113
+ int ID = SelectRandomToken(probs.ToReadOnlyArray());
114
+
115
+ if (currentToken >= maxTokens - 1)
116
+ {
117
+ for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
118
+ currentToken--;
119
+ }
120
+
121
+ outputTokens[++currentToken] = ID;
122
+ totalTokens++;
123
+
124
+ if (ID == END_OF_TEXT || totalTokens >= stopAfter)
125
+ {
126
+ runInference = false;
127
+ }
128
+ else outputString += GetUnicodeText(tokens[ID]);
129
+
130
+ Debug.Log(outputString);
131
+ }
132
+
133
+ void DecodePrompt(string text)
134
+ {
135
+ var inputTokens = GetTokens(text);
136
+
137
+ for(int i = 0; i < inputTokens.Count; i++)
138
+ {
139
+ outputTokens[i] = inputTokens[i];
140
+ }
141
+ currentToken = inputTokens.Count - 1;
142
+ }
143
+
144
+
145
+ void LoadVocabulary()
146
+ {
147
+ var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json");
148
+ vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
149
+ tokens = new string[vocab.Count];
150
+ foreach (var item in vocab)
151
+ {
152
+ tokens[item.Value] = item.Key;
153
+ }
154
+
155
+ merges = File.ReadAllLines(Application.streamingAssetsPath + "/merges.txt");
156
+ }
157
+
158
+
159
+ int SelectRandomToken(float[] probs)
160
+ {
161
+ float p = UnityEngine.Random.Range(0, 1f);
162
+ float t = 0;
163
+ for (int i = 0; i < probs.Length; i++)
164
+ {
165
+ t += probs[i];
166
+ if (p < t)
167
+ {
168
+ return i;
169
+ }
170
+ }
171
+ return probs.Length - 1;
172
+ }
173
+
174
+ // Translates encoded special characters to Unicode
175
+ string GetUnicodeText(string text)
176
+ {
177
+ var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text));
178
+ return Encoding.UTF8.GetString(bytes);
179
+ }
180
+ string GetASCIIText(string newText)
181
+ {
182
+ var bytes = Encoding.UTF8.GetBytes(newText);
183
+ return ShiftCharacterUp(Encoding.GetEncoding("ISO-8859-1").GetString(bytes));
184
+ }
185
+
186
+ string ShiftCharacterDown(string text)
187
+ {
188
+ string outText = "";
189
+ foreach (char letter in text)
190
+ {
191
+ outText += ((int)letter <= 256) ? letter :
192
+ (char)whiteSpaceCharacters[(int)(letter - 256)];
193
+ }
194
+ return outText;
195
+ }
196
+
197
+ string ShiftCharacterUp(string text)
198
+ {
199
+ string outText = "";
200
+ foreach (char letter in text)
201
+ {
202
+ outText += (char)encodedCharacters[(int)letter];
203
+ }
204
+ return outText;
205
+ }
206
+
207
+ void SetupWhiteSpaceShifts()
208
+ {
209
+ for (int i = 0, n = 0; i < 256; i++)
210
+ {
211
+ encodedCharacters[i] = i;
212
+ if (IsWhiteSpace((char)i))
213
+ {
214
+ encodedCharacters[i] = n + 256;
215
+ whiteSpaceCharacters[n++] = i;
216
+ }
217
+ }
218
+ }
219
+
220
+ bool IsWhiteSpace(char c)
221
+ {
222
+ return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�'));
223
+ }
224
+
225
+ List<int> GetTokens(string text)
226
+ {
227
+ text = GetASCIIText(text);
228
+
229
+ // Start with a list of single characters
230
+ var inputTokens = new List<string>();
231
+ foreach(var letter in text)
232
+ {
233
+ inputTokens.Add(letter.ToString());
234
+ }
235
+
236
+ ApplyMerges(inputTokens);
237
+
238
+ //Find the ids of the words in the vocab
239
+ var ids = new List<int>();
240
+ foreach(var token in inputTokens)
241
+ {
242
+ if (vocab.TryGetValue(token, out int id))
243
+ {
244
+ ids.Add(id);
245
+ }
246
+ }
247
+
248
+ return ids;
249
+ }
250
+
251
+ void ApplyMerges(List<string> inputTokens)
252
+ {
253
+ foreach(var merge in merges)
254
+ {
255
+ string[] pair = merge.Split(' ');
256
+ int n = 0;
257
+ while (n >= 0)
258
+ {
259
+ n = inputTokens.IndexOf(pair[0], n);
260
+ if (n != -1 && n < inputTokens.Count - 1 && inputTokens[n + 1] == pair[1])
261
+ {
262
+ inputTokens[n] += inputTokens[n + 1];
263
+ inputTokens.RemoveAt(n + 1);
264
+ }
265
+ if (n != -1) n++;
266
+ }
267
+ }
268
+ }
269
+
270
+ private void OnDestroy()
271
+ {
272
+ engine?.Dispose();
273
+ ops?.Dispose();
274
+ allocator?.Dispose();
275
+ }
276
+ }