File size: 2,373 Bytes
86f89a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import {
    AutoModelForImageClassification,
    AutoProcessor,
    AutoTokenizer,
	env,
    RawImage,
} from '@huggingface/transformers';

// Configure environment
env.localModelPath = './'; // Path to your ONNX model
env.allowRemoteModels = false; // Disable remote models

async function hasFp16() {
    try {
        const adapter = await navigator.gpu.requestAdapter();
        return adapter.features.has('shader-f16');
    } catch (e) {
        return false;
    }
}

class CustomModelSingleton {
    static model_id = 'saved-model/'; // Path to your custom ONNX model

    static async getInstance(progress_callback = null) {
        this.processor ??= await AutoProcessor.from_pretrained(this.model_id);
        this.tokenizer ??= await AutoTokenizer.from_pretrained(this.model_id);

        this.supports_fp16 ??= await hasFp16();
        this.model ??= await AutoModelForImageClassification.from_pretrained(this.model_id, {
            dtype: this.supports_fp16 ? 'fp16' : 'fp32',
            device: 'webgpu', // Change as per your hardware
            progress_callback,
        });

        return Promise.all([this.model, this.tokenizer, this.processor]);
    }
}

async function load() {
    self.postMessage({
        status: 'loading',
        data: 'Loading custom model...',
    });

    const [model, tokenizer, processor] = await CustomModelSingleton.getInstance((x) => {
        self.postMessage(x);
    });

    self.postMessage({
        status: 'ready',
        data: 'Model loaded successfully.',
    });
}

async function run({ imagePath, task }) {
    const [model, tokenizer, processor] = await CustomModelSingleton.getInstance();

    // Read and preprocess image
    const image = await RawImage.fromURL(imagePath); // Or use fromBlob if required
    const vision_inputs = await processor(image);

    // Run inference
    const results = await model.predict(vision_inputs);

    self.postMessage({ status: 'complete', result: results });
}

self.addEventListener('message', async (e) => {
    const { type, data } = e.data;

    switch (type) {
        case 'load':
            load();
            break;

        case 'run':
            run(data);
            break;

        case 'reset':
            vision_inputs = null;
            break;
    }
});