File size: 2,180 Bytes
2830133 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
const path = require('path');
async function loadImages(folder) {
const files = fs.readdirSync(folder);
const images = [];
const labels = [];
for (const file of files) {
const label = path.basename(folder);
const imageBuffer = fs.readFileSync(path.join(folder, file));
const imageTensor = tf.node.decodeImage(imageBuffer, 3)
.resizeNearestNeighbor([128, 128])
.toFloat()
.div(tf.scalar(255.0));
images.push(imageTensor);
labels.push(label);
}
return { images: tf.stack(images), labels };
}
async function loadDataset(basePath) {
const folders = fs.readdirSync(basePath);
const data = [];
const labelMap = {};
folders.forEach((folder, index) => labelMap[folder] = index);
for (const folder of folders) {
const { images, labels } = await loadImages(path.join(basePath, folder));
data.push({ images, labels: labels.map(label => labelMap[label]) });
}
return {
images: tf.concat(data.map(d => d.images)),
labels: tf.oneHot(tf.tensor1d(data.flatMap(d => d.labels), 'int32'), Object.keys(labelMap).length),
labelMap
};
}
async function trainModel() {
const basePath = './training_images'; // Folder with labeled subfolders
const dataset = await loadDataset(basePath);
const { images, labels } = dataset;
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [128, 128, 3],
filters: 32,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: Object.keys(dataset.labelMap).length, activation: 'softmax' }));
model.compile({ optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
await model.fit(images, labels, { epochs: 5 });
await model.save('file://./saved-model'); // Save in TensorFlow.js format
console.log('Model saved as TensorFlow.js format');
}
(async () => {
await trainModel();
})();
|