TomSmail commited on
Commit
efd4869
·
1 Parent(s): e03e522

feat: add mobile net pretrained model.

Browse files
Files changed (1) hide show
  1. cnn.ipynb +173 -0
cnn.ipynb ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import concrete.ml\n",
10
+ "import torch\n"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {},
16
+ "source": [
17
+ "Training: \n",
18
+ " 1. Gather dataset of pictures\n",
19
+ " 2. Preprocess the data\n",
20
+ " 3. Find pretrained model \n",
21
+ " 4. Segment Pretrained model into client-model and encrypted-server-model \n",
22
+ " 5. Retrain the server-side model on 8 bits\n",
23
+ " 6. Take output of the client model and truncate the floats to 8 bits\n",
24
+ "\n",
25
+ "Production\n",
26
+ " 1. Take a picture :)\n",
27
+ " 2. Evaluate client model on photo (clear)\n",
28
+ " 3. Truncate to 8 bits\n",
29
+ " 4. Encrypt \n",
30
+ " 5. Send encrypted data to server\n",
31
+ " 6. Send back encrypted result\n",
32
+ " 7. decrypt result\n"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {},
38
+ "source": [
39
+ "Step 1: Load Pretrained MobileNet"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "import torch\n",
49
+ "import torch.nn as nn\n",
50
+ "from torchvision import models\n",
51
+ "\n",
52
+ "# Load the pretrained MobileNet model\n",
53
+ "mobilenet = models.mobilenet_v2(pretrained=True)\n",
54
+ "\n",
55
+ "# Set model to evaluation mode\n",
56
+ "mobilenet.eval()\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "Step 2: Segment the Pretrained Model into Client and Server Parts"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# Client model - extracting up to the 10th layer (or any other cutoff)\n",
73
+ "client_model = nn.Sequential(*list(mobilenet.features.children())[:10])\n",
74
+ "\n",
75
+ "# Server model - the remaining layers\n",
76
+ "server_model = nn.Sequential(*list(mobilenet.features.children())[10:], mobilenet.classifier)\n",
77
+ "\n",
78
+ "# Freeze client model parameters (no need to retrain)\n",
79
+ "for param in client_model.parameters():\n",
80
+ " param.requires_grad = False"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "Step 3: Quantize the Server-Side Model to 8 Bits\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "from torch.quantization import quantize_dynamic\n",
97
+ "\n",
98
+ "# Quantize the server model\n",
99
+ "server_model_quantized = quantize_dynamic(\n",
100
+ " server_model, # Model to be quantized\n",
101
+ " {nn.Linear}, # Layers to quantize (we quantize fully connected layers here)\n",
102
+ " dtype=torch.qint8 # Quantize to 8-bit\n",
103
+ ")\n",
104
+ "\n",
105
+ "server_model_quantized.eval()"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "Step 4: Truncate the Client Model Output to 8 Bits"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "import numpy as np\n",
122
+ "\n",
123
+ "def truncate_to_8_bits(tensor):\n",
124
+ " # Scale the tensor to the range [0, 255]\n",
125
+ " tensor = torch.clamp(tensor, min=0, max=1)\n",
126
+ " tensor = tensor * 255.0\n",
127
+ " tensor = tensor.to(torch.uint8) # Convert to 8-bit integers\n",
128
+ " return tensor\n",
129
+ "\n",
130
+ "# Example input\n",
131
+ "input_image = torch.randn(1, 3, 224, 224) # A random image input\n",
132
+ "\n",
133
+ "# Client-side computation\n",
134
+ "client_output = client_model(input_image)\n",
135
+ "\n",
136
+ "# Truncate the output to 8 bits\n",
137
+ "client_output_8bit = truncate_to_8_bits(client_output)\n",
138
+ "\n",
139
+ "# The truncated output is now ready to be passed to the server\n"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "metadata": {},
145
+ "source": [
146
+ "Step 5: Server Model Inference on Quantized Data\n"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "# Ensure client output is in float format before feeding into server\n",
156
+ "client_output_8bit = client_output_8bit.float() / 255.0 # Rescale to [0, 1]\n",
157
+ "\n",
158
+ "# Run inference on the server-side model\n",
159
+ "server_output = server_model_quantized(client_output_8bit)\n",
160
+ "\n",
161
+ "# Output from the server model (class probabilities, etc.)\n",
162
+ "print(server_output)\n"
163
+ ]
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "language_info": {
168
+ "name": "python"
169
+ }
170
+ },
171
+ "nbformat": 4,
172
+ "nbformat_minor": 2
173
+ }