grg's picture
Cleaned old git history
be5548b
raw
history blame
806 Bytes
import random
import numpy
import torch
import collections
def seed(seed):
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def synthesize(array):
d = collections.OrderedDict()
d["mean"] = numpy.mean(array)
d["std"] = numpy.std(array)
d["min"] = numpy.amin(array)
d["max"] = numpy.amax(array)
return d
# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
classname = m.__class__.__name__
if classname.find("Linear") != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)