testNpState.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu May 28 17:57:07 2020
  4. @author: Ian Lim
  5. """
  6. #import sys
  7. from statefuncs import State, npState
  8. import numpy as np
  9. import unittest
  10. class TestStates(unittest.TestCase):
  11. def test_state_setup(self):
  12. occs = [0,1,0,1,0]
  13. nmax = 2
  14. #if L or m is not provided, raise an error
  15. with self.assertRaises(TypeError):
  16. State(occs,nmax,L=None,m=None)
  17. with self.assertRaises(TypeError):
  18. State(occs,nmax,L=1.,m=None)
  19. with self.assertRaises(TypeError):
  20. State(occs,nmax,L=None,m=1.)
  21. #should raise ValueError: state not at rest when nmax is shifted
  22. with self.assertRaises(ValueError):
  23. State(occs,1,m=1.,L=1.)
  24. def test_npstate_setup(self):
  25. occs = [0,1,0,1,0]
  26. nmax = 2
  27. with self.assertRaises(TypeError):
  28. npState(occs,nmax)
  29. with self.assertRaises(TypeError):
  30. npState(occs,nmax,L=1.)
  31. with self.assertRaises(TypeError):
  32. npState(occs,nmax,m=1.)
  33. #should raise ValueError: state not at rest when nmax is shifted
  34. with self.assertRaises(ValueError):
  35. npState(occs,1,m=1.,L=1.)
  36. def test_state_equality(self):
  37. occs = [2,1,0,1,2]
  38. nmax = 2
  39. occsList = [[2,1,0,1,2],[0,0,0,0,0],[0,5,0,5,0]]
  40. myState = State(occs,nmax, m=2,L=1.)
  41. for occs in occsList:
  42. for mass in np.arange(5):
  43. myState = State(occs, nmax, m=mass, L=1.)
  44. mynpState = npState(occs, nmax, m=mass, L=1.)
  45. self.assertEqual(myState.energy, mynpState.energy)
  46. self.assertEqual(myState.isParityEigenstate, mynpState.isParityEigenstate)
  47. self.assertEqual(myState.momentum,mynpState.momentum)
  48. self.assertEqual(myState.totalWN,mynpState.totalWN)
  49. if __name__ == "__main__":
  50. unittest.main()