|
const tf = require('@tensorflow/tfjs-node');
|
|
const fs = require('fs');
|
|
const path = require('path');
|
|
|
|
async function loadImages(folder) {
|
|
const files = fs.readdirSync(folder);
|
|
const images = [];
|
|
const labels = [];
|
|
for (const file of files) {
|
|
const label = path.basename(folder);
|
|
const imageBuffer = fs.readFileSync(path.join(folder, file));
|
|
const imageTensor = tf.node.decodeImage(imageBuffer, 3)
|
|
.resizeNearestNeighbor([128, 128])
|
|
.toFloat()
|
|
.div(tf.scalar(255.0));
|
|
images.push(imageTensor);
|
|
labels.push(label);
|
|
}
|
|
return { images: tf.stack(images), labels };
|
|
}
|
|
|
|
async function loadDataset(basePath) {
|
|
const folders = fs.readdirSync(basePath);
|
|
const data = [];
|
|
const labelMap = {};
|
|
folders.forEach((folder, index) => labelMap[folder] = index);
|
|
|
|
for (const folder of folders) {
|
|
const { images, labels } = await loadImages(path.join(basePath, folder));
|
|
data.push({ images, labels: labels.map(label => labelMap[label]) });
|
|
}
|
|
|
|
return {
|
|
images: tf.concat(data.map(d => d.images)),
|
|
labels: tf.oneHot(tf.tensor1d(data.flatMap(d => d.labels), 'int32'), Object.keys(labelMap).length),
|
|
labelMap
|
|
};
|
|
}
|
|
|
|
async function trainModel() {
|
|
const basePath = './training_images';
|
|
const dataset = await loadDataset(basePath);
|
|
const { images, labels } = dataset;
|
|
|
|
const model = tf.sequential();
|
|
model.add(tf.layers.conv2d({
|
|
inputShape: [128, 128, 3],
|
|
filters: 32,
|
|
kernelSize: 3,
|
|
activation: 'relu'
|
|
}));
|
|
model.add(tf.layers.flatten());
|
|
model.add(tf.layers.dense({ units: Object.keys(dataset.labelMap).length, activation: 'softmax' }));
|
|
model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
|
|
|
|
await model.fit(images, labels, { epochs: 5 });
|
|
await model.save('file://./saved-model');
|
|
fs.writeFileSync('./labelMap.json', JSON.stringify(dataset.labelMap));
|
|
console.log('Model saved as TensorFlow.js format');
|
|
}
|
|
|
|
(async () => {
|
|
await trainModel();
|
|
})();
|
|
|