nutszebra_optimizer.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import chainer
  2. from chainer import optimizers
  3. import nutszebra_basic_print
  4. class Optimizer(object):
  5. def __init__(self, model=None):
  6. self.model = model
  7. self.optimizer = None
  8. def __call__(self, i):
  9. pass
  10. def update(self):
  11. self.optimizer.update()
  12. class OptimizerWideResBinaryTree(Optimizer):
  13. def __init__(self, model=None, schedule=(60, 120, 160), lr=0.1, momentum=0.9, weight_decay=5.0e-4):
  14. super(OptimizerWideResBinaryTree, self).__init__(model)
  15. optimizer = optimizers.MomentumSGD(lr, momentum)
  16. weight_decay = chainer.optimizer.WeightDecay(weight_decay)
  17. optimizer.setup(self.model)
  18. optimizer.add_hook(weight_decay)
  19. self.optimizer = optimizer
  20. self.schedule = schedule
  21. self.lr = lr
  22. self.momentum = momentum
  23. self.weight_decay = weight_decay
  24. def __call__(self, i):
  25. if i in self.schedule:
  26. lr = self.optimizer.lr * 0.2
  27. print('lr is changed: {} -> {}'.format(self.optimizer.lr, lr))
  28. self.optimizer.lr = lr