123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import networkx as nx
- import matplotlib.pyplot as plt
- from networkx.algorithms import bipartite
- from string import ascii_lowercase
- def init_labels(cost):
- n = len(cost)
- lx = [0] * n
- ly = [0] * n
- for x in range(n):
- for y in range(n):
- lx[x] = max(lx[x], cost[x][y])
- return lx,ly
- def update_labels(T, slack, S, lx, ly, n):
- delta = float("inf");
- for y in range(n):
- if T[y] == 0:
- delta = min(delta, slack[y])
- for x in range(n):
- if S[x] != 0:
- lx[x] -= delta
- for y in range(n):
- if T[y] != 0:
- ly[y] += delta
- for y in range(n):
- if T[y] == 0:
- slack[y] -= delta
- def add_to_tree(x, prevx, S, prev, lx, ly, slack, slackx, cost):
- n = len(cost)
- S[x] = True
- prev[x] = prevx
- for y in range(n):
- if (lx[x] + ly[y] - cost[x][y]) < slack[y]:
- slack[y] = lx[x] + ly[y] - cost[x][y]
- slackx[y] = x
-
- def augment(cost, max_match, xy, yx, lx, ly, slack, slackx):
- n = len(cost)
- if max_match == n:
- return;
- q = [0] * n
- wr = 0
- rd = 0
- root = 0
- S = [False] * n
- T = [False] * n
- prev = [-1] * n
- for x in range(n):
- if xy[x] == -1:
- q[wr] = x
- wr = wr+1
- root = x
- prev[x] = -2
- S[x] = True
- break
- for y in range(n):
- slack[y] = lx[root] + ly[y] - cost[root][y]
- slackx[y] = root
- while True:
- while rd < wr:
- x = q[rd]
- rd = rd+1
- for y in range(n):
- if (cost[x][y] == lx[x] + ly[y] and T[y] == 0):
- if yx[y] == -1:
- break
- T[y] = True
- q[wr] = yx[y]
- wr = wr+1
- add_to_tree(yx[y], x, S, prev, lx, ly, slack, slackx, cost)
- if y < n:
- break
- if y < n:
- break
- update_labels(T, slack, S, lx, ly, n)
- wr = 0
- rd = 0
- for y in range(n):
- if T[y] == 0 and slack[y] == 0:
- if yx[y] == -1:
- x = slackx[y]
- break
- else:
- T[y] = true
- if S[yx[y]] == 0:
- q[wr] = yx[y]
- wr = wr+1
- add_to_tree(yx[y], slackx[y], S, prev, lx, ly, slack, slackx, cost)
- if y < n:
- break
- if y < n:
- max_match = max_match+1
- cx = x
- cy = y
- ty = 0
- flag = 0
- if cx != -2:
- ty = xy[cx];
- yx[cy] = cx;
- xy[cx] = cy;
- cx = prev[cx]
- cy = ty
- while cx != -2:
- ty = xy[cx]
- yx[cy] = cx
- xy[cx] = cy
- cx = prev[cx]
- cy = ty
- augment(cost, max_match, xy, yx, lx, ly, slack, slackx)
- def hungarian(B ,pos ,cost):
- n = len(cost)
- ret = 0;
- max_match = 0
- xy = [-1] * n
- yx = [-1] * n
- slack = [0] * n
- slackx = [0] * n
- lx, ly = init_labels(cost)
- augment(cost, max_match, xy, yx, lx, ly, slack, slackx)
- for x in range(n):
- if (x, chr(xy[x]+97)) in B.edges():
- nx.draw_networkx_edges(B, pos, edgelist = [(x, chr(xy[x]+97))], width = 2.5, alpha = 0.6, edge_color = 'r')
- #takes input from the file and creates a weighted bipartite graph
- def CreateGraph():
- B = nx.DiGraph();
- f = open('input.txt')
- n = int(f.readline())
- cost = []
-
- for i in range(n):
- list1 = map(int, (f.readline()).split())
- cost.append(list1)
- people = []
- for i in range(n):
- people.append(i)
- job = []
- for c in ascii_lowercase[:n]:
- job.append(c)
- B.add_nodes_from(people, bipartite=0) # Add the node attribute "bipartite"
- B.add_nodes_from(job, bipartite=1)
- for i in range(n) :
- for c in ascii_lowercase[:n] :
- if cost[i][ord(c)-97] > 0 :
- B.add_edge(i, c, length = cost[i][ord(c)-97])
- return B,cost
- def DrawGraph(B):
- l, r = nx.bipartite.sets(B)
- pos = {}
- # Update position for node from each group
- pos.update((node, (1, index)) for index, node in enumerate(l))
- pos.update((node, (2, index)) for index, node in enumerate(r))
- nx.draw(B, pos, with_labels = True) #with_labels=true is to show the node number in the output graph
- edge_labels = dict([((u, v), d['length']) for u, v, d in B.edges(data = True)])
- nx.draw_networkx_edge_labels(B, pos, edge_labels = edge_labels, label_pos = 0.2, font_size = 11) #prints weight on all the edges
- return pos
- #main function
- if __name__ == "__main__":
- B, cost = CreateGraph();
- pos = DrawGraph(B)
- hungarian(B, pos, cost)
- plt.show()
-
|