import torch from torch.autograd import Variable from shutil import copyfile from torchsummary import summary from dataloader.dataloader import PrivacyDataloader from dataloader.dataset import Dataset def train(epoch, dataloader, net, criterion, optimizer, opt, writer): print("------------training_epoch: ", epoch, "----------------------------") for i, (adj_matrix, embedding_matrix, target) in enumerate(dataloader, 0): # print("---------", i, "-----------") net.zero_grad() # print(embedding_matrix) # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double() # init_input = torch.cat((annotation, padding), 2) # init_input = torch.zeros(len(adj_matrix), opt.n_node, opt.state_dim).double() # init_input = torch.from_numpy(embedding_matrix).double() init_input = embedding_matrix print("input_shape", init_input.shape) # print(init_input) if opt.cuda: init_input = init_input.cuda() adj_matrix = adj_matrix.cuda() # annotation = annotation.cuda() target = target.cuda() init_input = Variable(init_input) adj_matrix = Variable(adj_matrix) # annotation = Variable(annotation) target = Variable(target) # print(adj_matrix.shape) output = net(init_input, adj_matrix) # print("ouput_shape", output.shape) # print("target_shape", target.shape) print("output", output) print(target) loss = criterion(output, target) loss.backward() optimizer.step() print("loss", loss) writer.add_scalar('loss', loss.data.item(), int(epoch)) if i % int(len(dataloader) / 10 + 1) == 0 and opt.verbal: print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataloader), loss.item())) torch.save(net, opt.model_path) copyfile(opt.model_path, "{}.{}".format(opt.model_path, epoch)) if __name__ == '__main__': # traindata = Dataset( # "/Users/liufan/program/PYTHON/sap2nd/GnnForPrivacyScan/data/traindata/train", True) # dataloader = PrivacyDataloader(traindata, batch_size=5, shuffle=True, num_workers=2) pass