12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import torch
- from torch.autograd import Variable
- from shutil import copyfile
- from torchsummary import summary
- from dataloader.dataloader import PrivacyDataloader
- from dataloader.dataset import Dataset
- def train(epoch, dataloader, net, criterion, optimizer, opt, writer):
- print("------------training_epoch: ", epoch, "----------------------------")
- for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0):
- # print("---------", i, "-----------")
- net.zero_grad()
- # print(embedding_matrix)
- # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double()
- # init_input = torch.cat((annotation, padding), 2)
- # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double()
- # init_input = torch.from_numpy(embedding_matrix).double()
- init_input = embedding_matrix
- print("input_shape", init_input.shape)
- # print(init_input)
- if opt.cuda:
- init_input = init_input.cuda()
- adj_matrix = adj_matrix.cuda()
- # annotation = annotation.cuda()
- target = target.cuda()
- init_input = Variable(init_input)
- adj_matrix = Variable(adj_matrix)
- # annotation = Variable(annotation)
- target = Variable(target)
- # print(adj_matrix.shape)
- output = net(init_input, adj_matrix)
- # print("ouput_shape", output.shape)
- # print("target_shape", target.shape)
- print("output", output)
- print(target)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
- print("loss", loss)
- writer.add_scalar('loss', loss.data.item(), int(epoch))
- if i % int(len(dataloader) / 10 + 1) == 0 and opt.verbal:
- print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataloader), loss.item()))
- torch.save(net, opt.model_path)
- copyfile(opt.model_path, "{}.{}".format(opt.model_path, epoch))
- if __name__ == '__main__':
- # traindata = Dataset(
- # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True)
- # dataloader = PrivacyDataloader(traindata, batch_size=5, shuffle=True, num_workers=2)
- pass
|