cleanup/refactor bipartite matching with hungarian method
[libfirm] / ir / adt / hungarian.c
1 /********************************************************************
2  ********************************************************************
3  **
4  ** libhungarian by Cyrill Stachniss, 2004
5  **
6  ** Added and adapted to libFirm by Christian Wuerdig, 2006
7  **
8  ** Solving the Minimum Assignment Problem using the
9  ** Hungarian Method.
10  **
11  ** ** This file may be freely copied and distributed! **
12  **
13  ** Parts of the used code was originally provided by the
14  ** "Stanford GraphGase", but I made changes to this code.
15  ** As asked by  the copyright node of the "Stanford GraphGase",
16  ** I hereby proclaim that this file are *NOT* part of the
17  ** "Stanford GraphGase" distrubition!
18  **
19  ** This file is distributed in the hope that it will be useful,
20  ** but WITHOUT ANY WARRANTY; without even the implied
21  ** warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
22  ** PURPOSE.
23  **
24  ********************************************************************
25  ********************************************************************/
26
27 /**
28  * @file
29  * @brief   Solving the Minimum Assignment Problem using the Hungarian Method.
30  * @version $Id$
31  */
32 #include "config.h"
33
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "irtools.h"
39 #include "xmalloc.h"
40 #include "debug.h"
41 #include "obst.h"
42 #include "bitset.h"
43 #include "error.h"
44
45 #include "hungarian.h"
46
47 struct hungarian_problem_t {
48         unsigned num_rows;          /**< number of rows */
49         unsigned num_cols;          /**< number of columns */
50         int      **cost;            /**< the cost matrix */
51         int      max_cost;          /**< the maximal costs in the matrix */
52         int      match_type;        /**< PERFECT or NORMAL matching */
53         bitset_t *missing_left;     /**< left side nodes having no edge to the right side */
54         bitset_t *missing_right;    /**< right side nodes having no edge to the left side */
55         struct obstack obst;
56         DEBUG_ONLY(firm_dbg_module_t *dbg);
57 };
58
59 static void hungarian_dump_f(FILE *f, int **C, unsigned n_rows, unsigned n_cols,
60                              int width)
61 {
62         unsigned r;
63
64         fprintf(f , "\n");
65         for (r = 0; r < n_rows; r++) {
66                 unsigned c;
67                 fprintf(f, " [");
68                 for (c = 0; c < n_cols; c++) {
69                         fprintf(f, "%*d", width, C[r][c]);
70                 }
71                 fprintf(f, "]\n");
72         }
73         fprintf(f, "\n");
74 }
75
76 void hungarian_print_cost_matrix(hungarian_problem_t *p, int width)
77 {
78         hungarian_dump_f(stderr, p->cost, p->num_rows, p->num_cols, width);
79 }
80
81 hungarian_problem_t *hungarian_new(unsigned n_rows, unsigned n_cols,
82                                    match_type_t match_type)
83 {
84         unsigned r;
85         hungarian_problem_t *p = XMALLOCZ(hungarian_problem_t);
86
87         FIRM_DBG_REGISTER(p->dbg, "firm.hungarian");
88
89         /*
90                 Is the number of cols  not equal to number of rows ?
91                 If yes, expand with 0 - cols / 0 - cols
92         */
93         n_rows = MAX(n_cols, n_rows);
94         n_cols = n_rows;
95
96         obstack_init(&p->obst);
97
98         p->num_rows   = n_rows;
99         p->num_cols   = n_cols;
100         p->match_type = match_type;
101
102         /*
103                 In case of normal matching, we have to keep
104                 track of nodes without edges to kill them in
105                 the assignment later.
106         */
107         if (match_type == HUNGARIAN_MATCH_NORMAL) {
108                 p->missing_left  = bitset_obstack_alloc(&p->obst, n_rows);
109                 p->missing_right = bitset_obstack_alloc(&p->obst, n_cols);
110                 bitset_set_all(p->missing_left);
111                 bitset_set_all(p->missing_right);
112         }
113
114         /* allocate space for cost matrix */
115         p->cost = OALLOCNZ(&p->obst, int*, n_rows);
116         for (r = 0; r < p->num_rows; r++)
117                 p->cost[r] = OALLOCNZ(&p->obst, int, n_cols);
118
119         return p;
120 }
121
122 void hungarian_prepare_cost_matrix(hungarian_problem_t *p,
123                                    hungarian_mode_t mode)
124 {
125         unsigned r, c;
126
127         if (mode == HUNGARIAN_MODE_MAXIMIZE_UTIL) {
128                 for (r = 0; r < p->num_rows; r++) {
129                         for (c = 0; c < p->num_cols; c++) {
130                                 p->cost[r][c] = p->max_cost - p->cost[r][c];
131                         }
132                 }
133         } else if (mode == HUNGARIAN_MODE_MINIMIZE_COST) {
134                 /* nothing to do */
135         } else {
136                 panic("Unknown hungarian problem mode\n");
137         }
138 }
139
140 void hungarian_add(hungarian_problem_t *p, unsigned left, unsigned right,
141                    int cost)
142 {
143         assert(p->num_rows > left  && "Invalid row selected.");
144         assert(p->num_cols > right && "Invalid column selected.");
145         assert(cost >= 0);
146
147         p->cost[left][right] = cost;
148         p->max_cost          = MAX(p->max_cost, cost);
149
150         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
151                 bitset_clear(p->missing_left, left);
152                 bitset_clear(p->missing_right, right);
153         }
154 }
155
156 void hungarian_remove(hungarian_problem_t *p, unsigned left, unsigned right)
157 {
158         assert(p->num_rows > left  && "Invalid row selected.");
159         assert(p->num_cols > right && "Invalid column selected.");
160
161         /* Set cost[left][right] to 0. */
162         p->cost[left][right] = 0;
163
164         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
165                 bitset_set(p->missing_left, left);
166                 bitset_set(p->missing_right, right);
167         }
168 }
169
170 void hungarian_free(hungarian_problem_t* p)
171 {
172         obstack_free(&p->obst, NULL);
173         xfree(p);
174 }
175
176 int hungarian_solve(hungarian_problem_t* p, unsigned *assignment,
177                     int *final_cost, int cost_threshold)
178 {
179         int       cost          = 0;
180         unsigned  num_rows      = p->num_rows;
181         unsigned  num_cols      = p->num_cols;
182         unsigned *col_mate      = XMALLOCNZ(unsigned, num_rows);
183         unsigned *row_mate      = XMALLOCNZ(unsigned, num_cols);
184         unsigned *parent_row    = XMALLOCNZ(unsigned, num_cols);
185         unsigned *unchosen_row  = XMALLOCNZ(unsigned, num_rows);
186         int      *row_dec       = XMALLOCNZ(int, num_rows);
187         int      *col_inc       = XMALLOCNZ(int, num_cols);
188         int      *slack         = XMALLOCNZ(int, num_cols);
189         unsigned *slack_row     = XMALLOCNZ(unsigned, num_rows);
190         unsigned  r;
191         unsigned  c;
192         unsigned  t;
193         unsigned  unmatched;
194
195         memset(assignment, -1, num_rows * sizeof(assignment[0]));
196
197         /* Begin subtract column minima in order to start with lots of zeros 12 */
198         DBG((p->dbg, LEVEL_1, "Using heuristic\n"));
199
200         for (c = 0; c < num_cols; ++c) {
201                 int s = p->cost[0][c];
202
203                 for (r = 1; r < num_rows; ++r) {
204                         if (p->cost[r][c] < s)
205                                 s = p->cost[r][c];
206                 }
207
208                 cost += s;
209                 if (s == 0)
210                         continue;
211
212                 for (r = 0; r < num_rows; ++r)
213                         p->cost[r][c] -= s;
214         }
215         /* End subtract column minima in order to start with lots of zeros 12 */
216
217         /* Begin initial state 16 */
218         t = 0;
219         for (c = 0; c < num_cols; ++c) {
220                 row_mate[c]   = (unsigned) -1;
221                 parent_row[c] = (unsigned) -1;
222                 col_inc[c]    = 0;
223                 slack[c]      = INT_MAX;
224         }
225
226         for (r = 0; r < num_rows; ++r) {
227                 int s = p->cost[r][0];
228
229                 for (c = 1; c < num_cols; ++c) {
230                         if (p->cost[r][c] < s)
231                                 s = p->cost[r][c];
232                 }
233
234                 row_dec[r] = s;
235
236                 for (c = 0; c < num_cols; ++c) {
237                         if (s == p->cost[r][c] && row_mate[c] == (unsigned)-1) {
238                                 col_mate[r] = c;
239                                 row_mate[c] = r;
240                                 DBG((p->dbg, LEVEL_1, "matching col %d == row %d\n", c, r));
241                                 goto row_done;
242                         }
243                 }
244
245                 col_mate[r] = (unsigned)-1;
246                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, r));
247                 unchosen_row[t++] = r;
248 row_done: ;
249         }
250         /* End initial state 16 */
251
252         /* Begin Hungarian algorithm 18 */
253         if (t == 0)
254                 goto done;
255
256         unmatched = t;
257         for (;;) {
258                 unsigned q = 0;
259                 unsigned j;
260                 DBG((p->dbg, LEVEL_1, "Matched %d rows.\n", num_rows - t));
261
262                 for (;;) {
263                         int s;
264                         while (q < t) {
265                                 /* Begin explore node q of the forest 19 */
266                                 r = unchosen_row[q];
267                                 s = row_dec[r];
268
269                                 for (c = 0; c < num_cols; ++c) {
270                                         if (slack[c]) {
271                                                 int del = p->cost[r][c] - s + col_inc[c];
272
273                                                 if (del < slack[c]) {
274                                                         if (del == 0) {
275                                                                 if (row_mate[c] == (unsigned)-1)
276                                                                         goto breakthru;
277
278                                                                 slack[c]      = 0;
279                                                                 parent_row[c] = r;
280                                                                 DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[c], c, r));
281                                                                 unchosen_row[t++] = row_mate[c];
282                                                         } else {
283                                                                 slack[c]     = del;
284                                                                 slack_row[c] = r;
285                                                         }
286                                                 }
287                                         }
288                                 }
289                                 /* End explore node q of the forest 19 */
290                                 q++;
291                         }
292
293                         /* Begin introduce a new zero into the matrix 21 */
294                         s = INT_MAX;
295                         for (c = 0; c < num_cols; ++c) {
296                                 if (slack[c] && slack[c] < s)
297                                         s = slack[c];
298                         }
299
300                         for (q = 0; q < t; ++q)
301                                 row_dec[unchosen_row[q]] += s;
302
303                         for (c = 0; c < num_cols; ++c) {
304                                 if (slack[c]) {
305                                         slack[c] -= s;
306                                         if (slack[c] == 0) {
307                                                 /* Begin look at a new zero 22 */
308                                                 r = slack_row[c];
309                                                 DBG((p->dbg, LEVEL_1, "Decreasing uncovered elements by %d produces zero at [%d, %d]\n", s, r, c));
310                                                 if (row_mate[c] == (unsigned)-1) {
311                                                         for (j = c + 1; j < num_cols; ++j) {
312                                                                 if (slack[j] == 0)
313                                                                         col_inc[j] += s;
314                                                         }
315                                                         goto breakthru;
316                                                 } else {
317                                                         parent_row[c] = r;
318                                                         DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[c], c, r));
319                                                         unchosen_row[t++] = row_mate[c];
320                                                 }
321                                                 /* End look at a new zero 22 */
322                                         }
323                                 } else {
324                                         col_inc[c] += s;
325                                 }
326                         }
327                         /* End introduce a new zero into the matrix 21 */
328                 }
329 breakthru:
330                 /* Begin update the matching 20 */
331                 DBG((p->dbg, LEVEL_1, "Breakthrough at node %d of %d.\n", q, t));
332                 for (;;) {
333                         j           = col_mate[r];
334                         col_mate[r] = c;
335                         row_mate[c] = r;
336
337                         DBG((p->dbg, LEVEL_1, "rematching col %d == row %d\n", c, r));
338                         if (j == (unsigned)-1)
339                                 break;
340
341                         r = parent_row[j];
342                         c = j;
343                 }
344                 /* End update the matching 20 */
345
346                 if (--unmatched == 0)
347                         goto done;
348
349                 /* Begin get ready for another stage 17 */
350                 t = 0;
351                 for (c = 0; c < num_cols; ++c) {
352                         parent_row[c] = -1;
353                         slack[c]      = INT_MAX;
354                 }
355
356                 for (r = 0; r < num_rows; ++r) {
357                         if (col_mate[r] == (unsigned)-1) {
358                                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, r));
359                                 unchosen_row[t++] = r;
360                         }
361                 }
362                 /* End get ready for another stage 17 */
363         }
364 done:
365
366         /* Begin double check the solution 23 */
367         for (r = 0; r < num_rows; ++r) {
368                 for (c = 0; c < num_cols; ++c) {
369                         if (p->cost[r][c] < row_dec[r] - col_inc[c])
370                                 return -1;
371                 }
372         }
373
374         for (r = 0; r < num_rows; ++r) {
375                 c = col_mate[r];
376                 if (c == (unsigned)-1 || p->cost[r][c] != row_dec[r] - col_inc[c])
377                         return -2;
378         }
379
380         for (r = c = 0; c < num_cols; ++c) {
381                 if (col_inc[c])
382                         r++;
383         }
384
385         if (r > num_rows)
386                 return -3;
387         /* End double check the solution 23 */
388
389         /* End Hungarian algorithm 18 */
390
391         /* collect the assigned values */
392         for (r = 0; r < num_rows; ++r) {
393                 if (cost_threshold > 0 && p->cost[r][col_mate[r]] >= cost_threshold)
394                         assignment[r] = -1; /* remove matching having cost > threshold */
395                 else
396                         assignment[r] = col_mate[r];
397         }
398
399         /* In case of normal matching: remove impossible ones */
400         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
401                 for (r = 0; r < num_rows; ++r) {
402                         if (bitset_is_set(p->missing_left, r)
403                                 || bitset_is_set(p->missing_right, col_mate[r]))
404                                 assignment[r] = -1;
405                 }
406         }
407
408         for (r = 0; r < num_rows; ++r) {
409                 for (c = 0; c < num_cols; ++c) {
410                         p->cost[r][c] = p->cost[r][c] - row_dec[r] + col_inc[c];
411                 }
412         }
413
414         for (r = 0; r < num_rows; ++r)
415                 cost += row_dec[r];
416
417         for (c = 0; c < num_cols; ++c)
418                 cost -= col_inc[c];
419
420         DBG((p->dbg, LEVEL_1, "Cost is %d\n", cost));
421
422         xfree(slack);
423         xfree(col_inc);
424         xfree(parent_row);
425         xfree(row_mate);
426         xfree(slack_row);
427         xfree(row_dec);
428         xfree(unchosen_row);
429         xfree(col_mate);
430
431         if (final_cost != NULL)
432                 *final_cost = cost;
433
434         return 0;
435 }