assignment_prob_hungarian.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import networkx as nx
  2. import matplotlib.pyplot as plt
  3. from networkx.algorithms import bipartite
  4. from string import ascii_lowercase
  5. def init_labels(cost):
  6. n = len(cost)
  7. lx = [0] * n
  8. ly = [0] * n
  9. for x in range(n):
  10. for y in range(n):
  11. lx[x] = max(lx[x], cost[x][y])
  12. return lx,ly
  13. def update_labels(T, slack, S, lx, ly, n):
  14. delta = float("inf");
  15. for y in range(n):
  16. if T[y] == 0:
  17. delta = min(delta, slack[y])
  18. for x in range(n):
  19. if S[x] != 0:
  20. lx[x] -= delta
  21. for y in range(n):
  22. if T[y] != 0:
  23. ly[y] += delta
  24. for y in range(n):
  25. if T[y] == 0:
  26. slack[y] -= delta
  27. def add_to_tree(x, prevx, S, prev, lx, ly, slack, slackx, cost):
  28. n = len(cost)
  29. S[x] = True
  30. prev[x] = prevx
  31. for y in range(n):
  32. if (lx[x] + ly[y] - cost[x][y]) < slack[y]:
  33. slack[y] = lx[x] + ly[y] - cost[x][y]
  34. slackx[y] = x
  35. def augment(cost, max_match, xy, yx, lx, ly, slack, slackx):
  36. n = len(cost)
  37. if max_match == n:
  38. return;
  39. q = [0] * n
  40. wr = 0
  41. rd = 0
  42. root = 0
  43. S = [False] * n
  44. T = [False] * n
  45. prev = [-1] * n
  46. for x in range(n):
  47. if xy[x] == -1:
  48. q[wr] = x
  49. wr = wr+1
  50. root = x
  51. prev[x] = -2
  52. S[x] = True
  53. break
  54. for y in range(n):
  55. slack[y] = lx[root] + ly[y] - cost[root][y]
  56. slackx[y] = root
  57. while True:
  58. while rd < wr:
  59. x = q[rd]
  60. rd = rd+1
  61. for y in range(n):
  62. if (cost[x][y] == lx[x] + ly[y] and T[y] == 0):
  63. if yx[y] == -1:
  64. break
  65. T[y] = True
  66. q[wr] = yx[y]
  67. wr = wr+1
  68. add_to_tree(yx[y], x, S, prev, lx, ly, slack, slackx, cost)
  69. if y < n:
  70. break
  71. if y < n:
  72. break
  73. update_labels(T, slack, S, lx, ly, n)
  74. wr = 0
  75. rd = 0
  76. for y in range(n):
  77. if T[y] == 0 and slack[y] == 0:
  78. if yx[y] == -1:
  79. x = slackx[y]
  80. break
  81. else:
  82. T[y] = true
  83. if S[yx[y]] == 0:
  84. q[wr] = yx[y]
  85. wr = wr+1
  86. add_to_tree(yx[y], slackx[y], S, prev, lx, ly, slack, slackx, cost)
  87. if y < n:
  88. break
  89. if y < n:
  90. max_match = max_match+1
  91. cx = x
  92. cy = y
  93. ty = 0
  94. flag = 0
  95. if cx != -2:
  96. ty = xy[cx];
  97. yx[cy] = cx;
  98. xy[cx] = cy;
  99. cx = prev[cx]
  100. cy = ty
  101. while cx != -2:
  102. ty = xy[cx]
  103. yx[cy] = cx
  104. xy[cx] = cy
  105. cx = prev[cx]
  106. cy = ty
  107. augment(cost, max_match, xy, yx, lx, ly, slack, slackx)
  108. def hungarian(B ,pos ,cost):
  109. n = len(cost)
  110. ret = 0;
  111. max_match = 0
  112. xy = [-1] * n
  113. yx = [-1] * n
  114. slack = [0] * n
  115. slackx = [0] * n
  116. lx, ly = init_labels(cost)
  117. augment(cost, max_match, xy, yx, lx, ly, slack, slackx)
  118. for x in range(n):
  119. if (x, chr(xy[x]+97)) in B.edges():
  120. nx.draw_networkx_edges(B, pos, edgelist = [(x, chr(xy[x]+97))], width = 2.5, alpha = 0.6, edge_color = 'r')
  121. #takes input from the file and creates a weighted bipartite graph
  122. def CreateGraph():
  123. B = nx.DiGraph();
  124. f = open('input.txt')
  125. n = int(f.readline())
  126. cost = []
  127. for i in range(n):
  128. list1 = map(int, (f.readline()).split())
  129. cost.append(list1)
  130. people = []
  131. for i in range(n):
  132. people.append(i)
  133. job = []
  134. for c in ascii_lowercase[:n]:
  135. job.append(c)
  136. B.add_nodes_from(people, bipartite=0) # Add the node attribute "bipartite"
  137. B.add_nodes_from(job, bipartite=1)
  138. for i in range(n) :
  139. for c in ascii_lowercase[:n] :
  140. if cost[i][ord(c)-97] > 0 :
  141. B.add_edge(i, c, length = cost[i][ord(c)-97])
  142. return B,cost
  143. def DrawGraph(B):
  144. l, r = nx.bipartite.sets(B)
  145. pos = {}
  146. # Update position for node from each group
  147. pos.update((node, (1, index)) for index, node in enumerate(l))
  148. pos.update((node, (2, index)) for index, node in enumerate(r))
  149. nx.draw(B, pos, with_labels = True) #with_labels=true is to show the node number in the output graph
  150. edge_labels = dict([((u, v), d['length']) for u, v, d in B.edges(data = True)])
  151. nx.draw_networkx_edge_labels(B, pos, edge_labels = edge_labels, label_pos = 0.2, font_size = 11) #prints weight on all the edges
  152. return pos
  153. #main function
  154. if __name__ == "__main__":
  155. B, cost = CreateGraph();
  156. pos = DrawGraph(B)
  157. hungarian(B, pos, cost)
  158. plt.show()