test_attribute.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. from math import isclose
  2. from numpy import random, array_equal
  3. from pandas import Series
  4. from ds4ml.attribute import Attribute
  5. from ds4ml.utils import randomize_string
  6. size = 30
  7. def test_integer_attribute():
  8. ints = random.randint(1, 100, size)
  9. attr = Attribute(Series(ints), name='ID', categorical=False)
  10. assert attr.type == 'integer'
  11. assert attr.name == 'ID'
  12. assert attr.min_ >= 1
  13. assert attr.max_ <= 100
  14. assert len(attr.bins) == 20
  15. assert isclose(sum(attr.prs), 1.0)
  16. from .testdata import adults01
  17. attr = Attribute(adults01['age'])
  18. assert attr.type == 'integer'
  19. def test_float_attribute():
  20. floats = random.uniform(1, 100, size)
  21. attr = Attribute(Series(floats, name='Float'))
  22. assert attr.type == 'float'
  23. assert attr.min_ >= 1
  24. assert attr.max_ <= 100
  25. assert len(attr.bins) == 20
  26. assert isclose(sum(attr.prs), 1.0)
  27. def test_string_attribute():
  28. strings = list(map(lambda x: randomize_string(5), range(size)))
  29. attr = Attribute(Series(strings, name='String'), categorical=True)
  30. assert attr.type == 'string'
  31. assert attr.min_ == 5
  32. assert attr.categorical
  33. def test_set_domain_for_integer_attribute():
  34. ints = random.randint(1, 100, size)
  35. attr = Attribute(Series(ints, name='Integer'))
  36. assert attr.min_ >= 1
  37. assert attr.max_ <= 100
  38. attr.domain = [-2, 120]
  39. assert attr.min_ == -2
  40. assert attr.max_ == 120
  41. def test_set_domain_for_integer_categorical_attribute():
  42. ints = random.randint(1, 100, size)
  43. attr = Attribute(Series(ints, name='Integer'), categorical=True)
  44. assert attr.bins[0] >= 1
  45. assert attr.bins[-1] <= 100
  46. attr.domain = [-2, 120]
  47. assert attr.bins[0] == -2
  48. assert attr.bins[-1] == 120
  49. def test_set_domain_for_float_attribute():
  50. floats = random.uniform(1, 100, size)
  51. attr = Attribute(Series(floats, name='Float'))
  52. assert attr.min_ >= 1
  53. assert attr.max_ <= 100
  54. attr.domain = [-2, 120]
  55. assert attr.min_ == -2
  56. assert attr.max_ == 120
  57. def test_set_domain_for_string_attribute():
  58. strings = list(map(lambda x: randomize_string(5), range(size)))
  59. attr = Attribute(Series(strings, name='String'), categorical=True)
  60. bins = attr.bins
  61. attr.domain = ['a', 'b', 'China', 'USA']
  62. assert len(bins) + 4 == len(attr.bins)
  63. def test_set_domain_for_datetime_attribute():
  64. dates = ['05/29/1988', '06/22/1988', '07/30/1992', '07/30/1992',
  65. '11/12/2000', '01/02/2001', '01/02/2001', '12/03/2001',
  66. '07/09/2002', '10/22/2002']
  67. attr = Attribute(Series(dates, name='String'), categorical=True)
  68. bins = attr.bins
  69. attr.domain = ['07/01/1997', '12/20/1999', '01/01/2004']
  70. assert len(bins) + 3 == len(attr.bins)
  71. def test_counts_numerical_attribute():
  72. ints = random.randint(1, 100, size)
  73. attr = Attribute(Series(ints, name='Integer'))
  74. counts = attr.counts(normalize=False)
  75. assert sum(counts) == 30
  76. assert len(counts) == 20
  77. counts = attr.counts(bins=[0, 10, 20, 30, 100], normalize=False)
  78. assert sum(counts) == 30
  79. assert len(counts) == 4
  80. # categorical ints
  81. attr = Attribute(Series([1, 10, 11, 10, 20, 15, 16, 25], name='Integer'),
  82. categorical=True)
  83. counts = attr.counts(normalize=False)
  84. assert sum(counts) == 8
  85. assert len(counts) == 7
  86. counts = attr.counts(bins=[5, 10, 15], normalize=False)
  87. assert sum(counts) == 3
  88. assert len(counts) == 3
  89. def test_decimals_float_attribute():
  90. floats = map(lambda v: round(v, 2), random.uniform(1, 10, size))
  91. attr = Attribute(Series(floats, name='Float'))
  92. assert attr.decimals() == 2
  93. def test_counts_datetimes():
  94. dates = ['05/29/1988', '06/22/1988', '07/30/1992', '07/30/1992',
  95. '11/12/2000', '01/02/2001', '01/02/2001', '12/03/2001',
  96. '07/09/2002', '10/22/2002']
  97. attr = Attribute(Series(dates, name='DateTime'), categorical=True)
  98. counts = attr.counts(normalize=False)
  99. assert sum(counts) == len(dates)
  100. assert array_equal(counts, [1, 1, 2, 1, 2, 1, 1, 1])
  101. counts = attr.counts(bins=['12/03/2001', '10/22/2002'], normalize=False)
  102. assert array_equal(counts, [1, 1])
  103. def test_counts_categorical_attribute():
  104. ints = random.randint(1, 10, size)
  105. attr = Attribute(Series(ints, name='Integer'), categorical=True)
  106. assert sum(attr.counts()) == 30
  107. def test_choice_integers():
  108. ints = random.randint(1, 100, size)
  109. attr = Attribute(Series(ints, name='Integer'))
  110. assert len(attr.bins) == 20
  111. choices = attr.choice()
  112. assert len(choices) == size
  113. def test_choice_floats():
  114. floats = random.uniform(1, 10, size)
  115. attr = Attribute(Series(floats, name='Float'))
  116. assert len(attr.bins) == 20
  117. choices = attr.choice()
  118. assert len(choices) == size
  119. def test_choice_strings():
  120. strings = list(map(lambda x: randomize_string(5), range(size)))
  121. attr = Attribute(Series(strings, name='String'))
  122. choices = attr.choice()
  123. assert len(choices) == size
  124. def test_choice_datetimes():
  125. dates = ['05/29/1988', '06/22/1988', '07/30/1992', '01/02/2001',
  126. '11/12/2000', '07/09/2002', '08/30/1998', '06/03/1997',
  127. '10/22/2002', '12/03/2001']
  128. attr = Attribute(Series(dates, name='DateTime'))
  129. choices = attr.choice()
  130. assert len(choices) == len(dates)
  131. def test_bin_indexes_ints():
  132. ints = [3, 5, 7, 8, 7, 1, 10, 30, 16, 19]
  133. attr = Attribute(Series(ints), name='ID', categorical=False)
  134. indexes = attr.bin_indexes()
  135. assert len(indexes) == len(ints)
  136. def test_bin_indexes_datetimes():
  137. dates = ['05/29/1988', '06/22/1988', '07/30/1992', '07/30/1992',
  138. '11/12/2000', '01/02/2001', '01/02/2001', '12/03/2001',
  139. '07/09/2002', '10/22/2002']
  140. attr = Attribute(Series(dates, name='DateTime'))
  141. indexes = attr.bin_indexes()
  142. assert len(indexes) == len(dates)
  143. def test_pseudonymize_strings():
  144. strings = Series(['Abc', 'edf', 'Abc', 'take', '中国', 'edf', 'Abc'])
  145. attr = Attribute(strings, name='String')
  146. pseudonyms = attr.pseudonymize()
  147. assert array_equal(strings.value_counts().values,
  148. pseudonyms.value_counts().values)
  149. def test_pseudonymize_ints():
  150. ints = Series([11, 2, 3, 4, 5, 4, 3, 2, 3, 4, 11])
  151. attr = Attribute(ints, name='Integer')
  152. pseudonyms = attr.pseudonymize()
  153. assert array_equal(ints.value_counts().values,
  154. pseudonyms.value_counts().values)
  155. def test_pseudonymize_floats():
  156. floats = Series([11.5, 2.6, 3.0, 4.3, 5, 4.3, 3.0, 2.6, 3.0, 4.3, 11.6])
  157. attr = Attribute(floats, name='Float')
  158. pseudonyms = attr.pseudonymize()
  159. assert array_equal(floats.value_counts().values,
  160. pseudonyms.value_counts().values)
  161. def test_pseudonym_dates():
  162. ints = Series(['07/15/2019', '07/24/2019', '07/23/2019', '07/22/2019',
  163. '07/21/2019', '07/22/2019', '07/23/2019', '07/24/2019',
  164. '07/23/2019', '07/22/2019', '07/15/2019'])
  165. attr = Attribute(ints, name='Date')
  166. pseudonyms = attr.pseudonymize()
  167. assert array_equal(ints.value_counts().values,
  168. pseudonyms.value_counts().values)
  169. def test_random_ints():
  170. ints = [3, 5, 7, 8, 7, 1, 10, 30, 16, 19]
  171. attr = Attribute(ints, name='Integer')
  172. randoms = attr.random()
  173. assert len(randoms) == len(ints)
  174. def test_random_datetimes():
  175. datetimes = ['07/15/2019', '07/24/2019', '07/23/2019', '07/22/2019',
  176. '07/21/2019', '07/22/2019', '07/23/2019', '07/24/2019',
  177. '07/23/2019', '07/22/2019', '07/15/2019']
  178. attr = Attribute(datetimes, name='Date')
  179. randoms = attr.random()
  180. assert len(randoms) == len(datetimes)
  181. def test_random_strings():
  182. strings = list(map(lambda x: randomize_string(5), range(size)))
  183. attr = Attribute(Series(strings, name='String'))
  184. randoms = attr.random()
  185. assert len(randoms) == size
  186. def test_retain_ints():
  187. ints = [3, 5, 7, 8, 7, 1, 10, 30, 16, 19]
  188. attr = Attribute(ints, name='Integer')
  189. retains = attr.retain()
  190. assert len(retains) == len(ints)
  191. retains = attr.retain(size=15)
  192. assert array_equal(retains.head(len(ints)).tolist(), ints)
  193. def test_encode_numerical_attributes():
  194. from .testdata import adults01
  195. attr = Attribute(adults01['age'])
  196. assert attr.bins[0] <= 19
  197. assert attr.bins[-1] >= 56
  198. assert len(attr.encode()) == len(attr)
  199. from sklearn.model_selection import train_test_split
  200. train, test = train_test_split(adults01['age'])
  201. assert len(attr.encode(data=train)) == len(train)
  202. def test_encode_categorical_attributes():
  203. from pandas import DataFrame
  204. from .testdata import adults01
  205. frame = DataFrame(adults01)
  206. attr = Attribute(frame['education'], categorical=True)
  207. columns = ['11th', '7th-8th', '9th', 'Assoc-acdm', 'Bachelors', 'Doctorate',
  208. 'HS-grad', 'Masters', 'Some-college']
  209. assert array_equal(attr.bins, columns)
  210. assert array_equal(attr.encode().columns, columns)
  211. def test_encode_datetime_attributes():
  212. from pandas import DataFrame
  213. from .testdata import adults01
  214. frame = DataFrame(adults01)
  215. attr = Attribute(frame['birth'])
  216. # assert other information
  217. assert len(attr.encode()) == len(attr)