In [None]:
import concrete.ml
import torch


Training: 
 1. Gather dataset of pictures
 2. Preprocess the data
 3. Find pretrained model 
 4. Segment Pretrained model into client-model and encrypted-server-model 
 5. Retrain the server-side model on 8 bits
 6. Take output of the client model and truncate the floats to 8 bits

Production
 1. Take a picture :)
 2. Evaluate client model on photo (clear)
 3. Truncate to 8 bits
 4. Encrypt 
 5. Send encrypted data to server
 6. Send back encrypted result
 7. decrypt result


Step 1: Load Pretrained MobileNet

In [None]:
import torch
import torch.nn as nn
from torchvision import models

# Load the pretrained MobileNet model
mobilenet = models.mobilenet_v2(pretrained=True)

# Set model to evaluation mode
mobilenet.eval()


Step 2: Segment the Pretrained Model into Client and Server Parts

In [None]:
# Client model - extracting up to the 10th layer (or any other cutoff)
client_model = nn.Sequential(*list(mobilenet.features.children())[:10])

# Server model - the remaining layers
server_model = nn.Sequential(*list(mobilenet.features.children())[10:], mobilenet.classifier)

# Freeze client model parameters (no need to retrain)
for param in client_model.parameters():
 param.requires_grad = False

Step 3: Quantize the Server-Side Model to 8 Bits


In [None]:
from torch.quantization import quantize_dynamic

# Quantize the server model
server_model_quantized = quantize_dynamic(
 server_model, # Model to be quantized
 {nn.Linear}, # Layers to quantize (we quantize fully connected layers here)
 dtype=torch.qint8 # Quantize to 8-bit
)

server_model_quantized.eval()

Step 4: Truncate the Client Model Output to 8 Bits

In [None]:
import numpy as np

def truncate_to_8_bits(tensor):
 # Scale the tensor to the range [0, 255]
 tensor = torch.clamp(tensor, min=0, max=1)
 tensor = tensor * 255.0
 tensor = tensor.to(torch.uint8) # Convert to 8-bit integers
 return tensor

# Example input
input_image = torch.randn(1, 3, 224, 224) # A random image input

# Client-side computation
client_output = client_model(input_image)

# Truncate the output to 8 bits
client_output_8bit = truncate_to_8_bits(client_output)

# The truncated output is now ready to be passed to the server


Step 5: Server Model Inference on Quantized Data


In [None]:
# Ensure client output is in float format before feeding into server
client_output_8bit = client_output_8bit.float() / 255.0 # Rescale to [0, 1]

# Run inference on the server-side model
server_output = server_model_quantized(client_output_8bit)

# Output from the server model (class probabilities, etc.)
print(server_output)
