k_centers_problem.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. import operator
  4. def k_centers(G, n):
  5. centers = []
  6. cities = G.nodes()
  7. #add an arbitrary node, here, the first node,to the centers list
  8. centers.append((G.nodes())[0])
  9. cities.remove(centers[0])
  10. n = n-1 #since we have already added one center
  11. #choose n-1 centers
  12. while n!= 0:
  13. city_dict = {}
  14. for cty in cities:
  15. min_dist = float("inf")
  16. for c in centers:
  17. min_dist = min(min_dist,G[cty][c]['length'])
  18. city_dict[cty] = min_dist
  19. #print city_dict
  20. new_center = max(city_dict, key = city_dict[i])
  21. #print new_center
  22. centers.append(new_center)
  23. cities.remove(new_center)
  24. n = n-1
  25. #print centers
  26. return centers
  27. #takes input from the file and creates a weighted undirected graph
  28. def CreateGraph():
  29. G = nx.Graph()
  30. f = open('input.txt')
  31. n = int(f.readline()) #n denotes the number of cities
  32. wtMatrix = []
  33. for i in range(n):
  34. list1 = map(int, (f.readline()).split())
  35. wtMatrix.append(list1)
  36. #Adds egdes along with their weights to the graph
  37. for i in range(n) :
  38. for j in range(n)[i:] :
  39. G.add_edge(i, j, length = wtMatrix[i][j])
  40. noc = int(f.readline()) #noc,here,denotes the number of centers
  41. return G, noc
  42. #draws the graph and displays the weights on the edges
  43. def DrawGraph(G, centers):
  44. pos = nx.spring_layout(G)
  45. color_map = ['blue'] * len(G.nodes())
  46. #all the center nodes are marked with 'red'
  47. for c in centers:
  48. color_map[c] = 'red'
  49. nx.draw(G, pos, node_color = color_map, with_labels = True) #with_labels=true is to show the node number in the output graph
  50. edge_labels = nx.get_edge_attributes(G, 'length')
  51. nx.draw_networkx_edge_labels(G, pos, edge_labels = edge_labels, font_size = 11) #prints weight on all the edges
  52. #main function
  53. if __name__ == "__main__":
  54. G,n = CreateGraph()
  55. centers = k_centers(G, n)
  56. DrawGraph(G, centers)
  57. plt.show()