prims.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. import sys
  4. #utility function that returns the minimum egde weight node
  5. def minDistance(dist, mstSet, V):
  6. min = sys.maxsize #assigning largest numeric value to min
  7. for v in range(V):
  8. if mstSet[v] == False and dist[v] < min:
  9. min = dist[v]
  10. min_index = v
  11. return min_index
  12. #function that performs prims algorithm on the graph G
  13. def prims(G, pos):
  14. V = len(G.nodes()) # V denotes the number of vertices in G
  15. dist = [] # dist[i] will hold the minimum weight edge value of node i to be included in MST
  16. parent = [None]*V # parent[i] will hold the vertex connected to i, in the MST edge
  17. mstSet = [] # mstSet[i] will hold true if vertex i is included in the MST
  18. #initially, for every node, dist[] is set to maximum value and mstSet[] is set to False
  19. for i in range(V):
  20. dist.append(sys.maxsize)
  21. mstSet.append(False)
  22. dist[0] = 0
  23. parent[0]= -1 #starting vertex is itself the root, and hence has no parent
  24. for count in range(V-1):
  25. u = minDistance(dist, mstSet, V) #pick the minimum distance vertex from the set of vertices
  26. mstSet[u] = True
  27. #update the vertices adjacent to the picked vertex
  28. for v in range(V):
  29. if (u, v) in G.edges():
  30. if mstSet[v] == False and G[u][v]['length'] < dist[v]:
  31. dist[v] = G[u][v]['length']
  32. parent[v] = u
  33. for X in range(V):
  34. if parent[X] != -1: #ignore the parent of the starting node
  35. if (parent[X], X) in G.edges():
  36. nx.draw_networkx_edges(G, pos, edgelist = [(parent[X], X)], width = 2.5, alpha = 0.6, edge_color = 'r')
  37. return
  38. #takes input from the file and creates a weighted graph
  39. def CreateGraph():
  40. G = nx.Graph()
  41. f = open('input.txt')
  42. n = int(f.readline())
  43. wtMatrix = []
  44. for i in range(n):
  45. list1 = map(int, (f.readline()).split())
  46. wtMatrix.append(list1)
  47. #Adds egdes along with their weights to the graph
  48. for i in range(n) :
  49. for j in range(n)[i:] :
  50. if wtMatrix[i][j] > 0 :
  51. G.add_edge(i, j, length = wtMatrix[i][j])
  52. return G
  53. #draws the graph and displays the weights on the edges
  54. def DrawGraph(G):
  55. pos = nx.spring_layout(G)
  56. nx.draw(G, pos, with_labels = True) #with_labels=true is to show the node number in the output graph
  57. edge_labels = nx.get_edge_attributes(G,'length')
  58. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, font_size = 11) #prints weight on all the edges
  59. return pos
  60. #main function
  61. if __name__ == "__main__":
  62. G = CreateGraph()
  63. pos = DrawGraph(G)
  64. prims(G, pos)
  65. plt.show()