issue1038 / nodejs /customVision.js
jrsimuix's picture
experiment and look at finetuning with ms-florence..
86f89a6 verified
raw
history blame
2.37 kB
import {
AutoModelForImageClassification,
AutoProcessor,
AutoTokenizer,
env,
RawImage,
} from '@huggingface/transformers';
// Configure environment
env.localModelPath = './'; // Path to your ONNX model
env.allowRemoteModels = false; // Disable remote models
async function hasFp16() {
try {
const adapter = await navigator.gpu.requestAdapter();
return adapter.features.has('shader-f16');
} catch (e) {
return false;
}
}
class CustomModelSingleton {
static model_id = 'saved-model/'; // Path to your custom ONNX model
static async getInstance(progress_callback = null) {
this.processor ??= await AutoProcessor.from_pretrained(this.model_id);
this.tokenizer ??= await AutoTokenizer.from_pretrained(this.model_id);
this.supports_fp16 ??= await hasFp16();
this.model ??= await AutoModelForImageClassification.from_pretrained(this.model_id, {
dtype: this.supports_fp16 ? 'fp16' : 'fp32',
device: 'webgpu', // Change as per your hardware
progress_callback,
});
return Promise.all([this.model, this.tokenizer, this.processor]);
}
}
async function load() {
self.postMessage({
status: 'loading',
data: 'Loading custom model...',
});
const [model, tokenizer, processor] = await CustomModelSingleton.getInstance((x) => {
self.postMessage(x);
});
self.postMessage({
status: 'ready',
data: 'Model loaded successfully.',
});
}
async function run({ imagePath, task }) {
const [model, tokenizer, processor] = await CustomModelSingleton.getInstance();
// Read and preprocess image
const image = await RawImage.fromURL(imagePath); // Or use fromBlob if required
const vision_inputs = await processor(image);
// Run inference
const results = await model.predict(vision_inputs);
self.postMessage({ status: 'complete', result: results });
}
self.addEventListener('message', async (e) => {
const { type, data } = e.data;
switch (type) {
case 'load':
load();
break;
case 'run':
run(data);
break;
case 'reset':
vision_inputs = null;
break;
}
});