romanbredehoft-zama commited on
Commit
68c9ed6
·
1 Parent(s): fee1bf4

Update the requirements, fix the notebook and improve the readme

Browse files
.gitignore CHANGED
@@ -1,6 +1,6 @@
1
- tmp_encrypted_prediction.npy
2
- tmp_encrypted_quantized_encoding.npy
3
- tmp_evaluation_key.npy
4
  .venv
5
  .fhe_keys
6
  *.pyc
 
 
 
1
+ tmp/
 
 
2
  .venv
3
  .fhe_keys
4
  *.pyc
5
+ local_datasets/
6
+ .vscode/
README.md CHANGED
@@ -13,11 +13,7 @@ python_version: 3.9
13
 
14
  # Sentiment Analysis With FHE
15
 
16
- ## Running the application on your machine
17
-
18
- In this directory, ie `sentiment-analysis-with-transformer`, you can do the following steps.
19
-
20
- ### Do once
21
 
22
  - First, create a virtual env and activate it:
23
 
@@ -34,43 +30,36 @@ pip3 install -U pip wheel setuptools --ignore-installed
34
  pip3 install -r requirements.txt --ignore-installed
35
  ```
36
 
37
- - If not on Linux, or if you want to compile the FHE algorithms by yourself:
38
 
39
  ```bash
40
  python3 compile.py
41
  ```
42
 
43
- Check it finish well (with a "Done!").
44
-
45
- ### Do each time you relaunch the application
46
-
47
- - Then, in a terminal Tab 1:
48
-
49
- ```bash
50
- source .venv/bin/activate
51
- uvicorn server:app
52
- ```
53
 
54
- Tab 1 will be for the Server side.
55
 
56
- - And, in another terminal Tab 2:
57
 
58
  ```bash
59
  source .venv/bin/activate
60
  python3 app.py
61
  ```
62
 
63
- Tab 2 will be for the Client side.
64
-
65
- ## Interacting with the application
66
-
67
- Open the given URL link (search for a line like `Running on local URL: http://127.0.0.1:8888/` in your Terminal 2).
68
 
69
- ## Training a new model
 
70
 
71
- The notebook SentimentClassification.ipynb provides a way to train a new model.
72
 
73
- Before running the notebook, you need to download the data.
 
 
 
 
74
 
75
  ```bash
76
  bash download_data.sh
 
13
 
14
  # Sentiment Analysis With FHE
15
 
16
+ ## Set up the app locally
 
 
 
 
17
 
18
  - First, create a virtual env and activate it:
19
 
 
30
  pip3 install -r requirements.txt --ignore-installed
31
  ```
32
 
33
+ - (optional) Compile the FHE algorithm:
34
 
35
  ```bash
36
  python3 compile.py
37
  ```
38
 
39
+ Check it finish well (with a "Done!"). Please note that the actual model initialization and training
40
+ can be found in the [SentimentClassification notebook](SentimentClassification.ipynb) (see below).
 
 
 
 
 
 
 
 
41
 
42
+ ### Launch the app locally
43
 
44
+ - In a terminal:
45
 
46
  ```bash
47
  source .venv/bin/activate
48
  python3 app.py
49
  ```
50
 
51
+ ## Interact with the application
 
 
 
 
52
 
53
+ Open the given URL link (search for a line like `Running on local URL: http://127.0.0.1:8888/` in the
54
+ terminal).
55
 
56
+ ## Train a new model
57
 
58
+ The notebook [SentimentClassification notebook](SentimentClassification.ipynb) provides a way to
59
+ train a new model. Be aware that the data needs to be downloaded beforehand using the
60
+ [download_data.sh](download_data.sh) file (which requires Kaggle CLI).
61
+ Alternatively, the dataset can be downloaded manually at
62
+ https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment
63
 
