test_utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. from ds4ml.utils import (plot_histogram,
  5. plot_confusion_matrix,
  6. plot_heatmap, write_csv, mutual_information,
  7. normalize_range, is_datetime, str_to_list,
  8. normalize_distribution, has_header,
  9. read_data_from_csv, ends_with_json)
  10. def test_plot_confusion_matrix_output_string():
  11. from pandas import DataFrame
  12. df = DataFrame({'True': [2, 3], 'False': [5, 0]})
  13. res = plot_confusion_matrix(df)
  14. assert type(res) == str
  15. assert res.startswith('<svg')
  16. assert res.endswith('</svg>')
  17. @pytest.mark.skip(reason='Need manually test to check figures.')
  18. # Please remove the annotation when manually test
  19. def test_plot_figures_output_show_special_characters():
  20. bins = np.array(['你好', 'Self-る', '¥¶ĎǨД'])
  21. counts = np.array([[6, 2, 1], [6, 2, 1]])
  22. plot_histogram(bins, counts, otype='show')
  23. @pytest.mark.skip(reason='Need manually test to check figures.')
  24. def test_plot_figures_output_show():
  25. from pandas import DataFrame
  26. plot_confusion_matrix(DataFrame({'True': [2, 3],
  27. 'False': [5, 0]}),
  28. otype='show')
  29. plot_confusion_matrix(DataFrame({'7th-8th': [2, 3, 5, 0],
  30. 'Masters': [0, 4, 1, 0],
  31. '11th': [0, 1, 5, 2],
  32. 'Bachelors': [2, 0, 0, 6]}),
  33. otype='show')
  34. bins = np.array([28., 29.25, 30.5, 31.75, 33., 34.25, 35.5, 36.75, 38.,
  35. 39.25, 40.5, 41.75, 43., 44.25, 45.5, 46.75, 48., 49.25,
  36. 50.5])
  37. counts = np.array(
  38. [[1, 0, 1, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
  39. [1, 0, 1, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
  40. plot_histogram(bins, counts, otype='show')
  41. bins = np.array(['Private', 'Self-emp-not-inc', 'State-gov'])
  42. counts = np.array([[6, 2, 1], [6, 2, 1]])
  43. plot_histogram(bins, counts, otype='show')
  44. bins = np.array(['11th', '9th', 'Bachelors', 'HS-grad', 'Masters'])
  45. counts = np.array([[3, 2, 2, 1, 1], [3, 2, 2, 1, 1]])
  46. plot_histogram(bins, counts, otype='show')
  47. bins = np.array([5., 5.45, 5.9, 6.35, 6.8, 7.25, 7.7, 8.15, 8.6,
  48. 9.05, 9.5, 9.95, 10.4, 10.85, 11.3, 11.75, 12.2, 12.65,
  49. 13.1])
  50. counts = np.array(
  51. [[1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0],
  52. [1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0]])
  53. plot_histogram(bins, counts, otype='show')
  54. bins = np.array(['Female', 'Male'])
  55. counts = np.array([[5, 4], [5, 4]])
  56. plot_histogram(bins, counts, otype='show')
  57. from .testdata import adults01
  58. from ds4ml.metrics import pairwise_mutual_information
  59. data = pairwise_mutual_information(DataFrame(adults01))
  60. plot_heatmap(data, otype='show')
  61. @pytest.mark.skip(reason='TODO')
  62. def test_mutual_information():
  63. from pandas import DataFrame
  64. from .testdata import adults01
  65. frame = DataFrame(adults01)
  66. print(mutual_information(frame['age'], frame.drop('age', axis=1)))
  67. def test_write_csv():
  68. data = [['epsilon', 'c00', 'precision'], [0.2, 157, 0.4]]
  69. import os
  70. name = '__test.csv'
  71. if os.path.exists(name) and os.path.isfile(name):
  72. os.remove(name)
  73. write_csv(name, data)
  74. assert os.path.exists(name)
  75. assert os.path.isfile(name)
  76. with open(name, 'r') as file:
  77. assert file.readline().strip() == 'epsilon,c00,precision'
  78. assert file.readline().strip() == '0.2,157,0.4'
  79. file.close()
  80. os.remove(name)
  81. def test_normalize_range_ints():
  82. from numpy.random import randint
  83. for i in range(50):
  84. start = randint(0, 5)
  85. stop = randint(start + 1, 200)
  86. bins = randint(8, 30)
  87. ints = normalize_range(start, stop, bins)
  88. assert len(ints) <= bins + 1
  89. def test_normalize_range_floats():
  90. from numpy.random import randint, rand
  91. for i in range(50):
  92. start = round(randint(0, 5) * rand(), 4)
  93. stop = round(randint(0, 200) * rand(), 4) + 5
  94. bins = randint(8, 30)
  95. floats = normalize_range(start, stop, bins)
  96. assert len(floats) <= bins + 1
  97. def test_is_datetime():
  98. date = 'monday'
  99. idt = is_datetime(date)
  100. assert idt is False
  101. time = '2020-03-01'
  102. idt = is_datetime(time)
  103. assert idt is True
  104. value = 'high school'
  105. idt = is_datetime(value)
  106. assert idt is False
  107. def test_str_to_list():
  108. iva = '1,3,4,5'
  109. res = str_to_list(iva)
  110. assert res == ['1', '3', '4', '5']
  111. iva = 'name,age,weight,height'
  112. res = str_to_list(iva)
  113. assert res == ['name', 'age', 'weight', 'height']
  114. def test_normalize_distribution():
  115. frequencies = [3, 3, 2]
  116. res = normalize_distribution(frequencies)
  117. assert res[0] == 0.375
  118. assert res[1] == 0.375
  119. assert res[2] == 0.25
  120. def test_has_header():
  121. from .testdata import adult_with_head, adult_without_head
  122. import io
  123. hasheader = has_header(io.StringIO(adult_with_head))
  124. assert hasheader is True
  125. hasheader = has_header(io.StringIO(adult_without_head))
  126. assert hasheader is False
  127. def test_read_data_from_csv():
  128. from pandas import DataFrame
  129. from .testdata import adult_with_head, adult_with_head_res
  130. import io
  131. data = read_data_from_csv(io.StringIO(adult_with_head))
  132. assert data.equals(DataFrame(adult_with_head_res)) is True
  133. def test_ends_with_json():
  134. assert ends_with_json("d.json") is True
  135. assert ends_with_json("a.json") is True
  136. assert ends_with_json("data\ A.jSon") is True
  137. assert ends_with_json("data A.jSon") is True