Spaces:
Running
Running
// Reference the elements we will use | |
const statusLabel = document.getElementById('status'); | |
const fileUpload = document.getElementById('upload'); | |
const imageContainer = document.getElementById('container'); | |
const example = document.getElementById('example'); | |
const maskCanvas = document.getElementById('mask-output'); | |
const uploadButton = document.getElementById('upload-button'); | |
const resetButton = document.getElementById('reset-image'); | |
const clearButton = document.getElementById('clear-points'); | |
const cutButton = document.getElementById('cut-mask'); | |
// State variables | |
let lastPoints = null; | |
let isEncoded = false; | |
let isDecoding = false; | |
let isMultiMaskMode = false; | |
let modelReady = false; | |
let imageDataURI = null; | |
// Constants | |
const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/'; | |
const EXAMPLE_URL = BASE_URL + 'corgi.jpg'; | |
// Create a web worker so that the main (UI) thread is not blocked during inference. | |
const worker = new Worker('worker.js', { | |
type: 'module', | |
}); | |
// Preload star and cross images to avoid lag on first click | |
const star = new Image(); | |
star.src = BASE_URL + 'star-icon.png'; | |
star.className = 'icon'; | |
const cross = new Image(); | |
cross.src = BASE_URL + 'cross-icon.png'; | |
cross.className = 'icon'; | |
// Set up message handler | |
worker.addEventListener('message', (e) => { | |
const { type, data } = e.data; | |
if (type === 'ready') { | |
modelReady = true; | |
statusLabel.textContent = 'Ready'; | |
} else if (type === 'decode_result') { | |
isDecoding = false; | |
if (!isEncoded) { | |
return; // We are not ready to decode yet | |
} | |
if (!isMultiMaskMode && lastPoints) { | |
// Perform decoding with the last point | |
decode(); | |
lastPoints = null; | |
} | |
const { mask, scores } = data; | |
// Update canvas dimensions (if different) | |
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { | |
maskCanvas.width = mask.width; | |
maskCanvas.height = mask.height; | |
} | |
// Create context and allocate buffer for pixel data | |
const context = maskCanvas.getContext('2d'); | |
const imageData = context.createImageData(maskCanvas.width, maskCanvas.height); | |
// Select best mask | |
const numMasks = scores.length; // 3 | |
let bestIndex = 0; | |
for (let i = 1; i < numMasks; ++i) { | |
if (scores[i] > scores[bestIndex]) { | |
bestIndex = i; | |
} | |
} | |
statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; | |
// Fill mask with colour | |
const pixelData = imageData.data; | |
for (let i = 0; i < pixelData.length; ++i) { | |
if (mask.data[numMasks * i + bestIndex] === 1) { | |
const offset = 4 * i; | |
pixelData[offset] = 0; // red | |
pixelData[offset + 1] = 114; // green | |
pixelData[offset + 2] = 189; // blue | |
pixelData[offset + 3] = 255; // alpha | |
} | |
} | |
// Draw image data to context | |
context.putImageData(imageData, 0, 0); | |
} else if (type === 'segment_result') { | |
if (data === 'start') { | |
statusLabel.textContent = 'Extracting image embedding...'; | |
} else { | |
statusLabel.textContent = 'Embedding extracted!'; | |
isEncoded = true; | |
} | |
} | |
}); | |
function decode() { | |
isDecoding = true; | |
worker.postMessage({ type: 'decode', data: lastPoints }); | |
} | |
function clearPointsAndMask() { | |
// Reset state | |
isMultiMaskMode = false; | |
lastPoints = null; | |
// Remove points from previous mask (if any) | |
document.querySelectorAll('.icon').forEach(e => e.remove()); | |
// Disable cut button | |
cutButton.disabled = true; | |
// Reset mask canvas | |
maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height); | |
} | |
clearButton.addEventListener('click', clearPointsAndMask); | |
resetButton.addEventListener('click', () => { | |
// Update state | |
isEncoded = false; | |
imageDataURI = null; | |
// Indicate to worker that we have reset the state | |
worker.postMessage({ type: 'reset' }); | |
// Clear points and mask (if present) | |
clearPointsAndMask(); | |
// Update UI | |
cutButton.disabled = true; | |
imageContainer.style.backgroundImage = 'none'; | |
uploadButton.style.display = 'flex'; | |
statusLabel.textContent = 'Ready'; | |
}); | |
function segment(data) { | |
// Update state | |
isEncoded = false; | |
if (!modelReady) { | |
statusLabel.textContent = 'Loading model...'; | |
} | |
imageDataURI = data; | |
// Update UI | |
imageContainer.style.backgroundImage = `url(${data})`; | |
uploadButton.style.display = 'none'; | |
cutButton.disabled = true; | |
// Instruct worker to segment the image | |
worker.postMessage({ type: 'segment', data }); | |
} | |
// Handle file selection | |
fileUpload.addEventListener('change', function (e) { | |
const file = e.target.files[0]; | |
if (!file) { | |
return; | |
} | |
const reader = new FileReader(); | |
// Set up a callback when the file is loaded | |
reader.onload = e2 => segment(e2.target.result); | |
reader.readAsDataURL(file); | |
}); | |
example.addEventListener('click', (e) => { | |
e.preventDefault(); | |
segment(EXAMPLE_URL); | |
}); | |
function addIcon({ point, label }) { | |
const icon = (label === 1 ? star : cross).cloneNode(); | |
icon.style.left = `${point[0] * 100}%`; | |
icon.style.top = `${point[1] * 100}%`; | |
imageContainer.appendChild(icon); | |
} | |
// Attach hover event to image container | |
imageContainer.addEventListener('mousedown', e => { | |
if (e.button !== 0 && e.button !== 2) { | |
return; // Ignore other buttons | |
} | |
if (!isEncoded) { | |
return; // Ignore if not encoded yet | |
} | |
if (!isMultiMaskMode) { | |
lastPoints = []; | |
isMultiMaskMode = true; | |
cutButton.disabled = false; | |
} | |
const point = getPoint(e); | |
lastPoints.push(point); | |
// add icon | |
addIcon(point); | |
decode(); | |
}); | |
// Clamp a value inside a range [min, max] | |
function clamp(x, min = 0, max = 1) { | |
return Math.max(Math.min(x, max), min) | |
} | |
function getPoint(e) { | |
// Get bounding box | |
const bb = imageContainer.getBoundingClientRect(); | |
// Get the mouse coordinates relative to the container | |
const mouseX = clamp((e.clientX - bb.left) / bb.width); | |
const mouseY = clamp((e.clientY - bb.top) / bb.height); | |
return { | |
point: [mouseX, mouseY], | |
label: e.button === 2 // right click | |
? 0 // negative prompt | |
: 1, // positive prompt | |
} | |
} | |
// Do not show context menu on right click | |
imageContainer.addEventListener('contextmenu', e => { | |
e.preventDefault(); | |
}); | |
// Attach hover event to image container | |
imageContainer.addEventListener('mousemove', e => { | |
if (!isEncoded || isMultiMaskMode) { | |
// Ignore mousemove events if the image is not encoded yet, | |
// or we are in multi-mask mode | |
return; | |
} | |
lastPoints = [getPoint(e)]; | |
if (!isDecoding) { | |
decode(); // Only decode if we are not already decoding | |
} | |
}); | |
// Handle cut button click | |
cutButton.addEventListener('click', () => { | |
const [w, h] = [maskCanvas.width, maskCanvas.height]; | |
// Get the mask pixel data | |
const maskContext = maskCanvas.getContext('2d'); | |
const maskPixelData = maskContext.getImageData(0, 0, w, h); | |
// Load the image | |
const image = new Image(); | |
image.crossOrigin = 'anonymous'; | |
image.onload = async () => { | |
// Create a new canvas to hold the image | |
const imageCanvas = new OffscreenCanvas(w, h); | |
const imageContext = imageCanvas.getContext('2d'); | |
imageContext.drawImage(image, 0, 0, w, h); | |
const imagePixelData = imageContext.getImageData(0, 0, w, h); | |
// Create a new canvas to hold the cut-out | |
const cutCanvas = new OffscreenCanvas(w, h); | |
const cutContext = cutCanvas.getContext('2d'); | |
const cutPixelData = cutContext.getImageData(0, 0, w, h); | |
// Copy the image pixel data to the cut canvas | |
for (let i = 3; i < maskPixelData.data.length; i += 4) { | |
if (maskPixelData.data[i] > 0) { | |
for (let j = 0; j < 4; ++j) { | |
const offset = i - j; | |
cutPixelData.data[offset] = imagePixelData.data[offset]; | |
} | |
} | |
} | |
cutContext.putImageData(cutPixelData, 0, 0); | |
// Download image | |
const link = document.createElement('a'); | |
link.download = 'image.png'; | |
link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); | |
link.click(); | |
link.remove(); | |
} | |
image.src = imageDataURI; | |
}); | |