test_against.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import unittest
  2. import numpy as np
  3. from swiftt.taylor import factory_taylor
  4. from swiftt.taylor.real_multivar_taylor import RealMultivarTaylor
  5. from swiftt.math_algebra import cos, sin, exp, sqrt
  6. tol_coeff = 1.e-12
  7. null_expansion_2var_order2 = factory_taylor.zero_expansion(2, 2)
  8. null_expansion_2var_order3 = factory_taylor.zero_expansion(2, 3)
  9. null_expansion_3var_order2 = factory_taylor.zero_expansion(3, 2)
  10. null_expansion_4var_order5 = factory_taylor.zero_expansion(4, 5)
  11. def intermediate(order: int) -> RealMultivarTaylor:
  12. x, y, z = factory_taylor.create_unknown_map(order=order, consts=[1., 2., -1.], var_names=["x", "y", "z"])
  13. g = exp((sin(x) * cos(y) + 1.) / sqrt(1. + x**2 + y**2 + z**2))
  14. g = g.deriv_once_wrt_var(1).integ_once_wrt_var(0).integ_once_wrt_var(2)
  15. return g
  16. class TestAgainst(unittest.TestCase):
  17. def test_non_regression(self):
  18. g = intermediate(order=4)
  19. regre_coeff = [0., 0., 0., 0., 0., 0.,
  20. -0.45942301, 0., 0., 0., 0., 0.,
  21. -0.02996078, 0., 0.57760928, -0.05369175, 0., 0.,
  22. 0., 0., 0., 0., 0.06033602, 0.,
  23. 0.04807805, 0.01270756, 0., -0.19353301, 0.11813837, 0.0072345,
  24. 0., 0., 0., 0., 0.]
  25. if not np.allclose(regre_coeff, g.coeff):
  26. self.fail()
  27. def test_pyaudi(self):
  28. try:
  29. order = 4
  30. g1 = intermediate(order)
  31. from pyaudi.core import gdual_double
  32. x = gdual_double(1., "x", order)
  33. y = gdual_double(2., "y", order)
  34. z = gdual_double(-1., "z", order)
  35. g2 = ((x.sin() * y.cos() + 1.) / ((x**2 + y**2 + z**2 + 1.).sqrt())).exp()
  36. g2 = g2.partial("y").integrate("x").integrate("z")
  37. for exponent in g1.get_mapping_monom().keys():
  38. dict_deriv = {"dx": exponent[0], "dy": exponent[1], "dz": exponent[2]}
  39. self.assertAlmostEqual(g1.get_partial_deriv(exponent), g2.get_derivative(dict_deriv), delta=1.e-15)
  40. except ImportError:
  41. pass
  42. def test_sympy(self):
  43. try:
  44. from sympy import poly
  45. str1 = "x**4 - 2 * x**3 + x**2 - x + 2 + x * y + 4 * y**2 + x * y * z +" \
  46. " z**3 + x * z**3 - 3 * x**2 * y**2 - 5 * y**4"
  47. str2 = "3 * x**4 - x**2 - z * x + 1 + x**2 * y + 4 * y**4 + x * y **2 * z +" \
  48. " y * z**3 - x * z**3 - 4 * x**2 * y * z + 3 * y**3"
  49. poly1, poly2 = poly(str1), poly(str2)
  50. poly_prod = poly1 * poly2
  51. expans1 = factory_taylor.from_string(str1, order=8)
  52. expans2 = factory_taylor.from_string(str2, order=8)
  53. expans_prod = expans1 * expans2
  54. coeff = expans_prod.coeff
  55. mapping = expans1.get_mapping_monom()
  56. for el, monom in zip(poly_prod.coeffs(), poly_prod.monoms()):
  57. if coeff[mapping[monom]] != el:
  58. self.fail()
  59. except ImportError:
  60. pass
  61. if __name__ == '__main__':
  62. unittest.main()