cleanup, simplify hungarian algorithm implementation
[libfirm] / ir / adt / hungarian.c
1 /********************************************************************
2  ********************************************************************
3  **
4  ** libhungarian by Cyrill Stachniss, 2004
5  **
6  ** Added and adapted to libFirm by Christian Wuerdig, 2006
7  **
8  ** Solving the Minimum Assignment Problem using the
9  ** Hungarian Method.
10  **
11  ** ** This file may be freely copied and distributed! **
12  **
13  ** Parts of the used code was originally provided by the
14  ** "Stanford GraphGase", but I made changes to this code.
15  ** As asked by  the copyright node of the "Stanford GraphGase",
16  ** I hereby proclaim that this file are *NOT* part of the
17  ** "Stanford GraphGase" distrubition!
18  **
19  ** This file is distributed in the hope that it will be useful,
20  ** but WITHOUT ANY WARRANTY; without even the implied
21  ** warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
22  ** PURPOSE.
23  **
24  ********************************************************************
25  ********************************************************************/
26
27 /**
28  * @file
29  * @brief   Solving the Minimum Assignment Problem using the Hungarian Method.
30  * @version $Id$
31  */
32 #include "config.h"
33
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "irtools.h"
39 #include "xmalloc.h"
40 #include "debug.h"
41 #include "bitset.h"
42 #include "error.h"
43
44 #include "hungarian.h"
45
46 DEBUG_ONLY(static firm_dbg_module_t *dbg);
47
48 struct hungarian_problem_t {
49         unsigned      num_rows;      /**< number of rows */
50         unsigned      num_cols;      /**< number of columns */
51         unsigned     *cost;          /**< the cost matrix */
52         unsigned      max_cost;      /**< the maximal costs in the matrix */
53         match_type_t  match_type;    /**< PERFECT or NORMAL matching */
54         unsigned     *missing_left;  /**< bitset: left side nodes having no edge to
55                                           the right side */
56         unsigned     *missing_right; /**< bitset: right side nodes having no edge to
57                                       the left side */
58 };
59
60 static void hungarian_dump_f(FILE *f, const unsigned *cost,
61                              unsigned num_rows, unsigned num_cols, int width)
62 {
63         unsigned r, c;
64
65         fprintf(f , "\n");
66         for (r = 0; r < num_rows; r++) {
67                 fprintf(f, " [");
68                 for (c = 0; c < num_cols; c++) {
69                         fprintf(f, "%*u", width, cost[r*num_cols + c]);
70                 }
71                 fprintf(f, "]\n");
72         }
73         fprintf(f, "\n");
74 }
75
76 void hungarian_print_cost_matrix(hungarian_problem_t *p, int width)
77 {
78         hungarian_dump_f(stderr, p->cost, p->num_rows, p->num_cols, width);
79 }
80
81 hungarian_problem_t *hungarian_new(unsigned num_rows, unsigned num_cols,
82                                    match_type_t match_type)
83 {
84         hungarian_problem_t *p = XMALLOCZ(hungarian_problem_t);
85
86         FIRM_DBG_REGISTER(dbg, "firm.hungarian");
87
88         /*
89                 Is the number of cols  not equal to number of rows ?
90                 If yes, expand with 0 - cols / 0 - cols
91         */
92         num_rows = MAX(num_cols, num_rows);
93         num_cols = num_rows;
94
95         p->num_rows   = num_rows;
96         p->num_cols   = num_cols;
97         p->match_type = match_type;
98
99         /*
100                 In case of normal matching, we have to keep
101                 track of nodes without edges to kill them in
102                 the assignment later.
103         */
104         if (match_type == HUNGARIAN_MATCH_NORMAL) {
105                 p->missing_left  = rbitset_malloc(num_rows);
106                 p->missing_right = rbitset_malloc(num_cols);
107                 rbitset_set_all(p->missing_left,  num_rows);
108                 rbitset_set_all(p->missing_right, num_cols);
109         }
110
111         /* allocate space for cost matrix */
112         p->cost = XMALLOCNZ(unsigned, num_rows * num_cols);
113         return p;
114 }
115
116 void hungarian_prepare_cost_matrix(hungarian_problem_t *p,
117                                    hungarian_mode_t mode)
118 {
119         if (mode == HUNGARIAN_MODE_MAXIMIZE_UTIL) {
120                 unsigned  r, c;
121                 unsigned  num_cols = p->num_cols;
122                 unsigned *cost     = p->cost;
123                 unsigned  max_cost = p->max_cost;
124                 for (r = 0; r < p->num_rows; r++) {
125                         for (c = 0; c < p->num_cols; c++) {
126                                 cost[r*num_cols + c] = max_cost - cost[r*num_cols + c];
127                         }
128                 }
129         } else if (mode == HUNGARIAN_MODE_MINIMIZE_COST) {
130                 /* nothing to do */
131         } else {
132                 panic("Unknown hungarian problem mode\n");
133         }
134 }
135
136 void hungarian_add(hungarian_problem_t *p, unsigned left, unsigned right,
137                    unsigned cost)
138 {
139         assert(p->num_rows > left  && "Invalid row selected.");
140         assert(p->num_cols > right && "Invalid column selected.");
141
142         p->cost[left*p->num_cols + right] = cost;
143         p->max_cost                       = MAX(p->max_cost, cost);
144
145         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
146                 rbitset_clear(p->missing_left, left);
147                 rbitset_clear(p->missing_right, right);
148         }
149 }
150
151 void hungarian_remove(hungarian_problem_t *p, unsigned left, unsigned right)
152 {
153         assert(p->num_rows > left  && "Invalid row selected.");
154         assert(p->num_cols > right && "Invalid column selected.");
155
156         p->cost[left*p->num_cols + right] = 0;
157
158         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
159                 rbitset_set(p->missing_left, left);
160                 rbitset_set(p->missing_right, right);
161         }
162 }
163
164 void hungarian_free(hungarian_problem_t* p)
165 {
166         xfree(p->missing_left);
167         xfree(p->missing_right);
168         xfree(p->cost);
169         xfree(p);
170 }
171
172 int hungarian_solve(hungarian_problem_t* p, unsigned *assignment,
173                     unsigned *final_cost, unsigned cost_threshold)
174 {
175         unsigned  res_cost     = 0;
176         unsigned  num_rows     = p->num_rows;
177         unsigned  num_cols     = p->num_cols;
178         unsigned *cost         = p->cost;
179         unsigned *col_mate     = XMALLOCNZ(unsigned, num_rows);
180         unsigned *row_mate     = XMALLOCNZ(unsigned, num_cols);
181         unsigned *parent_row   = XMALLOCNZ(unsigned, num_cols);
182         unsigned *unchosen_row = XMALLOCNZ(unsigned, num_rows);
183         int      *row_dec      = XMALLOCNZ(int, num_rows);
184         int      *col_inc      = XMALLOCNZ(int, num_cols);
185         int      *slack        = XMALLOCNZ(int, num_cols);
186         unsigned *slack_row    = XMALLOCNZ(unsigned, num_rows);
187         unsigned  r;
188         unsigned  c;
189         unsigned  t;
190         unsigned  unmatched;
191
192         memset(assignment, -1, num_rows * sizeof(assignment[0]));
193
194         /* Begin subtract column minima in order to start with lots of zeros 12 */
195         DBG((dbg, LEVEL_1, "Using heuristic\n"));
196
197         for (c = 0; c < num_cols; ++c) {
198                 unsigned col_mininum = cost[0*num_cols + c];
199
200                 for (r = 1; r < num_rows; ++r) {
201                         if (cost[r*num_cols + c] < col_mininum)
202                                 col_mininum = cost[r*num_cols + c];
203                 }
204
205                 if (col_mininum == 0)
206                         continue;
207
208                 res_cost += col_mininum;
209                 for (r = 0; r < num_rows; ++r)
210                         cost[r*num_cols + c] -= col_mininum;
211         }
212         /* End subtract column minima in order to start with lots of zeros 12 */
213
214         /* Begin initial state 16 */
215         unmatched = 0;
216         for (c = 0; c < num_cols; ++c) {
217                 row_mate[c]   = (unsigned) -1;
218                 parent_row[c] = (unsigned) -1;
219                 col_inc[c]    = 0;
220                 slack[c]      = INT_MAX;
221         }
222
223         for (r = 0; r < num_rows; ++r) {
224                 unsigned row_minimum = cost[r*num_cols + 0];
225
226                 for (c = 1; c < num_cols; ++c) {
227                         if (cost[r*num_cols + c] < row_minimum)
228                                 row_minimum = cost[r*num_cols + c];
229                 }
230
231                 row_dec[r] = row_minimum;
232
233                 for (c = 0; c < num_cols; ++c) {
234                         if (cost[r*num_cols + c] != row_minimum)
235                                 continue;
236                         if (row_mate[c] != (unsigned)-1)
237                                 continue;
238
239                         col_mate[r] = c;
240                         row_mate[c] = r;
241                         DBG((dbg, LEVEL_1, "matching col %u == row %u\n", c, r));
242                         goto row_done;
243                 }
244
245                 col_mate[r] = (unsigned)-1;
246                 DBG((dbg, LEVEL_1, "node %u: unmatched row %u\n", unmatched, r));
247                 unchosen_row[unmatched++] = r;
248 row_done: ;
249         }
250         /* End initial state 16 */
251
252         /* Begin Hungarian algorithm 18 */
253         if (unmatched == 0)
254                 goto done;
255
256         t = unmatched;
257         for (;;) {
258                 unsigned q = 0;
259                 unsigned j;
260                 DBG((dbg, LEVEL_1, "Matched %u rows.\n", num_rows - t));
261
262                 for (;;) {
263                         int s;
264                         while (q < t) {
265                                 /* Begin explore node q of the forest 19 */
266                                 r = unchosen_row[q];
267                                 s = row_dec[r];
268
269                                 for (c = 0; c < num_cols; ++c) {
270                                         if (slack[c]) {
271                                                 int del = cost[r*num_cols + c] - s + col_inc[c];
272
273                                                 if (del < slack[c]) {
274                                                         if (del == 0) {
275                                                                 if (row_mate[c] == (unsigned)-1)
276                                                                         goto breakthru;
277
278                                                                 slack[c]      = 0;
279                                                                 parent_row[c] = r;
280                                                                 DBG((dbg, LEVEL_1, "node %u: row %u == col %u -- row %u\n", t, row_mate[c], c, r));
281                                                                 unchosen_row[t++] = row_mate[c];
282                                                         } else {
283                                                                 slack[c]     = del;
284                                                                 slack_row[c] = r;
285                                                         }
286                                                 }
287                                         }
288                                 }
289                                 /* End explore node q of the forest 19 */
290                                 q++;
291                         }
292
293                         /* Begin introduce a new zero into the matrix 21 */
294                         s = INT_MAX;
295                         for (c = 0; c < num_cols; ++c) {
296                                 if (slack[c] && slack[c] < s)
297                                         s = slack[c];
298                         }
299
300                         for (q = 0; q < t; ++q)
301                                 row_dec[unchosen_row[q]] += s;
302
303                         for (c = 0; c < num_cols; ++c) {
304                                 if (slack[c]) {
305                                         slack[c] -= s;
306                                         if (slack[c] == 0) {
307                                                 /* Begin look at a new zero 22 */
308                                                 r = slack_row[c];
309                                                 DBG((dbg, LEVEL_1, "Decreasing uncovered elements by %d produces zero at [%u, %u]\n", s, r, c));
310                                                 if (row_mate[c] == (unsigned)-1) {
311                                                         for (j = c + 1; j < num_cols; ++j) {
312                                                                 if (slack[j] == 0)
313                                                                         col_inc[j] += s;
314                                                         }
315                                                         goto breakthru;
316                                                 } else {
317                                                         parent_row[c] = r;
318                                                         DBG((dbg, LEVEL_1, "node %u: row %u == col %u -- row %u\n", t, row_mate[c], c, r));
319                                                         unchosen_row[t++] = row_mate[c];
320                                                 }
321                                                 /* End look at a new zero 22 */
322                                         }
323                                 } else {
324                                         col_inc[c] += s;
325                                 }
326                         }
327                         /* End introduce a new zero into the matrix 21 */
328                 }
329 breakthru:
330                 /* Begin update the matching 20 */
331                 DBG((dbg, LEVEL_1, "Breakthrough at node %u of %u.\n", q, t));
332                 for (;;) {
333                         j           = col_mate[r];
334                         col_mate[r] = c;
335                         row_mate[c] = r;
336
337                         DBG((dbg, LEVEL_1, "rematching col %u == row %u\n", c, r));
338                         if (j == (unsigned)-1)
339                                 break;
340
341                         r = parent_row[j];
342                         c = j;
343                 }
344                 /* End update the matching 20 */
345
346                 if (--unmatched == 0)
347                         goto done;
348
349                 /* Begin get ready for another stage 17 */
350                 t = 0;
351                 for (c = 0; c < num_cols; ++c) {
352                         parent_row[c] = (unsigned) -1;
353                         slack[c]      = INT_MAX;
354                 }
355
356                 for (r = 0; r < num_rows; ++r) {
357                         if (col_mate[r] == (unsigned)-1) {
358                                 DBG((dbg, LEVEL_1, "node %u: unmatched row %u\n", t, r));
359                                 unchosen_row[t++] = r;
360                         }
361                 }
362                 /* End get ready for another stage 17 */
363         }
364 done:
365
366         /* Begin double check the solution 23 */
367         for (r = 0; r < num_rows; ++r) {
368                 for (c = 0; c < num_cols; ++c) {
369                         if ((int) cost[r*num_cols + c] < row_dec[r] - col_inc[c])
370                                 return -1;
371                 }
372         }
373
374         for (r = 0; r < num_rows; ++r) {
375                 c = col_mate[r];
376                 if (c == (unsigned)-1
377                     || cost[r*num_cols + c] != (unsigned) (row_dec[r] - col_inc[c]))
378                         return -2;
379         }
380
381         for (r = c = 0; c < num_cols; ++c) {
382                 if (col_inc[c])
383                         r++;
384         }
385
386         if (r > num_rows)
387                 return -3;
388         /* End double check the solution 23 */
389
390         /* End Hungarian algorithm 18 */
391
392         /* collect the assigned values */
393         for (r = 0; r < num_rows; ++r) {
394                 if (cost_threshold > 0
395                     && cost[r*num_cols + col_mate[r]] >= cost_threshold)
396                         assignment[r] = -1; /* remove matching having cost > threshold */
397                 else
398                         assignment[r] = col_mate[r];
399         }
400
401         /* In case of normal matching: remove impossible ones */
402         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
403                 for (r = 0; r < num_rows; ++r) {
404                         if (rbitset_is_set(p->missing_left, r)
405                                 || rbitset_is_set(p->missing_right, col_mate[r]))
406                                 assignment[r] = -1;
407                 }
408         }
409
410         for (r = 0; r < num_rows; ++r) {
411                 for (c = 0; c < num_cols; ++c) {
412                         cost[r*num_cols + c] = cost[r*num_cols + c] - row_dec[r] + col_inc[c];
413                 }
414         }
415
416         for (r = 0; r < num_rows; ++r)
417                 res_cost += row_dec[r];
418
419         for (c = 0; c < num_cols; ++c)
420                 res_cost -= col_inc[c];
421
422         DBG((dbg, LEVEL_1, "Cost is %d\n", res_cost));
423
424         xfree(slack);
425         xfree(col_inc);
426         xfree(parent_row);
427         xfree(row_mate);
428         xfree(slack_row);
429         xfree(row_dec);
430         xfree(unchosen_row);
431         xfree(col_mate);
432
433         if (final_cost != NULL)
434                 *final_cost = res_cost;
435
436         return 0;
437 }