kruskals_quick_union.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. import sys
  4. # A utility function that return the smallest unprocessed edge
  5. def getMin(G, mstFlag):
  6. min = sys.maxsize # assigning largest numeric value to min
  7. for i in [(u, v, edata['length']) for u, v, edata in G.edges( data = True) if 'length' in edata ]:
  8. if mstFlag[i] == False and i[2] < min:
  9. min = i[2]
  10. min_edge = i
  11. return min_edge
  12. # A utility function to find root or origin of the node i in MST
  13. def findRoot(parent, i):
  14. if parent[i] == i:
  15. return i
  16. return findRoot(parent, parent[i])
  17. # A function that does union of set x and y based on the order
  18. def union(parent, order, x, y):
  19. xRoot = findRoot(parent, x)
  20. yRoot = findRoot(parent, y)
  21. # Attach smaller order tree under root of high order tree
  22. if order[xRoot] < order[yRoot]:
  23. parent[xRoot] = yRoot
  24. elif order[xRoot] > order[yRoot]:
  25. parent[yRoot] = xRoot
  26. # If orders are same, then make any one as root and increment its order by one
  27. else :
  28. parent[yRoot] = xRoot
  29. order[xRoot] += 1
  30. # function that performs Kruskals algorithm on the graph G
  31. def kruskals(G, pos):
  32. eLen = len(G.edges()) # eLen denotes the number of edges in G
  33. vLen = len(G.nodes()) # vLen denotes the number of vertices in G
  34. mst = [] # mst contains the MST edges
  35. mstFlag = {} # mstFlag[i] will hold true if the edge i has been processed for MST
  36. for i in [ (u, v, edata['length']) for u, v, edata in G.edges(data = True) if 'length' in edata ]:
  37. mstFlag[i] = False
  38. parent = [None] * vLen # parent[i] will hold the vertex connected to i, in the MST
  39. order = [None] * vLen # order[i] will hold the order of appearance of the node in the MST
  40. for v in range(vLen):
  41. parent[v] = v
  42. order[v] = 0
  43. while len(mst) < vLen - 1 :
  44. curr_edge = getMin(G, mstFlag) # pick the smallest egde from the set of edges
  45. mstFlag[curr_edge] = True # update the flag for the current edge
  46. y = findRoot(parent, curr_edge[1])
  47. x = findRoot(parent, curr_edge[0])
  48. # adds the edge to MST, if including it doesn't form a cycle
  49. if x != y:
  50. mst.append(curr_edge)
  51. union(parent, order, x, y)
  52. # Else discard the edge
  53. # marks the MST edges with red
  54. for X in mst:
  55. if (X[0], X[1]) in G.edges():
  56. nx.draw_networkx_edges(G, pos, edgelist = [(X[0], X[1])], width = 2.5, alpha = 0.6, edge_color = 'r')
  57. return
  58. # takes input from the file and creates a weighted graph
  59. def CreateGraph():
  60. G = nx.Graph()
  61. f = open('input.txt')
  62. n = int(f.readline())
  63. wtMatrix = []
  64. for i in range(n):
  65. list1 = map(int, (f.readline()).split())
  66. wtMatrix.append(list1)
  67. # Adds egdes along with their weights to the graph
  68. for i in range(n) :
  69. for j in range(n)[i:] :
  70. if wtMatrix[i][j] > 0 :
  71. G.add_edge(i, j, length = wtMatrix[i][j])
  72. return G
  73. # draws the graph and displays the weights on the edges
  74. def DrawGraph(G):
  75. pos = nx.spring_layout(G)
  76. nx.draw(G, pos, with_labels = True) # with_labels=true is to show the node number in the output graph
  77. edge_labels = nx.get_edge_attributes(G, 'length')
  78. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, font_size = 11) # prints weight on all the edges
  79. return pos
  80. # main function
  81. if __name__ == "__main__":
  82. G = CreateGraph()
  83. pos = DrawGraph(G)
  84. kruskals(G, pos)
  85. plt.show()