models.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. from utils.helpers import *
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class L_Net(nn.Module):
  6. '''
  7. 5-layer Linear Network
  8. Takes input a batch of flattened images from MNIST
  9. Layers:
  10. lin 1,2 - increase the dimension of the input
  11. fc 1,2,3 - decrease the dimension of the input
  12. NOTE softmax is not used at the output as nn.CrossEntropyLoss()
  13. takes care of this.
  14. '''
  15. def __init__(self):
  16. super(L_Net, self).__init__()
  17. self.lin1 = nn.Linear(784, 784*2)
  18. self.lin2 = nn.Linear(784*2, 3136)
  19. self.fc1 = nn.Linear(3136, 500)
  20. self.fc2 = nn.Linear(500, 100)
  21. self.fc3 = nn.Linear(100, 10)
  22. def forward(self, x):
  23. x = x.view(-1,28*28)
  24. x = F.relu(self.lin1(x))
  25. x = F.relu(self.lin2(x))
  26. x = F.relu(self.fc1(x))
  27. x = F.relu(self.fc2(x))
  28. x = self.fc3(x)
  29. return x
  30. class r_L_Net(nn.Module):
  31. '''
  32. 5-layer k-Truncated Linear Network (robust)
  33. Takes input a batch of flattened images from MNIST
  34. Inputs:
  35. k - truncation parameter
  36. Layers:
  37. lin 1,2 - increase the dimension of the input
  38. fc 1,2,3 - decrease the dimension of the input
  39. NOTE lin1 is the layer that preforms truncation
  40. NOTE softmax is not used at the output as nn.CrossEntropyLoss()
  41. takes care of this.
  42. '''
  43. def __init__(self, k):
  44. super(r_L_Net, self).__init__()
  45. self.lin1 = fast_trunc(784,784*2,k)
  46. self.lin2 = nn.Linear(784*2, 3136)
  47. self.fc1 = nn.Linear(3136, 500)
  48. self.fc2 = nn.Linear(500, 100)
  49. self.fc3 = nn.Linear(100, 10)
  50. def forward(self, x):
  51. x = x.view(-1,28*28)
  52. x = F.relu(self.lin1(x))
  53. x = F.relu(self.lin2(x))
  54. x = F.relu(self.fc1(x))
  55. x = F.relu(self.fc2(x))
  56. x = self.fc3(x)
  57. return x
  58. # EVAL NETWORKS FOR ATTACK SCRIPTS
  59. class L_Net_eval(nn.Module):
  60. '''
  61. Eval Version for L_Net() that transforms
  62. the original domain using mu and sigma.
  63. Inputs:
  64. mu, sigma - transform the original domain
  65. '''
  66. def __init__(self, mu, sigma):
  67. super(L_Net_eval, self).__init__()
  68. self.lin1 = nn.Linear(784,784*2)
  69. self.lin2 = nn.Linear(784*2, 3136)
  70. self.fc1 = nn.Linear(3136, 500)
  71. self.fc2 = nn.Linear(500, 100)
  72. self.fc3 = nn.Linear(100, 10)
  73. self.sigma = torch.tensor(sigma)
  74. self.mu = torch.tensor(mu)
  75. def forward(self, x):
  76. x = x.view(-1,28*28)
  77. x = (x*self.sigma)+self.mu
  78. x = F.relu(self.lin1(x))
  79. x = F.relu(self.lin2(x))
  80. x = F.relu(self.fc1(x))
  81. x = F.relu(self.fc2(x))
  82. x = self.fc3(x)
  83. return x
  84. def __call__(self,x):
  85. return self.forward(x)
  86. class r_L_Net_eval(nn.Module):
  87. '''
  88. Eval Version for r_L_Net() that transforms
  89. the original domain using mu and sigma.
  90. Inputs:
  91. mu, sigma - transform the original domain
  92. k - truncation paramater
  93. '''
  94. def __init__(self, k, mu, sigma):
  95. super(r_L_Net_eval, self).__init__()
  96. self.lin1 = fast_trunc(784,784*2,k)
  97. self.lin2 = nn.Linear(784*2, 3136)
  98. self.fc1 = nn.Linear(3136, 500)
  99. self.fc2 = nn.Linear(500, 100)
  100. self.fc3 = nn.Linear(100, 10)
  101. self.sigma = torch.tensor(sigma)
  102. self.mu = torch.tensor(mu)
  103. def forward(self, x):
  104. x = x.view(-1,28*28)
  105. x = (x*self.sigma)+self.mu
  106. x = F.relu(self.lin1(x))
  107. x = F.relu(self.lin2(x))
  108. x = F.relu(self.fc1(x))
  109. x = F.relu(self.fc2(x))
  110. x = self.fc3(x)
  111. return x
  112. def __call__(self,x):
  113. return self.forward(x)