64
  ```bash
65
  bash download_data.sh
SentimentClassification.ipynb CHANGED
@@ -21,16 +21,16 @@
21
  },
22
  {
23
  "cell_type": "code",
24
- "execution_count": 3,
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
28
  "# Import the required packages\n",
29
  "import os\n",
30
  "import time\n",
 
31
  "\n",
32
  "import numpy\n",
33
- "import onnx\n",
34
  "import pandas as pd\n",
35
  "from sklearn.metrics import average_precision_score\n",
36
  "from sklearn.model_selection import GridSearchCV, train_test_split\n",
@@ -40,7 +40,7 @@
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": 4,
44
  "metadata": {},
45
  "outputs": [
46
  {
@@ -76,7 +76,7 @@
76
  },
77
  {
78
  "cell_type": "code",
79
- "execution_count": 5,
80
  "metadata": {},
81
  "outputs": [],
82
  "source": [
@@ -105,7 +105,7 @@
105
  },
106
  {
107
  "cell_type": "code",
108
- "execution_count": 19,
109
  "metadata": {},
110
  "outputs": [],
111
  "source": [
@@ -123,7 +123,7 @@
123
  },
124
  {
125
  "cell_type": "code",
126
- "execution_count": 20,
127
  "metadata": {},
128
  "outputs": [],
129
  "source": [
@@ -135,55 +135,55 @@
135
  " \"n_bits\": [2, 3],\n",
136
  " \"max_depth\": [1],\n",
137
  " \"n_estimators\": [10, 30, 50],\n",
138
- " \"n_jobs\": [-1],\n",
139
  "}"
140
  ]
141
  },
142
  {
143
  "cell_type": "code",
144
- "execution_count": 21,
145
  "metadata": {},
146
  "outputs": [
147
  {
148
  "data": {
149
  "text/html": [
150
- "<style>#sk-container-id-3 {color: black;background-color: white;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
151
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
152
- " &#x27;n_estimators&#x27;: [10, 30, 50], &#x27;n_jobs&#x27;: [-1]},\n",
153
- " scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
154
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
155
- " &#x27;n_estimators&#x27;: [10, 30, 50], &#x27;n_jobs&#x27;: [-1]},\n",
156
- " scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-9\" type=\"checkbox\" ><label for=\"sk-estimator-id-9\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
157
  ],
158
  "text/plain": [
159
- "GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
160
  " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n",
161
- " 'n_estimators': [10, 30, 50], 'n_jobs': [-1]},\n",
162
  " scoring='accuracy')"
163
  ]
164
  },
165
- "execution_count": 21,
166
  "metadata": {},
167
  "output_type": "execute_result"
168
  }
169
  ],
170
  "source": [
171
  "# Run the gridsearch\n",
172
- "grid_search = GridSearchCV(model, parameters, cv=3, n_jobs=1, scoring=\"accuracy\")\n",
173
  "grid_search.fit(X_train, y_train)"
174
  ]
175
  },
176
  {
177
  "cell_type": "code",
178
- "execution_count": 22,
179
  "metadata": {},
180
  "outputs": [
181
  {
182
  "name": "stdout",
183
  "output_type": "stream",
184
  "text": [
185
- "Best score: 0.6842744383727991\n",
186
- "Best parameters: {'max_depth': 1, 'n_bits': 3, 'n_estimators': 50, 'n_jobs': -1}\n"
187
  ]
188
  }
189
  ],
@@ -200,17 +200,17 @@
200
  },
201
  {
202
  "cell_type": "code",
203
- "execution_count": 24,
204
  "metadata": {},
205
  "outputs": [
206
  {
207
  "name": "stdout",
208
  "output_type": "stream",
209
  "text": [
210
- "Accuracy: 0.6810\n",
211
- "Average precision score for positive class: 0.5615\n",
212
- "Average precision score for negative class: 0.8349\n",
213
- "Average precision score for neutral class: 0.3820\n"
214
  ]
215
  }
216
  ],
@@ -238,7 +238,36 @@
238
  },
239
  {
240
  "cell_type": "code",
241
- "execution_count": 48,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  "metadata": {},
243
  "outputs": [
244
  {
@@ -246,18 +275,18 @@
246
  "output_type": "stream",
247
  "text": [
248
  "5 most positive tweets (class 2):\n",
249
- "@united sent a DM just now. Thanks I am incredibly happy the fast response I got via Twitter than via customer care. Thank you\n",
250
- "@JetBlue Great Thank you, lets hope so! Could you please notify me if flight 2302 leaves JFK? Thank you again\n",
251
- "@AmericanAir Great, thanks. Followed.\n",
252
- "@SouthwestAir I continue to be amazed by the amazing customer service. Thank you SWA!\n",
253
- "@JetBlue Awesome thanks! Thanks for the quick response. You guys ROCK! :)\n",
254
  "----------------------------------------------------------------------------------------------------\n",
255
  "5 most negative tweets (class 0):\n",
256
- "@USAirways been on hold 2 hours for a Cancelled Flighted flight. I understand the delay. I don't understand you auto-reFlight Booking Problems me on TUESDAY. HELP!\n",
257
- "@SouthwestAir 2 hours on hold for customer service never us SW again\n",
258
- "@SouthwestAir placed on hold for total of two hours today after flight was Cancelled Flightled. Online option not available. What to do?\n",
259
- "@southwestair I've been on hold for 2 hours to reschedule my Cancelled Flightled flight for the morning. What gives? I need help NOW\n",
260
- "@USAirways Customer service is dead. Last wk, flts delayed/Cancelled Flighted. Bags lost 4 days. Last nt, flt delayed/Cancelled Flighted. No meal voucher?\n"
261
  ]
262
  }
263
  ],
@@ -265,26 +294,26 @@
265
  "# Let's see what are the top predictions based on the probabilities in y_pred_test\n",
266
  "print(\"5 most positive tweets (class 2):\")\n",
267
  "for i in range(5):\n",
268
- " print(text_X_test.iloc[y_proba_test_tfidf[:, 2].argsort()[-1 - i]])\n",
269
  "\n",
270
  "print(\"-\" * 100)\n",
271
  "\n",
272
  "print(\"5 most negative tweets (class 0):\")\n",
273
  "for i in range(5):\n",
274
- " print(text_X_test.iloc[y_proba_test_tfidf[:, 0].argsort()[-1 - i]])"
275
  ]
276
  },
277
  {
278
  "cell_type": "code",
279
- "execution_count": 56,
280
  "metadata": {},
281
  "outputs": [
282
  {
283
  "name": "stdout",
284
  "output_type": "stream",
285
  "text": [
286
- "Compilation time: 11.5009 seconds\n",
287
- "FHE inference time: 48.6880 seconds\n"
288
  ]
289
  }
290
  ],
@@ -303,22 +332,22 @@
303
  "\n",
304
  "# Now let's predict with FHE over a single tweet and print the time it takes\n",
305
  "start = time.perf_counter()\n",
306
- "decrypted_proba = best_model.predict_proba(X_tested_tweet, execute_in_fhe=True)\n",
307
  "end = time.perf_counter()\n",
308
  "print(f\"FHE inference time: {end - start:.4f} seconds\")"
309
  ]
310
  },
311
  {
312
  "cell_type": "code",
313
- "execution_count": 57,
314
  "metadata": {},
315
  "outputs": [
316
  {
317
  "name": "stdout",
318
  "output_type": "stream",
319
  "text": [
320
- "Probabilities from the FHE inference: [[0.50224707 0.25647676 0.24127617]]\n",
321
- "Probabilities from the clear model: [[0.50224707 0.25647676 0.24127617]]\n"
322
  ]
323
  }
324
  ],
@@ -354,7 +383,7 @@
354
  },
355
  {
356
  "cell_type": "code",
357
- "execution_count": 8,
358
  "metadata": {},
359
  "outputs": [
360
  {
@@ -385,14 +414,19 @@
385
  },
386
  {
387
  "cell_type": "code",
388
- "execution_count": 11,
389
  "metadata": {},
390
  "outputs": [
391
  {
392
  "name": "stderr",
393
  "output_type": "stream",
394
  "text": [
395
- "100%|██████████| 30/30 [00:33<00:00, 1.10s/it]\n"
 
 
 
 
 
396
  ]
397
  }
398
  ],
@@ -421,7 +455,7 @@
421
  },
422
  {
423
  "cell_type": "code",
424
- "execution_count": 12,
425
  "metadata": {},
426
  "outputs": [
427
  {
@@ -429,9 +463,9 @@
429
  "output_type": "stream",
430
  "text": [
431
  "Predictions for the first 3 tweets:\n",
432
- " [[-2.3807464 -0.61802083 2.9900746 ]\n",
433
- " [ 2.0166504 0.4938078 -2.8006463 ]\n",
434
- " [ 2.3892698 0.1344364 -2.6873822 ]]\n"
435
  ]
436
  }
437
  ],
@@ -442,7 +476,7 @@
442
  },
443
  {
444
  "cell_type": "code",
445
- "execution_count": 13,
446
  "metadata": {},
447
  "outputs": [
448
  {
@@ -488,15 +522,15 @@
488
  },
489
  {
490
  "cell_type": "code",
491
- "execution_count": 14,
492
  "metadata": {},
493
  "outputs": [
494
  {
495
  "name": "stderr",
496
  "output_type": "stream",
497
  "text": [
498
- "100%|██████████| 13176/13176 [07:20<00:00, 29.91it/s]\n",
499
- "100%|██████████| 1464/1464 [00:47<00:00, 30.75it/s]\n"
500
  ]
501
  }
502
  ],
@@ -542,28 +576,28 @@
542
  },
543
  {
544
  "cell_type": "code",
545
- "execution_count": 15,
546
  "metadata": {},
547
  "outputs": [
548
  {
549
  "data": {
550
  "text/html": [
551
- "<style>#sk-container-id-2 {color: black;background-color: white;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
552
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
553
- " &#x27;n_estimators&#x27;: [10, 30, 50], &#x27;n_jobs&#x27;: [-1]},\n",
554
- " scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
555
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
556
- " &#x27;n_estimators&#x27;: [10, 30, 50], &#x27;n_jobs&#x27;: [-1]},\n",
557
- " scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
558
  ],
559
  "text/plain": [
560
- "GridSearchCV(cv=3, estimator=XGBClassifier(), n_jobs=1,\n",
561
  " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n",
562
- " 'n_estimators': [10, 30, 50], 'n_jobs': [-1]},\n",
563
  " scoring='accuracy')"
564
  ]
565
  },
566
- "execution_count": 15,
567
  "metadata": {},
568
  "output_type": "execute_result"
569
  }
@@ -576,15 +610,15 @@
576
  },
577
  {
578
  "cell_type": "code",
579
- "execution_count": 16,
580
  "metadata": {},
581
  "outputs": [
582
  {
583
  "name": "stdout",
584
  "output_type": "stream",
585
  "text": [
586
- "Best score: 0.8378111718275654\n",
587
- "Best parameters: {'max_depth': 1, 'n_bits': 3, 'n_estimators': 50, 'n_jobs': -1}\n"
588
  ]
589
  }
590
  ],
@@ -601,17 +635,17 @@
601
  },
602
  {
603
  "cell_type": "code",
604
- "execution_count": 17,
605
  "metadata": {},
606
  "outputs": [
607
  {
608
  "name": "stdout",
609
  "output_type": "stream",
610
  "text": [
611
- "Accuracy: 0.8504\n",
612
- "Average precision score for positive class: 0.8917\n",
613
- "Average precision score for negative class: 0.9597\n",
614
- "Average precision score for neutral class: 0.7341\n"
615
  ]
616
  }
617
  ],
@@ -648,7 +682,7 @@
648
  },
649
  {
650
  "cell_type": "code",
651
- "execution_count": 19,
652
  "metadata": {},
653
  "outputs": [
654
  {
@@ -656,18 +690,18 @@
656
  "output_type": "stream",
657
  "text": [
658
  "5 most positive tweets (class 2):\n",
 
 
 
659
  "@SouthwestAir love them! Always get the best deals!\n",
660
- "@AmericanAir THANK YOU FOR ALL THE HELP! :P You guys are the best. #americanairlines #americanair\n",
661
- "@SouthwestAir - Great flight from Phoenix to Dallas tonight!Great service and ON TIME! Makes @timieyancey very happy! http://t.co/TkVCMhbPim\n",
662
- "@AmericanAir AA2416 on time and awesome flight. Great job American!\n",
663
- "@SouthwestAir AMAZING c/s today by SW thank you SO very much. This is the reason we fly you #southwest\n",
664
  "----------------------------------------------------------------------------------------------------\n",
665
  "5 most negative tweets (class 0):\n",
666
- "@AmericanAir This entire process took sooooo long that no decent seats are left. #customerservice\n",
667
  "@USAirways Not only did u lose the flight plan! Now ur flight crew is FAA timed out! Thx for havin us sit on the tarmac for an hr! #Pathetic\n",
668
- "@United site errored out at last step of changing award. Now can't even pull up reservation. 60 minute wait time. Thanks @United!\n",
669
- "@united OKC ticket agent Roger McLarren(sp?) LESS than helpful with our Intl group travel problems Can't find a supervisor for help.\n",
670
- "@AmericanAir the dinner and called me \"hon\". Not the service I would expect from 1st class. #disappointed\n"
671
  ]
672
  }
673
  ],
@@ -689,7 +723,7 @@
689
  },
690
  {
691
  "cell_type": "code",
692
- "execution_count": 20,
693
  "metadata": {},
694
  "outputs": [
695
  {
@@ -697,16 +731,16 @@
697
  "output_type": "stream",
698
  "text": [
699
  "5 most positive (predicted) tweets that are actually negative (ground truth class 0):\n",
700
- "@USAirways as far as being delayed goes… Looks like tailwinds are going to make up for it. Good news!\n",
701
  "@united thanks for the link, now finally arrived in Brussels, 9 h after schedule...\n",
 
 
702
  "@USAirways your saving grace was our flight attendant Dallas who was amazing. wish he would transfer to Delta where I would see him again\n",
703
  "@AmericanAir that luggage you forgot...#mia.....he just won an oscar😄💝💝💝\n",
704
- "@united thanks for having changed me. Managed to arrive with only 8 hours of delay and exhausted\n",
705
  "----------------------------------------------------------------------------------------------------\n",
706
  "5 most negative (predicted) tweets that are actually positive (ground truth class 2):\n",
707
  "@united thanks for updating me about the 1+ hour delay the exact second I got to ATL. 🙅🙅🙅\n",
708
- "@JetBlue you don't remember our date Monday night back to NYC? #heartbroken\n",
709
  "@SouthwestAir save mile to visit family in 2015 and this will impact how many times I can see my mother. I planned and you change the rules\n",
 
710
  "@SouthwestAir hot stewardess flipped me off\n",
711
  "@SouthwestAir - We left iPad in a seat pocket. Filed lost item report. Received it exactly 1 week Late Flightr. Is that a record? #unbelievable\n"
712
  ]
@@ -750,28 +784,35 @@
750
  },
751
  {
752
  "cell_type": "code",
753
- "execution_count": 26,
754
  "metadata": {},
755
  "outputs": [
756
  {
757
  "name": "stdout",
758
  "output_type": "stream",
759
  "text": [
760
- "Compilation time: 12.6855 seconds\n"
761
  ]
762
  },
763
  {
764
  "name": "stderr",
765
  "output_type": "stream",
766
  "text": [
767
- "100%|██████████| 1/1 [00:00<00:00, 36.43it/s]\n"
768
  ]
769
  },
770
  {
771
  "name": "stdout",
772
  "output_type": "stream",
773
  "text": [
774
- "FHE inference time: 53.0192 seconds\n"
 
 
 
 
 
 
 
775
  ]
776
  }
777
  ],
@@ -791,7 +832,7 @@
791
  "\n",
792
  "# Now let's predict with FHE over a single tweet and print the time it takes\n",
793
  "start = time.perf_counter()\n",
794
- "decrypted_proba = best_model.predict_proba(X_tested_tweet, execute_in_fhe=True)\n",
795
  "end = time.perf_counter()\n",
796
  "fhe_exec_time = end - start\n",
797
  "print(f\"FHE inference time: {fhe_exec_time:.4f} seconds\")"
@@ -799,15 +840,15 @@
799
  },
800
  {
801
  "cell_type": "code",
802
- "execution_count": 40,
803
  "metadata": {},
804
  "outputs": [
805
  {
806
  "name": "stdout",
807
  "output_type": "stream",
808
  "text": [
809
- "Probabilities from the FHE inference: [[0.08434131 0.05571389 0.8599448 ]]\n",
810
- "Probabilities from the clear model: [[0.08434131 0.05571389 0.8599448 ]]\n"
811
  ]
812
  }
813
  ],
@@ -818,34 +859,38 @@
818
  },
819
  {
820
  "cell_type": "code",
821
- "execution_count": null,
822
  "metadata": {},
823
  "outputs": [],
824
  "source": [
 
 
 
825
  "# Let's export the final model such that we can reuse it in a client/server environment\n",
826
  "\n",
827
- "# Export the model to ONNX\n",
828
- "onnx.save(best_model._onnx_model_, \"server_model.onnx\") # pylint: disable=protected-access\n",
 
829
  "\n",
830
- "# Export some data to be used for compilation\n",
831
  "X_train_numpy = X_train_transformer[:100]\n",
832
  "\n",
833
  "# Merge the two arrays in a pandas dataframe\n",
834
  "X_test_numpy_df = pd.DataFrame(X_train_numpy)\n",
835
  "\n",
836
  "# to csv\n",
837
- "X_test_numpy_df.to_csv(\"samples_for_compilation.csv\")\n",
838
  "\n",
839
  "# Let's save the model to be pushed to a server later\n",
840
  "from concrete.ml.deployment import FHEModelDev\n",
841
  "\n",
842
- "fhe_api = FHEModelDev(\"sentiment_fhe_model\", best_model)\n",
843
- "fhe_api.save()"
844
  ]
845
  },
846
  {
847
  "cell_type": "code",
848
- "execution_count": 26,
849
  "metadata": {},
850
  "outputs": [
851
  {
@@ -885,24 +930,24 @@
885
  " <tbody>\n",
886
  " <tr>\n",
887
  " <th>TF-IDF + XGBoost</th>\n",
888
- " <td>0.681011</td>\n",
889
- " <td>0.561521</td>\n",
890
- " <td>0.834914</td>\n",
891
- " <td>0.382002</td>\n",
892
  " </tr>\n",
893
  " <tr>\n",
894
  " <th>Transformer Only</th>\n",
895
  " <td>0.805328</td>\n",
896
  " <td>0.854827</td>\n",
897
  " <td>0.954804</td>\n",
898
- " <td>0.680110</td>\n",
899
  " </tr>\n",
900
  " <tr>\n",
901
  " <th>Transformer + XGBoost</th>\n",
902
- " <td>0.850410</td>\n",
903
- " <td>0.891691</td>\n",
904
- " <td>0.959747</td>\n",
905
- " <td>0.734144</td>\n",
906
  " </tr>\n",
907
  " </tbody>\n",
908
  "</table>\n",
@@ -911,24 +956,24 @@
911
  "text/plain": [
912
  " Accuracy Average Precision (positive) \\\n",
913
  "Model \n",
914
- "TF-IDF + XGBoost 0.681011 0.561521 \n",
915
  "Transformer Only 0.805328 0.854827 \n",
916
- "Transformer + XGBoost 0.850410 0.891691 \n",
917
  "\n",
918
  " Average Precision (negative) \\\n",
919
  "Model \n",
920
- "TF-IDF + XGBoost 0.834914 \n",
921
  "Transformer Only 0.954804 \n",
922
- "Transformer + XGBoost 0.959747 \n",
923
  "\n",
924
  " Average Precision (neutral) \n",
925
  "Model \n",
926
- "TF-IDF + XGBoost 0.382002 \n",
927
- "Transformer Only 0.680110 \n",
928
- "Transformer + XGBoost 0.734144 "
929
  ]
930
  },
931
- "execution_count": 26,
932
  "metadata": {},
933
  "output_type": "execute_result"
934
  }
@@ -991,7 +1036,15 @@
991
  "name": "python3"
992
  },
993
  "language_info": {
 
 
 
 
 
 
994
  "name": "python",
 
 
995
  "version": "3.10.11"
996
  }
997
  },
 
21
  },
22
  {
23
  "cell_type": "code",
24
+ "execution_count": 31,
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
28
  "# Import the required packages\n",
29
  "import os\n",
30
  "import time\n",
31
+ "from pathlib import Path\n",
32
  "\n",
33
  "import numpy\n",
 
34
  "import pandas as pd\n",
35
  "from sklearn.metrics import average_precision_score\n",
36
  "from sklearn.model_selection import GridSearchCV, train_test_split\n",
 
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": 2,
44
  "metadata": {},
45
  "outputs": [
46
  {
 
76
  },
77
  {
78
  "cell_type": "code",
79
+ "execution_count": 3,
80
  "metadata": {},
81
  "outputs": [],
82
  "source": [
 
105
  },
106
  {
107
  "cell_type": "code",
108
+ "execution_count": 4,
109
  "metadata": {},
110
  "outputs": [],
111
  "source": [
 
123
  },
124
  {
125
  "cell_type": "code",
126
+ "execution_count": 5,
127
  "metadata": {},
128
  "outputs": [],
129
  "source": [
 
135
  " \"n_bits\": [2, 3],\n",
136
  " \"max_depth\": [1],\n",
137
  " \"n_estimators\": [10, 30, 50],\n",
138
+ " # \"n_jobs\": [-1],\n",
139
  "}"
140
  ]
141
  },
142
  {
143
  "cell_type": "code",
144
+ "execution_count": 6,
145
  "metadata": {},
146
  "outputs": [
147
  {
148
  "data": {
149
  "text/html": [
150
+ "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1),\n",
151
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
152
+ " &#x27;n_estimators&#x27;: [10, 30, 50]},\n",
153
+ " scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1),\n",
154
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
155
+ " &#x27;n_estimators&#x27;: [10, 30, 50]},\n",
156
+ " scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(n_jobs=1)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(n_jobs=1)</pre></div></div></div></div></div></div></div></div></div></div>"
157
  ],
158
  "text/plain": [
159
+ "GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1),\n",
160
  " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n",
161
+ " 'n_estimators': [10, 30, 50]},\n",
162
  " scoring='accuracy')"
163
  ]
164
  },
165
+ "execution_count": 6,
166
  "metadata": {},
167
  "output_type": "execute_result"
168
  }
169
  ],
170
  "source": [
171
  "# Run the gridsearch\n",
172
+ "grid_search = GridSearchCV(model, parameters, cv=3, scoring=\"accuracy\")\n",
173
  "grid_search.fit(X_train, y_train)"
174
  ]
175
  },
176
  {
177
  "cell_type": "code",
178
+ "execution_count": 7,
179
  "metadata": {},
180
  "outputs": [
181
  {
182
  "name": "stdout",
183
  "output_type": "stream",
184
  "text": [
185
+ "Best score: 0.705980570734669\n",
186
+ "Best parameters: {'max_depth': 1, 'n_bits': 3, 'n_estimators': 50}\n"
187
  ]
188
  }
189
  ],
 
200
  },
201
  {
202
  "cell_type": "code",
203
+ "execution_count": 8,
204
  "metadata": {},
205
  "outputs": [
206
  {
207
  "name": "stdout",
208
  "output_type": "stream",
209
  "text": [
210
+ "Accuracy: 0.7117\n",
211
+ "Average precision score for positive class: 0.6404\n",
212
+ "Average precision score for negative class: 0.8719\n",
213
+ "Average precision score for neutral class: 0.4349\n"
214
  ]
215
  }
216
  ],
 
238
  },
239
  {
240
  "cell_type": "code",
241
+ "execution_count": 9,
242
+ "metadata": {},
243
+ "outputs": [
244
+ {
245
+ "data": {
246
+ "text/plain": [
247
+ "array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
248
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
249
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
250
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
251
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
252
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
253
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
254
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
255
+ " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
256
+ " 2, 2, 2, 2, 2, 2])"
257
+ ]
258
+ },
259
+ "execution_count": 9,
260
+ "metadata": {},
261
+ "output_type": "execute_result"
262
+ }
263
+ ],
264
+ "source": [
265
+ "y_pred_test_tfidf[y_pred_test_tfidf == 2]"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 10,
271
  "metadata": {},
272
  "outputs": [
273
  {
 
275
  "output_type": "stream",
276
  "text": [
277
  "5 most positive tweets (class 2):\n",
278
+ "@JetBlue do bags still fly free or have you started charging? thanks!\n",
279
+ "@SouthwestAir Is there a way to receive a refund on a trip that was Cancelled Flight online instead of calling? Your phone lines are super busy.\n",
280
+ "@JetBlue bag is supposedly here in Boston\n",
281
+ "@AmericanAir Cancelled Flights my flight, doesn't send an email, text or call. Now I'm stranded in Louisville.\n",
282
+ "@SouthwestAir I need to Cancelled Flight one leg of a flight, but can't seem to do this online. Been on hold on the phone for 10 minutes. Any help?\n",
283
  "----------------------------------------------------------------------------------------------------\n",
284
  "5 most negative tweets (class 0):\n",
285
+ "@AmericanAir - keeping AA up in the Air! My crew chief cousin Alex Espinosa in DFW! http://t.co/0HXLNvZknP\n",
286
+ "@JetBlue Called JB 3 times!Everytime, Auto Vmsg:\"your wait time should not be longer than 9 mins\" waited longer than 18 mins and no answer!\n",
287
+ "@SouthwestAir can you outline the policies for both scenarios?\n",
288
+ "@united is not a company that values it's customer &amp; after reading tweets to them I'm not the only one who feels that way #lostmybusiness\n",
289
+ "@JetBlue how about free wifi on flt 1254 out of PBI to make up for 2.5 hr delay? Treat us right.\n"
290
  ]
291
  }
292
  ],
 
294
  "# Let's see what are the top predictions based on the probabilities in y_pred_test\n",
295
  "print(\"5 most positive tweets (class 2):\")\n",
296
  "for i in range(5):\n",
297
+ " print(text_X_test.iloc[y_pred_test_tfidf[y_pred_test_tfidf==2].argsort()[-1 - i]])\n",
298
  "\n",
299
  "print(\"-\" * 100)\n",
300
  "\n",
301
  "print(\"5 most negative tweets (class 0):\")\n",
302
  "for i in range(5):\n",
303
+ " print(text_X_test.iloc[y_pred_test_tfidf[y_pred_test_tfidf==0].argsort()[-1 - i]])"
304
  ]
305
  },
306
  {
307
  "cell_type": "code",
308
+ "execution_count": 11,
309
  "metadata": {},
310
  "outputs": [
311
  {
312
  "name": "stdout",
313
  "output_type": "stream",
314
  "text": [
315
+ "Compilation time: 5.4779 seconds\n",
316
+ "FHE inference time: 1.1039 seconds\n"
317
  ]
318
  }
319
  ],
 
332
  "\n",
333
  "# Now let's predict with FHE over a single tweet and print the time it takes\n",
334
  "start = time.perf_counter()\n",
335
+ "decrypted_proba = best_model.predict_proba(X_tested_tweet, fhe=\"execute\")\n",
336
  "end = time.perf_counter()\n",
337
  "print(f\"FHE inference time: {end - start:.4f} seconds\")"
338
  ]
339
  },
340
  {
341
  "cell_type": "code",
342
+ "execution_count": 12,
343
  "metadata": {},
344
  "outputs": [
345
  {
346
  "name": "stdout",
347
  "output_type": "stream",
348
  "text": [
349
+ "Probabilities from the FHE inference: [[0.30244059 0.17506451 0.5224949 ]]\n",
350
+ "Probabilities from the clear model: [[0.30244059 0.17506451 0.5224949 ]]\n"
351
  ]
352
  }
353
  ],
 
383
  },
384
  {
385
  "cell_type": "code",
386
+ "execution_count": 13,
387
  "metadata": {},
388
  "outputs": [
389
  {
 
414
  },
415
  {
416
  "cell_type": "code",
417
+ "execution_count": 14,
418
  "metadata": {},
419
  "outputs": [
420
  {
421
  "name": "stderr",
422
  "output_type": "stream",
423
  "text": [
424
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
425
+ "To disable this warning, you can either:\n",
426
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
427
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
428
+ " 0%| | 0/30 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.\n",
429
+ "100%|██████████| 30/30 [00:20<00:00, 1.46it/s]\n"
430
  ]
431
  }
432
  ],
 
455
  },
456
  {
457
  "cell_type": "code",
458
+ "execution_count": 15,
459
  "metadata": {},
460
  "outputs": [
461
  {
 
463
  "output_type": "stream",
464
  "text": [
465
  "Predictions for the first 3 tweets:\n",
466
+ " [[-2.3807454 -0.61802197 2.9900734 ]\n",
467
+ " [ 2.0166504 0.49380752 -2.8006463 ]\n",
468
+ " [ 2.3892734 0.13443531 -2.6873832 ]]\n"
469
  ]
470
  }
471
  ],
 
476
  },
477
  {
478
  "cell_type": "code",
479
+ "execution_count": 16,
480
  "metadata": {},
481
  "outputs": [
482
  {
 
522
  },
523
  {
524
  "cell_type": "code",
525
+ "execution_count": 17,
526
  "metadata": {},
527
  "outputs": [
528
  {
529
  "name": "stderr",
530
  "output_type": "stream",
531
  "text": [
532
+ "100%|██████████| 13176/13176 [08:10<00:00, 26.88it/s]\n",
533
+ "100%|██████████| 1464/1464 [00:54<00:00, 26.90it/s]\n"
534
  ]
535
  }
536
  ],
 
576
  },
577
  {
578
  "cell_type": "code",
579
+ "execution_count": 18,
580
  "metadata": {},
581
  "outputs": [
582
  {
583
  "data": {
584
  "text/html": [
585
+ "<style>#sk-container-id-2 {color: black;background-color: white;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1), n_jobs=1,\n",
586
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
587
+ " &#x27;n_estimators&#x27;: [10, 30, 50]},\n",
588
+ " scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1), n_jobs=1,\n",
589
  " param_grid={&#x27;max_depth&#x27;: [1], &#x27;n_bits&#x27;: [2, 3],\n",
590
+ " &#x27;n_estimators&#x27;: [10, 30, 50]},\n",
591
+ " scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(n_jobs=1)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(n_jobs=1)</pre></div></div></div></div></div></div></div></div></div></div>"
592
  ],
593
  "text/plain": [
594
+ "GridSearchCV(cv=3, estimator=XGBClassifier(n_jobs=1), n_jobs=1,\n",
595
  " param_grid={'max_depth': [1], 'n_bits': [2, 3],\n",
596
+ " 'n_estimators': [10, 30, 50]},\n",
597
  " scoring='accuracy')"
598
  ]
599
  },
600
+ "execution_count": 18,
601
  "metadata": {},
602
  "output_type": "execute_result"
603
  }
 
610
  },
611
  {
612
  "cell_type": "code",
613
+ "execution_count": 19,
614
  "metadata": {},
615
  "outputs": [
616
  {
617
  "name": "stdout",
618
  "output_type": "stream",
619
  "text": [
620
+ "Best score: 0.8381147540983607\n",
621
+ "Best parameters: {'max_depth': 1, 'n_bits': 3, 'n_estimators': 50}\n"
622
  ]
623
  }
624
  ],
 
635
  },
636
  {
637
  "cell_type": "code",
638
+ "execution_count": 20,
639
  "metadata": {},
640
  "outputs": [
641
  {
642
  "name": "stdout",
643
  "output_type": "stream",
644
  "text": [
645
+ "Accuracy: 0.8463\n",
646
+ "Average precision score for positive class: 0.8959\n",
647
+ "Average precision score for negative class: 0.9647\n",
648
+ "Average precision score for neutral class: 0.7449\n"
649
  ]
650
  }
651
  ],
 
682
  },
683
  {
684
  "cell_type": "code",
685
+ "execution_count": 21,
686
  "metadata": {},
687
  "outputs": [
688
  {
 
690
  "output_type": "stream",
691
  "text": [
692
  "5 most positive tweets (class 2):\n",
693
+ "@united I think this is the best first class I have ever gotten!! Denver to LAX and it's wonderful!!!\n",
694
+ "@AmericanAir Flight 236 was great. Fantastic cabin crew. A+ landing. #thankyou #JFK http://t.co/dRW08djHAI\n",
695
+ "@SouthwestAir Jason (108639) at Gate #3 in SAN made my afternoon!!! #southwestairlines #stellarservice #thanks!\n",
696
  "@SouthwestAir love them! Always get the best deals!\n",
697
+ "@AmericanAir simply amazing. Smiles for miles.Thank u for my upgrade tomorrow for ORD.We are spending a lot of time together next few weeks!\n",
 
 
 
698
  "----------------------------------------------------------------------------------------------------\n",
699
  "5 most negative tweets (class 0):\n",
700
+ "@united first you lost all my bags, now you Cancelled Flight my flight home. 30 min wait to talk to somebody #poorservice #notgoodenough\n",
701
  "@USAirways Not only did u lose the flight plan! Now ur flight crew is FAA timed out! Thx for havin us sit on the tarmac for an hr! #Pathetic\n",
702
+ "@AmericanAir Phone just disconnects if you stay on the line. Need to checkout of hotel in 2 hrs &amp; have no place to go. Can't keep calling.\n",
703
+ "@VirginAmerica I have lots of flights to book and your site it not working!!!! I've been on the phone waiting for over 10 minutes..........\n",
704
+ "@united 3 hour delay plus a jetway that won't move. This biz traveler is never flying u again!\n"
705
  ]
706
  }
707
  ],
 
723
  },
724
  {
725
  "cell_type": "code",
726
+ "execution_count": 22,
727
  "metadata": {},
728
  "outputs": [
729
  {
 
731
  "output_type": "stream",
732
  "text": [
733
  "5 most positive (predicted) tweets that are actually negative (ground truth class 0):\n",
 
734
  "@united thanks for the link, now finally arrived in Brussels, 9 h after schedule...\n",
735
+ "@USAirways as far as being delayed goes… Looks like tailwinds are going to make up for it. Good news!\n",
736
+ "@united thanks for having changed me. Managed to arrive with only 8 hours of delay and exhausted\n",
737
  "@USAirways your saving grace was our flight attendant Dallas who was amazing. wish he would transfer to Delta where I would see him again\n",
738
  "@AmericanAir that luggage you forgot...#mia.....he just won an oscar😄💝💝💝\n",
 
739
  "----------------------------------------------------------------------------------------------------\n",
740
  "5 most negative (predicted) tweets that are actually positive (ground truth class 2):\n",
741
  "@united thanks for updating me about the 1+ hour delay the exact second I got to ATL. 🙅🙅🙅\n",
 
742
  "@SouthwestAir save mile to visit family in 2015 and this will impact how many times I can see my mother. I planned and you change the rules\n",
743
+ "@JetBlue you don't remember our date Monday night back to NYC? #heartbroken\n",
744
  "@SouthwestAir hot stewardess flipped me off\n",
745
  "@SouthwestAir - We left iPad in a seat pocket. Filed lost item report. Received it exactly 1 week Late Flightr. Is that a record? #unbelievable\n"
746
  ]
 
784
  },
785
  {
786
  "cell_type": "code",
787
+ "execution_count": 23,
788
  "metadata": {},
789
  "outputs": [
790
  {
791
  "name": "stdout",
792
  "output_type": "stream",
793
  "text": [
794
+ "Compilation time: 5.9232 seconds\n"
795
  ]
796
  },
797
  {
798
  "name": "stderr",
799
  "output_type": "stream",
800
  "text": [
801
+ "100%|██████████| 1/1 [00:00<00:00, 17.83it/s]"
802
  ]
803
  },
804
  {
805
  "name": "stdout",
806
  "output_type": "stream",
807
  "text": [
808
+ "FHE inference time: 0.8374 seconds\n"
809
+ ]
810
+ },
811
+ {
812
+ "name": "stderr",
813
+ "output_type": "stream",
814
+ "text": [
815
+ "\n"
816
  ]
817
  }
818
  ],
 
832
  "\n",
833
  "# Now let's predict with FHE over a single tweet and print the time it takes\n",
834
  "start = time.perf_counter()\n",
835
+ "decrypted_proba = best_model.predict_proba(X_tested_tweet, fhe=\"execute\")\n",
836
  "end = time.perf_counter()\n",
837
  "fhe_exec_time = end - start\n",
838
  "print(f\"FHE inference time: {fhe_exec_time:.4f} seconds\")"
 
840
  },
841
  {
842
  "cell_type": "code",
843
+ "execution_count": 24,
844
  "metadata": {},
845
  "outputs": [
846
  {
847
  "name": "stdout",
848
  "output_type": "stream",
849
  "text": [
850
+ "Probabilities from the FHE inference: [[0.05162184 0.04558276 0.90279541]]\n",
851
+ "Probabilities from the clear model: [[0.05162184 0.04558276 0.90279541]]\n"
852
  ]
853
  }
854
  ],
 
859
  },
860
  {
861
  "cell_type": "code",
862
+ "execution_count": 40,
863
  "metadata": {},
864
  "outputs": [],
865
  "source": [
866
+ "DEPLOYMENT_DIR = Path(\"deployment\")\n",
867
+ "DEPLOYMENT_DIR.mkdir(exist_ok=True)\n",
868
+ "\n",
869
  "# Let's export the final model such that we can reuse it in a client/server environment\n",
870
  "\n",
871
+ "# Serialize the model (for development only)\n",
872
+ "with (DEPLOYMENT_DIR / \"serialized_model\").open(\"w\") as file:\n",
873
+ " best_model.dump(file)\n",
874
  "\n",
875
+ "# Export some data to be used for compilation \n",
876
  "X_train_numpy = X_train_transformer[:100]\n",
877
  "\n",
878
  "# Merge the two arrays in a pandas dataframe\n",
879
  "X_test_numpy_df = pd.DataFrame(X_train_numpy)\n",
880
  "\n",
881
  "# to csv\n",
882
+ "X_test_numpy_df.to_csv(DEPLOYMENT_DIR / \"samples_for_compilation.csv\")\n",
883
  "\n",
884
  "# Let's save the model to be pushed to a server later\n",
885
  "from concrete.ml.deployment import FHEModelDev\n",
886
  "\n",
887
+ "fhe_api = FHEModelDev(DEPLOYMENT_DIR / \"sentiment_fhe_model\", best_model)\n",
888
+ "fhe_api.save(via_mlir=True)"
889
  ]
890
  },
891
  {
892
  "cell_type": "code",
893
+ "execution_count": null,
894
  "metadata": {},
895
  "outputs": [
896
  {
 
930
  " <tbody>\n",
931
  " <tr>\n",
932
  " <th>TF-IDF + XGBoost</th>\n",
933
+ " <td>0.711749</td>\n",
934
+ " <td>0.640422</td>\n",
935
+ " <td>0.871891</td>\n",
936
+ " <td>0.43486</td>\n",
937
  " </tr>\n",
938
  " <tr>\n",
939
  " <th>Transformer Only</th>\n",
940
  " <td>0.805328</td>\n",
941
  " <td>0.854827</td>\n",
942
  " <td>0.954804</td>\n",
943
+ " <td>0.68011</td>\n",
944
  " </tr>\n",
945
  " <tr>\n",
946
  " <th>Transformer + XGBoost</th>\n",
947
+ " <td>0.846311</td>\n",
948
+ " <td>0.895930</td>\n",
949
+ " <td>0.964674</td>\n",
950
+ " <td>0.74489</td>\n",
951
  " </tr>\n",
952
  " </tbody>\n",
953
  "</table>\n",
 
956
  "text/plain": [
957
  " Accuracy Average Precision (positive) \\\n",
958
  "Model \n",
959
+ "TF-IDF + XGBoost 0.711749 0.640422 \n",
960
  "Transformer Only 0.805328 0.854827 \n",
961
+ "Transformer + XGBoost 0.846311 0.895930 \n",
962
  "\n",
963
  " Average Precision (negative) \\\n",
964
  "Model \n",
965
+ "TF-IDF + XGBoost 0.871891 \n",
966
  "Transformer Only 0.954804 \n",
967
+ "Transformer + XGBoost 0.964674 \n",
968
  "\n",
969
  " Average Precision (neutral) \n",
970
  "Model \n",
971
+ "TF-IDF + XGBoost 0.43486 \n",
972
+ "Transformer Only 0.68011 \n",
973
+ "Transformer + XGBoost 0.74489 "
974
  ]
975
  },
976
+ "execution_count": 33,
977
  "metadata": {},
978
  "output_type": "execute_result"
979
  }
 
1036
  "name": "python3"
1037
  },
1038
  "language_info": {
1039
+ "codemirror_mode": {
1040
+ "name": "ipython",
1041
+ "version": 3
1042
+ },
1043
+ "file_extension": ".py",
1044
+ "mimetype": "text/x-python",
1045
  "name": "python",
1046
+ "nbconvert_exporter": "python",
1047
+ "pygments_lexer": "ipython3",
1048
  "version": "3.10.11"
1049
  }
1050
  },
app.py CHANGED
@@ -26,6 +26,7 @@ time.sleep(5)
26
  # (encrypted data is too large to display in the browser)
27
  ENCRYPTED_DATA_BROWSER_LIMIT = 500
28
  N_USER_KEY_STORED = 20
 
29
 
30
  print("Loading the transformer model...")
31
 
@@ -60,7 +61,7 @@ def keygen():
60
 
61
  # Let's create a user_id
62
  user_id = numpy.random.randint(0, 2**32)
63
- fhe_api = FHEModelClient("sentiment_fhe_model/deployment", f".fhe_keys/{user_id}")
64
  fhe_api.load()
65
 
66
 
@@ -79,7 +80,7 @@ def encode_quantize_encrypt(text, user_id):
79
  if not user_id:
80
  raise gr.Error("You need to generate FHE keys first.")
81
 
82
- fhe_api = FHEModelClient("sentiment_fhe_model/deployment", f".fhe_keys/{user_id}")
83
  fhe_api.load()
84
  encodings = transformer_vectorizer.transform([text])
85
  quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8)
@@ -143,7 +144,7 @@ def decrypt_prediction(user_id):
143
  # Read encrypted_prediction from the file
144
  encrypted_prediction = numpy.load(encoded_data_path).tobytes()
145
 
146
- fhe_api = FHEModelClient("sentiment_fhe_model/deployment", f".fhe_keys/{user_id}")
147
  fhe_api.load()
148
 
149
  # We need to retrieve the private key that matches the client specs (see issue #18)
 
26
  # (encrypted data is too large to display in the browser)
27
  ENCRYPTED_DATA_BROWSER_LIMIT = 500
28
  N_USER_KEY_STORED = 20
29
+ FHE_MODEL_PATH = "deployment/sentiment_fhe_model"
30
 
31
  print("Loading the transformer model...")
32
 
 
61
 
62
  # Let's create a user_id
63
  user_id = numpy.random.randint(0, 2**32)
64
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
65
  fhe_api.load()
66
 
67
 
 
80
  if not user_id:
81
  raise gr.Error("You need to generate FHE keys first.")
82
 
83
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
84
  fhe_api.load()
85
  encodings = transformer_vectorizer.transform([text])
86
  quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8)
 
144
  # Read encrypted_prediction from the file
145
  encrypted_prediction = numpy.load(encoded_data_path).tobytes()
146
 
147
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
148
  fhe_api.load()
149
 
150
  # We need to retrieve the private key that matches the client specs (see issue #18)
compile.py CHANGED
@@ -1,7 +1,8 @@
1
  import onnx
2
  import pandas as pd
3
  from concrete.ml.deployment import FHEModelDev, FHEModelClient
4
- from concrete.ml.onnx.convert import get_equivalent_numpy_forward
 
5
  import json
6
  import os
7
  import shutil
@@ -10,48 +11,29 @@ from pathlib import Path
10
 
11
  script_dir = Path(__file__).parent
12
 
 
 
13
  print("Compiling the model...")
14
 
15
- # Load the onnx model
16
- model_onnx = onnx.load(Path.joinpath(script_dir, "sentiment_fhe_model/server_model.onnx"))
17
 
18
  # Load the data from the csv file to be used for compilation
19
- data = pd.read_csv(
20
- Path.joinpath(script_dir, "sentiment_fhe_model/samples_for_compilation.csv"), index_col=0
21
- ).values
22
-
23
- # Convert the onnx model to a numpy model
24
- _tensor_tree_predict = get_equivalent_numpy_forward(model_onnx)
25
-
26
- model = FHEModelClient(
27
- Path.joinpath(script_dir, "sentiment_fhe_model/deployment"), ".fhe_keys"
28
- ).model
29
-
30
- # Assign the numpy model and compile the model
31
- model._tensor_tree_predict = _tensor_tree_predict
32
 
33
  # Compile the model
34
  model.compile(data)
35
 
36
- # Load the serialized_processing.json file
37
- with open(
38
- Path.joinpath(script_dir, "sentiment_fhe_model/deployment/serialized_processing.json"), "r"
39
- ) as f:
40
- serialized_processing = json.load(f)
41
 
42
  # Delete the deployment folder if it exist
43
- if Path.joinpath(script_dir, "sentiment_fhe_model/deployment").exists():
44
- shutil.rmtree(Path.joinpath(script_dir, "sentiment_fhe_model/deployment"))
45
 
46
  fhe_api = FHEModelDev(
47
- model=model, path_dir=Path.joinpath(script_dir, "sentiment_fhe_model/deployment")
48
  )
49
  fhe_api.save(via_mlir=True)
50
 
51
- # Write the serialized_processing.json file to the deployment folder
52
- with open(
53
- Path.joinpath(script_dir, "sentiment_fhe_model/deployment/serialized_processing.json"), "w"
54
- ) as f:
55
- json.dump(serialized_processing, f)
56
 
57
  print("Done!")
 
1
  import onnx
2
  import pandas as pd
3
  from concrete.ml.deployment import FHEModelDev, FHEModelClient
4
+ from concrete.ml.common.serialization.loaders import load
5
+ from concrete.ml.onnx.convert import get_equivalent_numpy_forward_from_onnx_tree
6
  import json
7
  import os
8
  import shutil
 
11
 
12
  script_dir = Path(__file__).parent
13
 
14
+ DEPLOYMENT_DIR = script_dir / "deployment"
15
+
16
  print("Compiling the model...")
17
 
18
+ with (DEPLOYMENT_DIR / "serialized_model").open("r") as file:
19
+ model = load(file)
20
 
21
  # Load the data from the csv file to be used for compilation
22
+ data = pd.read_csv(DEPLOYMENT_DIR / "samples_for_compilation.csv", index_col=0).values
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Compile the model
25
  model.compile(data)
26
 
27
+ dev_model_path = DEPLOYMENT_DIR / "sentiment_fhe_model"
 
 
 
 
28
 
29
  # Delete the deployment folder if it exist
30
+ if dev_model_path.is_dir():
31
+ shutil.rmtree(dev_model_path)
32
 
33
  fhe_api = FHEModelDev(
34
+ model=model, path_dir=dev_model_path
35
  )
36
  fhe_api.save(via_mlir=True)
37
 
 
 
 
 
 
38
 
39
  print("Done!")
deployment/samples_for_compilation.csv ADDED
The diff for this file is too large to render. See raw diff
 
deployment/sentiment_fhe_model/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:972f0c7d83f12e3a43e8f923fc422cdb443b9f64bb6f74c1abf912836ba27e60
3
+ size 3887326
deployment/sentiment_fhe_model/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:216d2a78d7ec47ec2a478d5f32ed34cee8a9c45700325e5d8de4e087b7ed8dfc
3
+ size 3004
deployment/sentiment_fhe_model/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-python": "2.5", "concrete-ml": "1.4.0", "python": "3.10.11"}
deployment/serialized_model ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- concrete-ml==1.1.0
2
  gradio==3.40.1
3
  pandas==1.4.3
4
- transformers==4.32.0
5
  jupyter==1.0.0
 
1
+ concrete-ml==1.4.0
2
  gradio==3.40.1
3
  pandas==1.4.3
4
+ transformers==4.36.0
5
  jupyter==1.0.0
sentiment_fhe_model/samples_for_compilation.csv DELETED
The diff for this file is too large to render. See raw diff
 
server.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  current_dir = Path(__file__).parent
10
 
11
  # Load the model
12
- fhe_model = FHEModelServer(Path.joinpath(current_dir, "sentiment_fhe_model/deployment"))
13
 
14
  class PredictRequest(BaseModel):
15
  evaluation_key: str
 
9
  current_dir = Path(__file__).parent
10
 
11
  # Load the model
12
+ fhe_model = FHEModelServer("deployment/sentiment_fhe_model")
13
 
14
  class PredictRequest(BaseModel):
15
  evaluation_key: str