123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
- import numpy as np
- def prep_CIFAR(root, bs):
- '''
- Preps the CIFAR Dataset from root/datasets/, loading in all
- the classes, using batch size of bs
- Outputs Data without saving
- '''
- # Load the data
- transform_train = transforms.Compose([
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
- transform_test = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- ])
- trainset = datasets.CIFAR10(root=root+'/datasets/CIFAR/', train=True,
- download=True, transform=transform_train)
- train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs,
- shuffle=True, num_workers=2)
- testset = datasets.CIFAR10(root=root+'/datasets/CIFAR/', train=False,
- download=True, transform=transform_test)
- test_loader = torch.utils.data.DataLoader(testset, batch_size=bs,
- shuffle=False, num_workers=2)
- # Finally compile loaders into Data structure
- Data = {}
- Data['bs'] = bs
- Data['x_test'] = []
- Data['y_test'] = []
- Data['x_train'] = []
- Data['y_train'] = []
-
- # Go through loaders collecting data
- for _, data in enumerate(train_loader, 0):
- Data['x_train'].append(data[0])
- Data['y_train'].append(data[1])
- for _, data in enumerate(test_loader, 0):
- Data['x_test'].append(data[0])
- Data['y_test'].append(data[1])
- Data['x_og'] = Data['x_train'].copy()
- Data['y_og'] = Data['y_train'].copy()
- return Data
- def prep_MNIST(root, bs):
- '''
- Preps the MNIST Dataset from root/datasets/, loading in all
- the classes, using batch size of bs
- Outputs Data without saving
- The images are flattened
- '''
- # Load the data
- transform = transforms.Compose(
- [transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
- trainset = datasets.MNIST(root=root+'/datasets/',train = True,download = True, transform=transform)
- testset = datasets.MNIST(root=root+'/datasets/',train = False,download = True, transform=transform)
- # now the loaders
- bs = 256
- train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True)
- test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=True)
- # Finally compile loaders into Data structure
- Data = {}
- Data['bs'] = bs
- Data['x_test'] = []
- Data['y_test'] = []
- Data['x_train'] = []
- Data['y_train'] = []
-
- # Go through loaders collecting data
- for _, data in enumerate(train_loader, 0):
- Data['x_train'].append((data[0].reshape(-1,1,1,28*28)))
- Data['y_train'].append(data[1])
- for _, data in enumerate(test_loader, 0):
- Data['x_test'].append((data[0].reshape(-1,1,1,28*28)))
- Data['y_test'].append(data[1])
- Data['x_og'] = Data['x_train'].copy()
- Data['y_og'] = Data['y_train'].copy()
- return Data
- class fast_trunc(nn.Module):
- def __init__(self, in_features, out_features, k):
- '''
- Custom truncation layer. Returns output of Linear Layer (FC)
- with truncation happening at each vector dot product
- Inputs:
- in_features - Dimension of input vector
- out_features - Dimension of output vector
- k - truncation param.
- '''
- super(fast_trunc, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.k = k
- # Initialize weight matrix and bias vector
- k0 = torch.sqrt(torch.tensor(1/(in_features)))
- w = -1*((-2*k0)*torch.rand(out_features,in_features)+k0)
- b = -1*((-2*k0)*torch.rand(out_features)+k0)
- self.weight = torch.nn.Parameter(data=w, requires_grad=True)
- self.bias = torch.nn.Parameter(data=b, requires_grad=True)
- def forward(self,x):
- # compute regular linear layer output, but save copy of x
- x_vals = x.clone().detach()
- x = torch.matmul(x,self.weight.T)
- temp = x_vals.view(-1,1,784)*self.weight
- # temp shape is (bs, out_dim, in_dim)
- val_1, _ = torch.topk(temp,self.k)
- val_2, _ = torch.topk(-1*temp,self.k)
- # val shapes are (bs, out_dim, self.k)
- x -= val_1.sum(axis=-1)
- x += val_2.sum(axis=-1)
- x += self.bias
- ####
- # MORE EFFICIENT IMPLEMENTATION WAS UPDATED AFTER REPORT,
- # IT IS IDENTICAL TO THE BELOW ORIGINAL IMPLEMENTAION,
- # BUT IF ANY ERROR OCCURS FEEL FREE TO REVERT BACK BY COMMENTING
- # LINES 121-130, and UNCOMMENTING LINES 140-149. The speedup is roughly
- # 30%.
- ####
- # OLD IMPLEMENTATION BEGINS
- # x_vals = x.clone().detach()
- # x = torch.matmul(x,self.weight.T)
- # # iterate over the result to apply truncation after
- # for i in range(x.shape[0]):
- # temp = x_vals[i,:]*self.weight
- # val_1, _ = torch.topk(temp,self.k)
- # val_2, _ = torch.topk(-1*temp,self.k)
- # x[i] -= torch.sum(val_1,dim=1)
- # x[i] += torch.sum(val_2,dim=1)
- # x += self.bias
- # OLD IMPLEMENTATION ENDS
-
- return x
- def trunc(x,k):
- '''
- Takes input x, and removes the top and bottom k features by
- zeroing them.
- Inputs:
- x - 4-dim: [bs,in_ch,out_dim,out_dim]
- k - truncation parameter
- Outputs:
- x - truncated version of input
- '''
- x_vals = x.clone().detach()
- out_dim = x.shape[3]
- # Now x is dimension [bs, in_ch, out_dim, out_dim]
- with torch.no_grad():
- for i in range(x.shape[0]):
- temp = x_vals[i,:]
- _, idx_1 = torch.topk(temp.view(-1,out_dim**2),k)
- _, idx_2 = torch.topk(-1*temp.view(-1,out_dim**2), k)
- for ch in range(x.shape[1]):
- x[i,ch,idx_1[ch]//out_dim,idx_1[ch]%out_dim] = 0
- x[i,ch,idx_2[ch]//out_dim,idx_2[ch]%out_dim] = 0
- return x
- def mu_sigma(beta, CIFAR=False):
- '''
- Rescales the original domain given Beta value.
- Inputs:
- beta - The magnitude by which we scale the domain
- CIFAR - bool to tell if we are using CIFAR or MNIST
- Outputs:
- mu - New mean for the data
- sigma - New std for the data
- '''
- # min/max pixel values for MNIST/CIFAR datasets
- if CIFAR:
- MIN = -2.429065704345703
- MAX = 2.7537312507629395
- else:
- MIN = -0.42421296
- MAX = 2.8214867
- # transfomration
- mu = MIN - (0.5-(1/(2*beta)))*(MAX-MIN)*beta
- sigma = (MAX-MIN)*beta
- return mu, sigma
- class Logger():
- '''
- Logger used within sparse_rs.py to record details on txt file.
- '''
- def __init__(self, log_path):
- self.log_path = log_path
- def log(self, str_to_log):
- with open(self.log_path, 'a') as f:
- f.write(str_to_log + '\n')
- f.flush()
- class flatten(object):
- '''
- Flatten into one dimension
- '''
- def __call__(self, sample):
- image = sample[0]
- new_image = torch.flatten(sample[0])
- return (new_image)
|