tsp_christofides.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. import sys
  4. # A utility function that return the smallest unprocessed edge
  5. def getMin(G, mstFlag):
  6. min = sys.maxsize # assigning largest numeric value to min
  7. for i in [(u, v, edata['length']) for u, v, edata in G.edges( data = True) if 'length' in edata ]:
  8. if mstFlag[i] == False and i[2] < min:
  9. min = i[2]
  10. min_edge = i
  11. return min_edge
  12. # A utility function to find root or origin of the node i in MST
  13. def findRoot(parent, i):
  14. if parent[i] == i:
  15. return i
  16. return findRoot(parent, parent[i])
  17. # A function that does union of set x and y based on the order
  18. def union(parent, order, x, y):
  19. xRoot = findRoot(parent, x)
  20. yRoot = findRoot(parent, y)
  21. # Attach smaller order tree under root of high order tree
  22. if order[xRoot] < order[yRoot]:
  23. parent[xRoot] = yRoot
  24. elif order[xRoot] > order[yRoot]:
  25. parent[yRoot] = xRoot
  26. # If orders are same, then make any one as root and increment its order by one
  27. else :
  28. parent[yRoot] = xRoot
  29. order[xRoot] += 1
  30. #function that performs kruskals algorithm on the graph G
  31. def genMinimumSpanningTree(G):
  32. MST = nx.Graph()
  33. eLen = len(G.edges()) # eLen denotes the number of edges in G
  34. vLen = len(G.nodes()) # vLen denotes the number of vertices in G
  35. mst = [] # mst contains the MST edges
  36. mstFlag = {} # mstFlag[i] will hold true if the edge i has been processed for MST
  37. for i in [ (u, v, edata['length']) for u, v, edata in G.edges(data = True) if 'length' in edata ]:
  38. mstFlag[i] = False
  39. parent = [None] * vLen # parent[i] will hold the vertex connected to i, in the MST
  40. order = [None] * vLen # order[i] will hold the order of appearance of the node in the MST
  41. for v in range(vLen):
  42. parent[v] = v
  43. order[v] = 0
  44. while len(mst) < vLen - 1 :
  45. curr_edge = getMin(G, mstFlag) # pick the smallest egde from the set of edges
  46. mstFlag[curr_edge] = True # update the flag for the current edge
  47. y = findRoot(parent, curr_edge[1])
  48. x = findRoot(parent, curr_edge[0])
  49. # adds the edge to MST, if including it doesn't form a cycle
  50. if x != y:
  51. mst.append(curr_edge)
  52. union(parent, order, x, y)
  53. # Else discard the edge
  54. for X in mst:
  55. if (X[0], X[1]) in G.edges():
  56. MST.add_edge(X[0], X[1], length = G[X[0]][X[1]]['length'])
  57. return MST
  58. #utility function that adds minimum weight matching edges to MST
  59. def minimumWeightedMatching(MST, G, odd_vert):
  60. while odd_vert:
  61. v = odd_vert.pop()
  62. length = float("inf")
  63. u = 1
  64. closest = 0
  65. for u in odd_vert:
  66. if G[v][u]['length'] < length :
  67. length = G[v][u]['length']
  68. closest = u
  69. MST.add_edge(v, closest, length = length)
  70. odd_vert.remove(closest)
  71. def christofedes(G ,pos):
  72. opGraph=nx.DiGraph()
  73. #optimal_dist = 0
  74. MST = genMinimumSpanningTree(G) # generates minimum spanning tree of graph G, using Prim's algo
  75. odd_vert = [] #list containing vertices with odd degree
  76. for i in MST.nodes():
  77. if MST.degree(i)%2 != 0:
  78. odd_vert.append(i) #if the degree of the vertex is odd, then append it to odd_vert list
  79. minimumWeightedMatching(MST, G, odd_vert) #adds minimum weight matching edges to MST
  80. # now MST has the Eulerian circuit
  81. start = MST.nodes()[0]
  82. visited = [False] * len(MST.nodes())
  83. # finds the hamiltonian circuit
  84. curr = start
  85. visited[curr] = True
  86. for nd in MST.neighbors(curr):
  87. if visited[nd] == False or nd == start:
  88. next = nd
  89. break
  90. while next != start:
  91. visited[next]=True
  92. opGraph.add_edge(curr,next,length = G[curr][next]['length'])
  93. nx.draw_networkx_edges(G, pos, arrows = True, edgelist = [(curr, next)], width = 2.5, alpha = 0.6, edge_color = 'r')
  94. # optimal_dist = optimal_dist + G[curr][next]['length']
  95. # finding the shortest Eulerian path from MST
  96. curr = next
  97. for nd in MST.neighbors(curr):
  98. if visited[nd] == False:
  99. next = nd
  100. break
  101. if next == curr:
  102. for nd in G.neighbors(curr):
  103. if visited[nd] == False:
  104. next = nd
  105. break
  106. if next == curr:
  107. next = start
  108. opGraph.add_edge(curr,next,length = G[curr][next]['length'])
  109. nx.draw_networkx_edges(G, pos, edgelist = [(curr, next)], width = 2.5, alpha = 0.6, edge_color = 'r')
  110. # optimal_dist = optimal_dist + G[curr][next]['length']
  111. # print optimal_dist
  112. return opGraph
  113. #takes input from the file and creates a weighted undirected graph
  114. def CreateGraph():
  115. G = nx.Graph()
  116. f = open('input.txt')
  117. n = int(f.readline())
  118. wtMatrix = []
  119. for i in range(n):
  120. list1 = map(int, (f.readline()).split())
  121. wtMatrix.append(list1)
  122. #Adds egdes along with their weights to the graph
  123. for i in range(n) :
  124. for j in range(n)[i:] :
  125. if wtMatrix[i][j] > 0 :
  126. G.add_edge(i, j, length = wtMatrix[i][j])
  127. return G
  128. def DrawGraph(G,color):
  129. pos = nx.spring_layout(G)
  130. nx.draw(G, pos, with_labels = True, edge_color = color) #with_labels=true is to show the node number in the output graph
  131. edge_labels = nx.get_edge_attributes(G,'length')
  132. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, font_size = 11) #prints weight on all the edges
  133. return pos
  134. #main function
  135. if __name__ == "__main__":
  136. G = CreateGraph()
  137. plt.figure(1)
  138. pos = DrawGraph(G,'black')
  139. opGraph = christofedes(G, pos)
  140. plt.figure(2)
  141. pos1 = DrawGraph(opGraph,'r')
  142. plt.show()