File size: 2,373 Bytes
86f89a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import {
AutoModelForImageClassification,
AutoProcessor,
AutoTokenizer,
env,
RawImage,
} from '@huggingface/transformers';
// Configure environment
env.localModelPath = './'; // Path to your ONNX model
env.allowRemoteModels = false; // Disable remote models
async function hasFp16() {
try {
const adapter = await navigator.gpu.requestAdapter();
return adapter.features.has('shader-f16');
} catch (e) {
return false;
}
}
class CustomModelSingleton {
static model_id = 'saved-model/'; // Path to your custom ONNX model
static async getInstance(progress_callback = null) {
this.processor ??= await AutoProcessor.from_pretrained(this.model_id);
this.tokenizer ??= await AutoTokenizer.from_pretrained(this.model_id);
this.supports_fp16 ??= await hasFp16();
this.model ??= await AutoModelForImageClassification.from_pretrained(this.model_id, {
dtype: this.supports_fp16 ? 'fp16' : 'fp32',
device: 'webgpu', // Change as per your hardware
progress_callback,
});
return Promise.all([this.model, this.tokenizer, this.processor]);
}
}
async function load() {
self.postMessage({
status: 'loading',
data: 'Loading custom model...',
});
const [model, tokenizer, processor] = await CustomModelSingleton.getInstance((x) => {
self.postMessage(x);
});
self.postMessage({
status: 'ready',
data: 'Model loaded successfully.',
});
}
async function run({ imagePath, task }) {
const [model, tokenizer, processor] = await CustomModelSingleton.getInstance();
// Read and preprocess image
const image = await RawImage.fromURL(imagePath); // Or use fromBlob if required
const vision_inputs = await processor(image);
// Run inference
const results = await model.predict(vision_inputs);
self.postMessage({ status: 'complete', result: results });
}
self.addEventListener('message', async (e) => {
const { type, data } = e.data;
switch (type) {
case 'load':
load();
break;
case 'run':
run(data);
break;
case 'reset':
vision_inputs = null;
break;
}
});
|