dfs_visual.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. #utility fucntion used by DFS which does recursive depth first search
  4. def DFSUtil(G, v, visited, sl):
  5. visited[v] = True
  6. sl.append(v)
  7. for i in G[v]:
  8. if visited[i] == False:
  9. DFSUtil(G, i, visited, sl)
  10. return sl
  11. #DFS traversal
  12. def DFS(G, source):
  13. visited = [False]*(len(G.nodes()))
  14. sl = [] #a list that stores dfs forest starting with source node
  15. dfs_stk = [] #A nested list that stores all the DFS Forest's
  16. dfs_stk.append(DFSUtil(G, source, visited, sl))
  17. for i in range(len(G.nodes())):
  18. if visited[i] == False:
  19. sl = []
  20. dfs_stk.append(DFSUtil(G, i, visited, sl))
  21. return dfs_stk
  22. #takes input from the file and creates a weighted graph
  23. def CreateGraph():
  24. G = nx.DiGraph()
  25. f = open('input.txt')
  26. n = int(f.readline())
  27. wtMatrix = []
  28. for i in range(n):
  29. list1 = map(int,(f.readline()).split())
  30. wtMatrix.append(list1)
  31. source = int(f.readline()) #source vertex from where DFS has to start
  32. #Adds egdes along with their weights to the graph
  33. for i in range(n):
  34. for j in range(n):
  35. if wtMatrix[i][j] > 0:
  36. G.add_edge(i, j, length = wtMatrix[i][j])
  37. return G,source
  38. #marks all edges traversed through DFS with red
  39. def DrawDFSPath(G, dfs_stk):
  40. pos = nx.spring_layout(G)
  41. nx.draw(G, pos, with_labels = True) #with_labels=true is to show the node number in the output graph
  42. edge_labels = dict([((u,v,), d['length']) for u, v, d in G.edges(data = True)])
  43. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, label_pos = 0.3, font_size = 11) #prints weight on all the edges
  44. for i in dfs_stk:
  45. #if there is more than one node in the dfs-forest, then print the corresponding edges
  46. if len(i) > 1:
  47. for j in i[ :(len(i)-1)]:
  48. if i[i.index(j)+1] in G[j]:
  49. nx.draw_networkx_edges(G, pos, edgelist = [(j,i[i.index(j)+1])], width = 2.5, alpha = 0.6, edge_color = 'r')
  50. else:
  51. #if in case the path was reversed because all the possible neighbours were visited, we need to find the adj node to it.
  52. for k in i[1::-1]:
  53. if k in G[j]:
  54. nx.draw_networkx_edges(G, pos, edgelist = [(j,k)], width = 2.5, alpha = 0.6, edge_color = 'r')
  55. break
  56. #main function
  57. if __name__ == "__main__":
  58. G, source = CreateGraph()
  59. dfs_stk = DFS(G, source)
  60. DrawDFSPath(G, dfs_stk)
  61. plt.show()