augmentation_main.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from __future__ import print_function, unicode_literals
  2. import os
  3. from twisted.python import filepath
  4. from twisted.trial import unittest
  5. from .. import database
  6. from ..database import (CHANNELDB_TARGET_VERSION, USAGEDB_TARGET_VERSION,
  7. _get_db, dump_db, DBError)
  8. class Get(unittest.TestCase):
  9. def test_create_default(self):
  10. db_url = ":memory:"
  11. db = _get_db(db_url, "channel", CHANNELDB_TARGET_VERSION)
  12. rows = db.execute("SELECT * FROM version").fetchall()
  13. self.assertEqual(len(rows), 1)
  14. self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)
  15. def test_open_existing_file(self):
  16. basedir = self.mktemp()
  17. os.mkdir(basedir)
  18. fn = os.path.join(basedir, "normal.db")
  19. db = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
  20. rows = db.execute("SELECT * FROM version").fetchall()
  21. self.assertEqual(len(rows), 1)
  22. self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)
  23. db2 = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
  24. rows = db2.execute("SELECT * FROM version").fetchall()
  25. self.assertEqual(len(rows), 1)
  26. self.assertEqual(rows[0]["version"], CHANNELDB_TARGET_VERSION)
  27. def test_open_bad_version(self):
  28. basedir = self.mktemp()
  29. os.mkdir(basedir)
  30. fn = os.path.join(basedir, "old.db")
  31. db = _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
  32. db.execute("UPDATE version SET version=999")
  33. db.commit()
  34. with self.assertRaises(DBError) as e:
  35. _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
  36. self.assertIn("Unable to handle db version 999", str(e.exception))
  37. def test_open_corrupt(self):
  38. basedir = self.mktemp()
  39. os.mkdir(basedir)
  40. fn = os.path.join(basedir, "corrupt.db")
  41. with open(fn, "wb") as f:
  42. f.write(b"I am not a database")
  43. with self.assertRaises(DBError) as e:
  44. _get_db(fn, "channel", CHANNELDB_TARGET_VERSION)
  45. self.assertIn("not a database", str(e.exception))
  46. def test_failed_create_allows_subsequent_create(self):
  47. patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
  48. dbfile = filepath.FilePath(self.mktemp())
  49. self.assertRaises(Exception, lambda: _get_db(dbfile.path))
  50. patch.restore()
  51. _get_db(dbfile.path, "channel", CHANNELDB_TARGET_VERSION)
  52. def test_upgrade(self):
  53. basedir = self.mktemp()
  54. os.mkdir(basedir)
  55. fn = os.path.join(basedir, "upgrade.db")
  56. self.assertNotEqual(USAGEDB_TARGET_VERSION, 1)
  57. # create an old-version DB in a file
  58. db = _get_db(fn, "usage", 1)
  59. rows = db.execute("SELECT * FROM version").fetchall()
  60. self.assertEqual(len(rows), 1)
  61. self.assertEqual(rows[0]["version"], 1)
  62. del db
  63. # then upgrade the file to the latest version
  64. dbA = _get_db(fn, "usage", USAGEDB_TARGET_VERSION)
  65. rows = dbA.execute("SELECT * FROM version").fetchall()
  66. self.assertEqual(len(rows), 1)
  67. self.assertEqual(rows[0]["version"], USAGEDB_TARGET_VERSION)
  68. dbA_text = dump_db(dbA)
  69. del dbA
  70. # make sure the upgrades got committed to disk
  71. dbB = _get_db(fn, "usage", USAGEDB_TARGET_VERSION)
  72. dbB_text = dump_db(dbB)
  73. del dbB
  74. self.assertEqual(dbA_text, dbB_text)
  75. # The upgraded schema should be equivalent to that of a new DB.
  76. latest_db = _get_db(":memory:", "usage", USAGEDB_TARGET_VERSION)
  77. latest_text = dump_db(latest_db)
  78. with open("up.sql","w") as f: f.write(dbA_text)
  79. with open("new.sql","w") as f: f.write(latest_text)
  80. # debug with "diff -u _trial_temp/up.sql _trial_temp/new.sql"
  81. self.assertEqual(dbA_text, latest_text)
  82. def test_upgrade_fails(self):
  83. basedir = self.mktemp()
  84. os.mkdir(basedir)
  85. fn = os.path.join(basedir, "upgrade.db")
  86. self.assertNotEqual(USAGEDB_TARGET_VERSION, 1)
  87. # create an old-version DB in a file
  88. db = _get_db(fn, "usage", 1)
  89. rows = db.execute("SELECT * FROM version").fetchall()
  90. self.assertEqual(len(rows), 1)
  91. self.assertEqual(rows[0]["version"], 1)
  92. del db
  93. # then upgrade the file to a too-new version, for which we have no
  94. # upgrader
  95. with self.assertRaises(DBError):
  96. _get_db(fn, "usage", USAGEDB_TARGET_VERSION+1)
  97. class CreateChannel(unittest.TestCase):
  98. def test_memory(self):
  99. db = database.create_channel_db(":memory:")
  100. latest_text = dump_db(db)
  101. self.assertIn("CREATE TABLE", latest_text)
  102. def test_preexisting(self):
  103. basedir = self.mktemp()
  104. os.mkdir(basedir)
  105. fn = os.path.join(basedir, "preexisting.db")
  106. with open(fn, "w"):
  107. pass
  108. with self.assertRaises(database.DBAlreadyExists):
  109. database.create_channel_db(fn)
  110. def test_create(self):
  111. basedir = self.mktemp()
  112. os.mkdir(basedir)
  113. fn = os.path.join(basedir, "created.db")
  114. db = database.create_channel_db(fn)
  115. latest_text = dump_db(db)
  116. self.assertIn("CREATE TABLE", latest_text)
  117. def test_create_or_upgrade(self):
  118. basedir = self.mktemp()
  119. os.mkdir(basedir)
  120. fn = os.path.join(basedir, "created.db")
  121. db = database.create_or_upgrade_channel_db(fn)
  122. latest_text = dump_db(db)
  123. self.assertIn("CREATE TABLE", latest_text)
  124. class CreateUsage(unittest.TestCase):
  125. def test_memory(self):
  126. db = database.create_usage_db(":memory:")
  127. latest_text = dump_db(db)
  128. self.assertIn("CREATE TABLE", latest_text)
  129. def test_preexisting(self):
  130. basedir = self.mktemp()
  131. os.mkdir(basedir)
  132. fn = os.path.join(basedir, "preexisting.db")
  133. with open(fn, "w"):
  134. pass
  135. with self.assertRaises(database.DBAlreadyExists):
  136. database.create_usage_db(fn)
  137. def test_create(self):
  138. basedir = self.mktemp()
  139. os.mkdir(basedir)
  140. fn = os.path.join(basedir, "created.db")
  141. db = database.create_usage_db(fn)
  142. latest_text = dump_db(db)
  143. self.assertIn("CREATE TABLE", latest_text)
  144. def test_create_or_upgrade(self):
  145. basedir = self.mktemp()
  146. os.mkdir(basedir)
  147. fn = os.path.join(basedir, "created.db")
  148. db = database.create_or_upgrade_usage_db(fn)
  149. latest_text = dump_db(db)
  150. self.assertIn("CREATE TABLE", latest_text)
  151. def test_create_or_upgrade_disabled(self):
  152. db = database.create_or_upgrade_usage_db(None)
  153. self.assertIs(db, None)
  154. class OpenChannel(unittest.TestCase):
  155. def test_open(self):
  156. basedir = self.mktemp()
  157. os.mkdir(basedir)
  158. fn = os.path.join(basedir, "created.db")
  159. db1 = database.create_channel_db(fn)
  160. latest_text = dump_db(db1)
  161. self.assertIn("CREATE TABLE", latest_text)
  162. db2 = database.open_existing_db(fn)
  163. self.assertIn("CREATE TABLE", dump_db(db2))
  164. def test_doesnt_exist(self):
  165. basedir = self.mktemp()
  166. os.mkdir(basedir)
  167. fn = os.path.join(basedir, "created.db")
  168. with self.assertRaises(database.DBDoesntExist):
  169. database.open_existing_db(fn)
  170. class OpenUsage(unittest.TestCase):
  171. def test_open(self):
  172. basedir = self.mktemp()
  173. os.mkdir(basedir)
  174. fn = os.path.join(basedir, "created.db")
  175. db1 = database.create_usage_db(fn)
  176. latest_text = dump_db(db1)
  177. self.assertIn("CREATE TABLE", latest_text)
  178. db2 = database.open_existing_db(fn)
  179. self.assertIn("CREATE TABLE", dump_db(db2))
  180. def test_doesnt_exist(self):
  181. basedir = self.mktemp()
  182. os.mkdir(basedir)
  183. fn = os.path.join(basedir, "created.db")
  184. with self.assertRaises(database.DBDoesntExist):
  185. database.open_existing_db(fn)