const fs = require('fs'); | |
const sharp = require('sharp'); | |
const ort = require('onnxruntime-node'); | |
(async () => { | |
try { | |
// Step 1: Load and preprocess the image | |
const imageBuffer = await sharp('./training_images/hat.jpg') | |
.resize(128, 128) // Resize to 128x128 | |
.raw() // Get raw pixel data | |
.toBuffer(); | |
// Normalize to [0, 1] and create input tensor | |
const imgArray = Float32Array.from(imageBuffer).map(value => value / 255.0); | |
const inputTensor = new ort.Tensor('float32', imgArray, [1, 128, 128, 3]); | |
// Step 2: Load ONNX model | |
const session = await ort.InferenceSession.create('./saved-model/model.onnx'); | |
// Step 3: Run inference | |
const results = await session.run({ [session.inputNames[0]]: inputTensor }); | |
// Step 4: Get output probabilities | |
const probabilities = results[session.outputNames[0]].data; // Float32Array | |
const labelIndex = probabilities.indexOf(Math.max(...probabilities)); // Find the index of the max probability | |
// Step 5: Load the label map | |
const labelMap = JSON.parse(fs.readFileSync('./labelMap.json', 'utf8')); // Assuming you saved the label map | |
const label = Object.keys(labelMap).find(key => labelMap[key] === labelIndex); | |
console.log(`Predicted label: ${label}`); | |
console.log(`Confidencex: ${(probabilities[labelIndex] * 100).toFixed(2)}%`); | |
} catch (err) { | |
console.error('Error:', err); | |
} | |
})(); | |