test.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. from torch.autograd import Variable
  3. from shutil import copyfile
  4. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
  5. from torchsummary import summary
  6. def test(dataloader, net, criterion, optimizer, opt):
  7. test_loss = 0
  8. correct = 0
  9. net.eval()
  10. all_targets = []
  11. all_predicted = []
  12. print("len", len(dataloader))
  13. for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0):
  14. # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double()
  15. # init_input = torch.cat((annotation, padding), 2)
  16. # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double()
  17. print(adj_matrix.shape)
  18. print("target.shape", target.shape)
  19. init_input = embedding_matrix
  20. if opt.cuda:
  21. init_input = init_input.cuda()
  22. adj_matrix = adj_matrix.cuda()
  23. # annotation = annotation.cuda()
  24. target = target.cuda()
  25. init_input = Variable(init_input)
  26. adj_matrix = Variable(adj_matrix)
  27. # annotation = Variable(annotation)
  28. target = Variable(target)
  29. print("init_input_shape", init_input.shape)
  30. print("target", target)
  31. # summary(net, init_input.shape, batch_size=5)
  32. output = net(init_input, adj_matrix)
  33. print("output", output)
  34. # test_loss += criterion(output, target).data[0]
  35. test_loss += criterion(output, target).item()
  36. pred = output.data.max(1, keepdim=True)[1]
  37. print("pred", pred)
  38. all_predicted.extend(pred.data.view_as(target).cpu().numpy())
  39. all_targets.extend(target.cpu().numpy())
  40. correct += pred.eq(target.data.view_as(pred)).cpu().sum()
  41. test_loss /= len(dataloader.dataset)
  42. print('Accuracy:', accuracy_score(all_targets, all_predicted))
  43. print(classification_report(all_targets, all_predicted))
  44. print(confusion_matrix(all_targets, all_predicted))
  45. print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
  46. test_loss, correct, len(dataloader.dataset),
  47. 100. * correct / len(dataloader.dataset)))