minor bug + better defaults in test()
Browse files
cli.py
CHANGED
@@ -223,7 +223,7 @@ def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walk
|
|
223 |
nb_updates += 1
|
224 |
|
225 |
|
226 |
-
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=
|
227 |
if not os.path.exists(folder):
|
228 |
os.makedirs(folder, exist_ok=True)
|
229 |
dataset = load_dataset(dataset, split='train')
|
@@ -235,6 +235,7 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
235 |
model_path = os.path.join(folder, "model.th")
|
236 |
ae = torch.load(model_path, map_location="cpu")
|
237 |
ae = ae.to(device)
|
|
|
238 |
def enc(X):
|
239 |
batch_size = 64
|
240 |
h_list = []
|
@@ -267,12 +268,12 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
267 |
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
268 |
g_subset = g[:, 0:100]
|
269 |
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
270 |
-
imsave('{}/gen_full_iters.png'.format(folder), gr)
|
271 |
|
272 |
g = g[-1] # last iter
|
273 |
print(g.shape)
|
274 |
gr = grid_of_images_default(g.numpy())
|
275 |
-
imsave('{}/gen_full.png'.format(folder), gr)
|
276 |
|
277 |
if tsne:
|
278 |
from sklearn.manifold import TSNE
|
@@ -300,13 +301,13 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
300 |
print('fit tsne...')
|
301 |
ah = sne.fit_transform(ah)
|
302 |
print('grid embedding...')
|
303 |
-
|
304 |
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
305 |
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
306 |
rows = grid_embedding(ahsmall)
|
307 |
asmall = asmall[rows]
|
308 |
gr = grid_of_images_default(asmall)
|
309 |
-
imsave('{}/sne_grid.png'.format(folder), gr)
|
310 |
|
311 |
fig = plt.figure(figsize=(10, 10))
|
312 |
plot_dataset(ah, labels)
|
|
|
223 |
nb_updates += 1
|
224 |
|
225 |
|
226 |
+
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=25, nb_generate=100, nb_active=160, tsne=False):
|
227 |
if not os.path.exists(folder):
|
228 |
os.makedirs(folder, exist_ok=True)
|
229 |
dataset = load_dataset(dataset, split='train')
|
|
|
235 |
model_path = os.path.join(folder, "model.th")
|
236 |
ae = torch.load(model_path, map_location="cpu")
|
237 |
ae = ae.to(device)
|
238 |
+
ae.nb_active = nb_active # for fc_sparse.th only
|
239 |
def enc(X):
|
240 |
batch_size = 64
|
241 |
h_list = []
|
|
|
268 |
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
269 |
g_subset = g[:, 0:100]
|
270 |
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
271 |
+
imsave('{}/gen_full_iters.png'.format(folder), (gr*255).astype("uint8") )
|
272 |
|
273 |
g = g[-1] # last iter
|
274 |
print(g.shape)
|
275 |
gr = grid_of_images_default(g.numpy())
|
276 |
+
imsave('{}/gen_full.png'.format(folder), (gr*255).astype("uint8") )
|
277 |
|
278 |
if tsne:
|
279 |
from sklearn.manifold import TSNE
|
|
|
301 |
print('fit tsne...')
|
302 |
ah = sne.fit_transform(ah)
|
303 |
print('grid embedding...')
|
304 |
+
assert nb_generate >= 450
|
305 |
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
306 |
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
307 |
rows = grid_embedding(ahsmall)
|
308 |
asmall = asmall[rows]
|
309 |
gr = grid_of_images_default(asmall)
|
310 |
+
imsave('{}/sne_grid.png'.format(folder), (gr*255).astype("uint8") )
|
311 |
|
312 |
fig = plt.figure(figsize=(10, 10))
|
313 |
plot_dataset(ah, labels)
|
viz.py
CHANGED
@@ -116,8 +116,8 @@ def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normali
|
|
116 |
height, width, color = M[0].shape
|
117 |
assert color == 3, 'Nb of color channels are {}'.format(color)
|
118 |
if shape is None:
|
119 |
-
n0 = np.
|
120 |
-
n1 = np.
|
121 |
else:
|
122 |
n0 = shape[0]
|
123 |
n1 = shape[1]
|
|
|
116 |
height, width, color = M[0].shape
|
117 |
assert color == 3, 'Nb of color channels are {}'.format(color)
|
118 |
if shape is None:
|
119 |
+
n0 = np.int32(np.ceil(np.sqrt(numimages)))
|
120 |
+
n1 = np.int32(np.ceil(np.sqrt(numimages)))
|
121 |
else:
|
122 |
n0 = shape[0]
|
123 |
n1 = shape[1]
|