|
|
|
import init, { Model } from "./build/m.js"; |
|
|
|
async function fetchArrayBuffer(url, cacheModel = true) { |
|
if (!cacheModel) |
|
return new Uint8Array(await (await fetch(url)).arrayBuffer()); |
|
const cacheName = "sam-candle-cache"; |
|
const cache = await caches.open(cacheName); |
|
const cachedResponse = await cache.match(url); |
|
if (cachedResponse) { |
|
const data = await cachedResponse.arrayBuffer(); |
|
return new Uint8Array(data); |
|
} |
|
const res = await fetch(url, { cache: "force-cache" }); |
|
cache.put(url, res.clone()); |
|
return new Uint8Array(await res.arrayBuffer()); |
|
} |
|
class SAMModel { |
|
static instance = {}; |
|
|
|
static imageArrayHash = {}; |
|
|
|
static currentModelID = null; |
|
|
|
static async getInstance(modelURL, modelID) { |
|
if (!this.instance[modelID]) { |
|
await init(); |
|
|
|
self.postMessage({ |
|
status: "loading", |
|
message: `Loading Model ${modelID}`, |
|
}); |
|
const weightsArrayU8 = await fetchArrayBuffer(modelURL); |
|
this.instance[modelID] = new Model( |
|
weightsArrayU8, |
|
/tiny|mobile/.test(modelID) |
|
); |
|
} else { |
|
self.postMessage({ status: "loading", message: "Model Already Loaded" }); |
|
} |
|
|
|
this.currentModelID = modelID; |
|
return this.instance[modelID]; |
|
} |
|
|
|
|
|
static setImageEmbeddings(imageArrayU8) { |
|
|
|
const imageArrayHash = this.getSimpleHash(imageArrayU8); |
|
if ( |
|
this.imageArrayHash[this.currentModelID] === imageArrayHash && |
|
this.instance[this.currentModelID] |
|
) { |
|
self.postMessage({ |
|
status: "embedding", |
|
message: "Embeddings Already Set", |
|
}); |
|
return; |
|
} |
|
this.imageArrayHash[this.currentModelID] = imageArrayHash; |
|
this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); |
|
self.postMessage({ status: "embedding", message: "Embeddings Set" }); |
|
} |
|
|
|
static getSimpleHash(imageArrayU8) { |
|
|
|
let imageArrayHash = 0; |
|
for (let i = 0; i < imageArrayU8.length; i += 100) { |
|
imageArrayHash ^= imageArrayU8[i]; |
|
} |
|
return imageArrayHash.toString(16); |
|
} |
|
} |
|
|
|
async function createImageCanvas( |
|
{ mask_shape, mask_data }, |
|
{ original_width, original_height, width, height } |
|
) { |
|
const [_, __, shape_width, shape_height] = mask_shape; |
|
const maskCanvas = new OffscreenCanvas(shape_width, shape_height); |
|
const maskCtx = maskCanvas.getContext("2d"); |
|
const canvas = new OffscreenCanvas(original_width, original_height); |
|
const ctx = canvas.getContext("2d"); |
|
|
|
const imageData = maskCtx.createImageData( |
|
maskCanvas.width, |
|
maskCanvas.height |
|
); |
|
const data = imageData.data; |
|
|
|
for (let p = 0; p < data.length; p += 4) { |
|
data[p] = 0; |
|
data[p + 1] = 0; |
|
data[p + 2] = 0; |
|
data[p + 3] = mask_data[p / 4] * 255; |
|
} |
|
maskCtx.putImageData(imageData, 0, 0); |
|
|
|
let sx, sy; |
|
if (original_height < original_width) { |
|
sy = original_height / original_width; |
|
sx = 1; |
|
} else { |
|
sy = 1; |
|
sx = original_width / original_height; |
|
} |
|
ctx.drawImage( |
|
maskCanvas, |
|
0, |
|
0, |
|
maskCanvas.width * sx, |
|
maskCanvas.height * sy, |
|
0, |
|
0, |
|
original_width, |
|
original_height |
|
); |
|
|
|
const blob = await canvas.convertToBlob(); |
|
return URL.createObjectURL(blob); |
|
} |
|
|
|
self.addEventListener("message", async (event) => { |
|
const { modelURL, modelID, imageURL, points } = event.data; |
|
try { |
|
self.postMessage({ status: "loading", message: "Starting SAM" }); |
|
const sam = await SAMModel.getInstance(modelURL, modelID); |
|
|
|
self.postMessage({ status: "loading", message: "Loading Image" }); |
|
const imageArrayU8 = await fetchArrayBuffer(imageURL, false); |
|
|
|
self.postMessage({ status: "embedding", message: "Creating Embeddings" }); |
|
SAMModel.setImageEmbeddings(imageArrayU8); |
|
if (!points) { |
|
|
|
self.postMessage({ |
|
status: "complete-embedding", |
|
message: "Embeddings Complete", |
|
}); |
|
return; |
|
} |
|
|
|
self.postMessage({ status: "segmenting", message: "Segmenting" }); |
|
const { mask, image } = sam.mask_for_point({ points }); |
|
const maskDataURL = await createImageCanvas(mask, image); |
|
|
|
self.postMessage({ |
|
status: "complete", |
|
message: "Segmentation Complete", |
|
output: { maskURL: maskDataURL }, |
|
}); |
|
} catch (e) { |
|
self.postMessage({ error: e }); |
|
} |
|
}); |
|
|