helpers.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision.datasets as datasets
  5. import torchvision.transforms as transforms
  6. import numpy as np
  7. def prep_CIFAR(root, bs):
  8. '''
  9. Preps the CIFAR Dataset from root/datasets/, loading in all
  10. the classes, using batch size of bs
  11. Outputs Data without saving
  12. '''
  13. # Load the data
  14. transform_train = transforms.Compose([
  15. transforms.RandomCrop(32, padding=4),
  16. transforms.RandomHorizontalFlip(),
  17. transforms.ToTensor(),
  18. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  19. ])
  20. transform_test = transforms.Compose([
  21. transforms.ToTensor(),
  22. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  23. ])
  24. trainset = datasets.CIFAR10(root=root+'/datasets/CIFAR/', train=True,
  25. download=True, transform=transform_train)
  26. train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs,
  27. shuffle=True, num_workers=2)
  28. testset = datasets.CIFAR10(root=root+'/datasets/CIFAR/', train=False,
  29. download=True, transform=transform_test)
  30. test_loader = torch.utils.data.DataLoader(testset, batch_size=bs,
  31. shuffle=False, num_workers=2)
  32. # Finally compile loaders into Data structure
  33. Data = {}
  34. Data['bs'] = bs
  35. Data['x_test'] = []
  36. Data['y_test'] = []
  37. Data['x_train'] = []
  38. Data['y_train'] = []
  39. # Go through loaders collecting data
  40. for _, data in enumerate(train_loader, 0):
  41. Data['x_train'].append(data[0])
  42. Data['y_train'].append(data[1])
  43. for _, data in enumerate(test_loader, 0):
  44. Data['x_test'].append(data[0])
  45. Data['y_test'].append(data[1])
  46. Data['x_og'] = Data['x_train'].copy()
  47. Data['y_og'] = Data['y_train'].copy()
  48. return Data
  49. def prep_MNIST(root, bs):
  50. '''
  51. Preps the MNIST Dataset from root/datasets/, loading in all
  52. the classes, using batch size of bs
  53. Outputs Data without saving
  54. The images are flattened
  55. '''
  56. # Load the data
  57. transform = transforms.Compose(
  58. [transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
  59. trainset = datasets.MNIST(root=root+'/datasets/',train = True,download = True, transform=transform)
  60. testset = datasets.MNIST(root=root+'/datasets/',train = False,download = True, transform=transform)
  61. # now the loaders
  62. bs = 256
  63. train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
  64. test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=True)
  65. # Finally compile loaders into Data structure
  66. Data = {}
  67. Data['bs'] = bs
  68. Data['x_test'] = []
  69. Data['y_test'] = []
  70. Data['x_train'] = []
  71. Data['y_train'] = []
  72. # Go through loaders collecting data
  73. for _, data in enumerate(train_loader, 0):
  74. Data['x_train'].append((data[0].reshape(-1,1,1,28*28)))
  75. Data['y_train'].append(data[1])
  76. for _, data in enumerate(test_loader, 0):
  77. Data['x_test'].append((data[0].reshape(-1,1,1,28*28)))
  78. Data['y_test'].append(data[1])
  79. Data['x_og'] = Data['x_train'].copy()
  80. Data['y_og'] = Data['y_train'].copy()
  81. return Data
  82. class fast_trunc(nn.Module):
  83. def __init__(self, in_features, out_features, k):
  84. '''
  85. Custom truncation layer. Returns output of Linear Layer (FC)
  86. with truncation happening at each vector dot product
  87. Inputs:
  88. in_features - Dimension of input vector
  89. out_features - Dimension of output vector
  90. k - truncation param.
  91. '''
  92. super(fast_trunc, self).__init__()
  93. self.in_features = in_features
  94. self.out_features = out_features
  95. self.k = k
  96. # Initialize weight matrix and bias vector
  97. k0 = torch.sqrt(torch.tensor(1/(in_features)))
  98. w = -1*((-2*k0)*torch.rand(out_features,in_features)+k0)
  99. b = -1*((-2*k0)*torch.rand(out_features)+k0)
  100. self.weight = torch.nn.Parameter(data=w, requires_grad=True)
  101. self.bias = torch.nn.Parameter(data=b, requires_grad=True)
  102. def forward(self,x):
  103. # compute regular linear layer output, but save copy of x
  104. x_vals = x.clone().detach()
  105. x = torch.matmul(x,self.weight.T)
  106. temp = x_vals.view(-1,1,784)*self.weight
  107. # temp shape is (bs, out_dim, in_dim)
  108. val_1, _ = torch.topk(temp,self.k)
  109. val_2, _ = torch.topk(-1*temp,self.k)
  110. # val shapes are (bs, out_dim, self.k)
  111. x -= val_1.sum(axis=-1)
  112. x += val_2.sum(axis=-1)
  113. x += self.bias
  114. ####
  115. # MORE EFFICIENT IMPLEMENTATION WAS UPDATED AFTER REPORT,
  116. # IT IS IDENTICAL TO THE BELOW ORIGINAL IMPLEMENTAION,
  117. # BUT IF ANY ERROR OCCURS FEEL FREE TO REVERT BACK BY COMMENTING
  118. # LINES 121-130, and UNCOMMENTING LINES 140-149. The speedup is roughly
  119. # 30%.
  120. ####
  121. # OLD IMPLEMENTATION BEGINS
  122. # x_vals = x.clone().detach()
  123. # x = torch.matmul(x,self.weight.T)
  124. # # iterate over the result to apply truncation after
  125. # for i in range(x.shape[0]):
  126. # temp = x_vals[i,:]*self.weight
  127. # val_1, _ = torch.topk(temp,self.k)
  128. # val_2, _ = torch.topk(-1*temp,self.k)
  129. # x[i] -= torch.sum(val_1,dim=1)
  130. # x[i] += torch.sum(val_2,dim=1)
  131. # x += self.bias
  132. # OLD IMPLEMENTATION ENDS
  133. return x
  134. def trunc(x,k):
  135. '''
  136. Takes input x, and removes the top and bottom k features by
  137. zeroing them.
  138. Inputs:
  139. x - 4-dim: [bs,in_ch,out_dim,out_dim]
  140. k - truncation parameter
  141. Outputs:
  142. x - truncated version of input
  143. '''
  144. x_vals = x.clone().detach()
  145. out_dim = x.shape[3]
  146. # Now x is dimension [bs, in_ch, out_dim, out_dim]
  147. with torch.no_grad():
  148. for i in range(x.shape[0]):
  149. temp = x_vals[i,:]
  150. _, idx_1 = torch.topk(temp.view(-1,out_dim**2),k)
  151. _, idx_2 = torch.topk(-1*temp.view(-1,out_dim**2), k)
  152. for ch in range(x.shape[1]):
  153. x[i,ch,idx_1[ch]//out_dim,idx_1[ch]%out_dim] = 0
  154. x[i,ch,idx_2[ch]//out_dim,idx_2[ch]%out_dim] = 0
  155. return x
  156. def mu_sigma(beta, CIFAR=False):
  157. '''
  158. Rescales the original domain given Beta value.
  159. Inputs:
  160. beta - The magnitude by which we scale the domain
  161. CIFAR - bool to tell if we are using CIFAR or MNIST
  162. Outputs:
  163. mu - New mean for the data
  164. sigma - New std for the data
  165. '''
  166. # min/max pixel values for MNIST/CIFAR datasets
  167. if CIFAR:
  168. MIN = -2.429065704345703
  169. MAX = 2.7537312507629395
  170. else:
  171. MIN = -0.42421296
  172. MAX = 2.8214867
  173. # transfomration
  174. mu = MIN - (0.5-(1/(2*beta)))*(MAX-MIN)*beta
  175. sigma = (MAX-MIN)*beta
  176. return mu, sigma
  177. class Logger():
  178. '''
  179. Logger used within sparse_rs.py to record details on txt file.
  180. '''
  181. def __init__(self, log_path):
  182. self.log_path = log_path
  183. def log(self, str_to_log):
  184. with open(self.log_path, 'a') as f:
  185. f.write(str_to_log + '\n')
  186. f.flush()
  187. class flatten(object):
  188. '''
  189. Flatten into one dimension
  190. '''
  191. def __call__(self, sample):
  192. image = sample[0]
  193. new_image = torch.flatten(sample[0])
  194. return (new_image)