12345678910111213141516171819202122232425262728293031323334353637 |
- import chainer
- from chainer import optimizers
- import nutszebra_basic_print
- class Optimizer(object):
- def __init__(self, model=None):
- self.model = model
- self.optimizer = None
- def __call__(self, i):
- pass
- def update(self):
- self.optimizer.update()
- class OptimizerWideResBinaryTree(Optimizer):
- def __init__(self, model=None, schedule=(60, 120, 160), lr=0.1, momentum=0.9, weight_decay=5.0e-4):
- super(OptimizerWideResBinaryTree, self).__init__(model)
- optimizer = optimizers.MomentumSGD(lr, momentum)
- weight_decay = chainer.optimizer.WeightDecay(weight_decay)
- optimizer.setup(self.model)
- optimizer.add_hook(weight_decay)
- self.optimizer = optimizer
- self.schedule = schedule
- self.lr = lr
- self.momentum = momentum
- self.weight_decay = weight_decay
- def __call__(self, i):
- if i in self.schedule:
- lr = self.optimizer.lr * 0.2
- print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
- self.optimizer.lr = lr
|