{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import concrete.ml\n", "import torch\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training: \n", " 1. Gather dataset of pictures\n", " 2. Preprocess the data\n", " 3. Find pretrained model \n", " 4. Segment Pretrained model into client-model and encrypted-server-model \n", " 5. Retrain the server-side model on 8 bits\n", " 6. Take output of the client model and truncate the floats to 8 bits\n", "\n", "Production\n", " 1. Take a picture :)\n", " 2. Evaluate client model on photo (clear)\n", " 3. Truncate to 8 bits\n", " 4. Encrypt \n", " 5. Send encrypted data to server\n", " 6. Send back encrypted result\n", " 7. decrypt result\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 1: Load Pretrained MobileNet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torchvision import models\n", "\n", "# Load the pretrained MobileNet model\n", "mobilenet = models.mobilenet_v2(pretrained=True)\n", "\n", "# Set model to evaluation mode\n", "mobilenet.eval()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 2: Segment the Pretrained Model into Client and Server Parts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Client model - extracting up to the 10th layer (or any other cutoff)\n", "client_model = nn.Sequential(*list(mobilenet.features.children())[:10])\n", "\n", "# Server model - the remaining layers\n", "server_model = nn.Sequential(*list(mobilenet.features.children())[10:], mobilenet.classifier)\n", "\n", "# Freeze client model parameters (no need to retrain)\n", "for param in client_model.parameters():\n", " param.requires_grad = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 3: Quantize the Server-Side Model to 8 Bits\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.quantization import quantize_dynamic\n", "\n", "# Quantize the server model\n", "server_model_quantized = quantize_dynamic(\n", " server_model, # Model to be quantized\n", " {nn.Linear}, # Layers to quantize (we quantize fully connected layers here)\n", " dtype=torch.qint8 # Quantize to 8-bit\n", ")\n", "\n", "server_model_quantized.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 4: Truncate the Client Model Output to 8 Bits" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def truncate_to_8_bits(tensor):\n", " # Scale the tensor to the range [0, 255]\n", " tensor = torch.clamp(tensor, min=0, max=1)\n", " tensor = tensor * 255.0\n", " tensor = tensor.to(torch.uint8) # Convert to 8-bit integers\n", " return tensor\n", "\n", "# Example input\n", "input_image = torch.randn(1, 3, 224, 224) # A random image input\n", "\n", "# Client-side computation\n", "client_output = client_model(input_image)\n", "\n", "# Truncate the output to 8 bits\n", "client_output_8bit = truncate_to_8_bits(client_output)\n", "\n", "# The truncated output is now ready to be passed to the server\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 5: Server Model Inference on Quantized Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Ensure client output is in float format before feeding into server\n", "client_output_8bit = client_output_8bit.float() / 255.0 # Rescale to [0, 1]\n", "\n", "# Run inference on the server-side model\n", "server_output = server_model_quantized(client_output_8bit)\n", "\n", "# Output from the server model (class probabilities, etc.)\n", "print(server_output)\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }