interval.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # interval.py: class implementing interval arithmetic
  2. # Copyright 2022 Romain Serra
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Union, Optional, Callable
  16. import numpy as np
  17. from swiftt.algebraic_abstract import AlgebraicAbstract
  18. IntervalOrScalar = Union[float, "Interval"]
  19. class Interval(AlgebraicAbstract):
  20. """Class representing intervals of real numbers.
  21. Attributes:
  22. _lb (float): lower bound.
  23. _ub (float): upper bound.
  24. """
  25. def __init__(self, lb: float, ub: float) -> None:
  26. self._lb = lb # no call to property on purpose as no sanity check is needed
  27. self.ub = ub
  28. def __len__(self) -> float:
  29. return self._ub - self._lb
  30. @property
  31. def ub(self) -> float:
  32. return self._ub
  33. @ub.setter
  34. def ub(self, ub: float) -> None:
  35. if ub < self.lb:
  36. raise ValueError("The upper bound cannot be strictly less than the lower one.")
  37. self._ub = ub
  38. @property
  39. def lb(self) -> float:
  40. return self._lb
  41. @lb.setter
  42. def lb(self, lb: float) -> None:
  43. if lb > self.ub:
  44. raise ValueError("The lower bound cannot be strictly greater than the upper one.")
  45. self._lb = lb
  46. def copy(self) -> "Interval":
  47. return Interval(self._lb, self._ub)
  48. @staticmethod
  49. def singleton(point: float) -> "Interval":
  50. return Interval(point, point)
  51. def __add__(self, other: IntervalOrScalar) -> "Interval":
  52. if isinstance(other, Interval):
  53. return Interval(self._lb + other.lb, self._ub + other.ub)
  54. # scalar case
  55. return Interval(self._lb + other, self._ub + other)
  56. def contains(self, other: IntervalOrScalar) -> bool:
  57. if isinstance(other, Interval):
  58. return other.lb >= self._lb and other.ub <= self._ub
  59. # scalar case
  60. return self._ub >= other >= self._lb
  61. def contains_zero(self) -> bool:
  62. return self.contains(0.)
  63. def __neg__(self) -> "Interval":
  64. return Interval(-self._ub, -self._lb)
  65. def __sub__(self, other: IntervalOrScalar) -> "Interval":
  66. if isinstance(other, Interval):
  67. return self + other.__neg__()
  68. # scalar case
  69. return Interval(self._lb - other, self._ub - other)
  70. def __mul__(self, other: IntervalOrScalar) -> "Interval":
  71. if isinstance(other, Interval):
  72. candidates = np.array([self._lb * other.lb, self._ub * other.lb, self._ub * other.ub,
  73. self._lb * other.ub])
  74. return Interval(np.min(candidates), np.max(candidates))
  75. # scalar case
  76. return self * Interval.singleton(other)
  77. def reciprocal(self) -> "Interval":
  78. if self.contains_zero():
  79. return Interval(-np.inf, np.inf)
  80. if self._ub != 0. and self._lb != 0.:
  81. inter = np.sort(1. / np.array([self._lb, self._ub]))
  82. return Interval(inter[0], inter[1])
  83. if self._ub != 0.:
  84. return Interval(1. / self._ub, np.inf)
  85. return Interval(-np.inf, 1. / self._lb)
  86. def __pow__(self, power: Union[int, float], modulo: Optional[float] = None) -> "Interval":
  87. if isinstance(power, int):
  88. if int(power / 2) == power / 2.:
  89. if self._lb >= 0:
  90. return Interval(self._lb**power, self._ub**power)
  91. if self._ub < 0.:
  92. return Interval(self._ub**power, self._lb**power)
  93. return Interval(0., max(self._lb**power, self._ub**power))
  94. return Interval(self._lb**power, self._ub**power)
  95. raise NotImplementedError
  96. def __str__(self) -> str:
  97. return "[" + str(self._lb) + ", " + str(self._ub) + "]"
  98. def __eq__(self, other: IntervalOrScalar) -> bool:
  99. if isinstance(other, Interval):
  100. return self._ub == other.ub and self._lb == other.lb
  101. # scalar case
  102. return self == self.singleton(other)
  103. def __abs__(self) -> "Interval":
  104. fabs = np.fabs([self._lb, self._ub])
  105. if self.contains_zero():
  106. return Interval(0., np.max(fabs))
  107. return Interval(np.min(fabs), np.max(fabs))
  108. def increasing_intrinsic(self, func: Callable) -> "Interval":
  109. return Interval(func(self.lb), func(self.ub))
  110. def sqrt(self) -> "Interval":
  111. return self.increasing_intrinsic(math.sqrt)
  112. def exp(self) -> "Interval":
  113. return self.increasing_intrinsic(math.exp)
  114. def log(self) -> "Interval":
  115. return self.increasing_intrinsic(math.log)