dijsktras.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. import sys
  4. #utility function that returns the minimum distance node
  5. def minDistance(dist, sptSet, V):
  6. min = sys.maxsize #assigning largest numeric value to min
  7. for v in range(V):
  8. if sptSet[v] == False and dist[v] <= min:
  9. min = dist[v]
  10. min_index = v
  11. return min_index
  12. #function that performs dijsktras algorithm on the graph G,with source vertex as source
  13. def dijsktras(G, source, pos):
  14. V = len(G.nodes()) # V denotes the number of vertices in G
  15. dist = [] # dist[i] will hold the shortest distance from source to i
  16. parent = [None]*V # parent[i] will hold the node from which i is reached to, in the shortest path from source
  17. sptSet = [] # sptSet[i] will hold true if vertex i is included in shortest path tree
  18. #initially, for every node, dist[] is set to maximum value and sptSet[] is set to False
  19. for i in range(V):
  20. dist.append(sys.maxsize)
  21. sptSet.append(False)
  22. dist[source] = 0
  23. parent[source]= -1 #source is itself the root, and hence has no parent
  24. for count in range(V-1):
  25. u = minDistance(dist, sptSet, V) #pick the minimum distance vectex from the set of vertices
  26. sptSet[u] = True
  27. #update the vertices adjacent to the picked vertex
  28. for v in range(V):
  29. if (u, v) in G.edges():
  30. if sptSet[v] == False and dist[u] != sys.maxsize and dist[u] + G[u][v]['length'] < dist[v]:
  31. dist[v] = dist[u] + G[u][v]['length']
  32. parent[v] = u
  33. #marking the shortest path from source to each of the vertex with red, using parent[]
  34. for X in range(V):
  35. if parent[X] != -1: #ignore the parent of root node
  36. if (parent[X], X) in G.edges():
  37. nx.draw_networkx_edges(G, pos, edgelist = [(parent[X], X)], width = 2.5, alpha = 0.6, edge_color = 'r')
  38. return
  39. #takes input from the file and creates a weighted graph
  40. def CreateGraph():
  41. G = nx.DiGraph()
  42. f = open('input.txt')
  43. n = int(f.readline())
  44. wtMatrix = []
  45. for i in range(n):
  46. list1 = map(int, (f.readline()).split())
  47. wtMatrix.append(list1)
  48. source = int(f.readline()) #source vertex for dijsktra's algo
  49. #Adds egdes along with their weights to the graph
  50. for i in range(n) :
  51. for j in range(n) :
  52. if wtMatrix[i][j] > 0 :
  53. G.add_edge(i, j, length = wtMatrix[i][j])
  54. return G, source
  55. #draws the graph and displays the weights on the edges
  56. def DrawGraph(G):
  57. pos = nx.spring_layout(G)
  58. nx.draw(G, pos, with_labels = True) #with_labels=true is to show the node number in the output graph
  59. edge_labels = dict([((u, v), d['length']) for u, v, d in G.edges(data = True)])
  60. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, label_pos = 0.3, font_size = 11) #prints weight on all the edges
  61. return pos
  62. #main function
  63. if __name__ == "__main__":
  64. G,source = CreateGraph()
  65. pos = DrawGraph(G)
  66. dijsktras(G, source, pos)
  67. plt.show()