943567edc14cdc274f34422b29b336e9b398018f
[libfirm] / ir / adt / hungarian.c
1 /********************************************************************
2  ********************************************************************
3  **
4  ** libhungarian by Cyrill Stachniss, 2004
5  **
6  ** Added and adapted to libFirm by Christian Wuerdig, 2006
7  **
8  ** Solving the Minimum Assignment Problem using the
9  ** Hungarian Method.
10  **
11  ** ** This file may be freely copied and distributed! **
12  **
13  ** Parts of the used code was originally provided by the
14  ** "Stanford GraphGase", but I made changes to this code.
15  ** As asked by  the copyright node of the "Stanford GraphGase",
16  ** I hereby proclaim that this file are *NOT* part of the
17  ** "Stanford GraphGase" distrubition!
18  **
19  ** This file is distributed in the hope that it will be useful,
20  ** but WITHOUT ANY WARRANTY; without even the implied
21  ** warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
22  ** PURPOSE.
23  **
24  ********************************************************************
25  ********************************************************************/
26
27 /* $Id$ */
28
29 #ifdef HAVE_CONFIG_H
30 # include "config.h"
31 #endif
32
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <assert.h>
36
37 #include "irtools.h"
38 #include "xmalloc.h"
39 #include "debug.h"
40 #include "obst.h"
41 #include "bitset.h"
42
43 #include "hungarian.h"
44
45 #define INF (0x7FFFFFFF)
46
47 struct _hungarian_problem_t {
48         int      num_rows;          /**< number of rows */
49         int      num_cols;          /**< number of columns */
50         int      **cost;            /**< the cost matrix */
51         int      width;             /**< the width for cost matrix dumper */
52         int      max_cost;          /**< the maximal costs in the matrix */
53         int      match_type;        /**< PERFECT or NORMAL matching */
54         bitset_t *missing_left;     /**< left side nodes having no edge to the right side */
55         bitset_t *missing_right;    /**< right side nodes having no edge to the left side */
56         struct obstack obst;
57         DEBUG_ONLY(firm_dbg_module_t *dbg);
58 };
59
60 static INLINE void *get_init_mem(struct obstack *obst, long sz) {
61         void *p = obstack_alloc(obst, sz);
62         memset(p, 0, sz);
63         return p;
64 }
65
66 static void hungarian_dump_f(FILE *f, int **C, int rows, int cols, int width) {
67         int i, j;
68
69         fprintf(f , "\n");
70         for (i = 0; i < rows; i++) {
71                 fprintf(f, " [");
72                 for (j = 0; j < cols; j++) {
73                         fprintf(f, "%*d", width, C[i][j]);
74                 }
75                 fprintf(f, "]\n");
76         }
77         fprintf(f, "\n");
78 }
79
80 void hungarian_print_costmatrix(hungarian_problem_t *p) {
81         hungarian_dump_f(stderr, p->cost, p->num_rows, p->num_cols, p->width);
82 }
83
84 /**
85  * Create the object and allocate memory for the data structures.
86  */
87 hungarian_problem_t *hungarian_new(int rows, int cols, int width, int match_type) {
88         int i;
89         hungarian_problem_t *p = xmalloc(sizeof(*p));
90
91         memset(p, 0, sizeof(p[0]));
92
93         FIRM_DBG_REGISTER(p->dbg, "firm.hungarian");
94
95         /*
96                 Is the number of cols  not equal to number of rows ?
97                 If yes, expand with 0 - cols / 0 - cols
98         */
99         rows = MAX(cols, rows);
100         cols = rows;
101
102         obstack_init(&p->obst);
103
104         p->num_rows   = rows;
105         p->num_cols   = cols;
106         p->width      = width;
107         p->match_type = match_type;
108
109         /*
110                 In case of normal matching, we have to keep
111                 track of nodes without edges to kill them in
112                 the assignment later.
113         */
114         if (match_type == HUNGARIAN_MATCH_NORMAL) {
115                 p->missing_left  = bitset_obstack_alloc(&p->obst, rows);
116                 p->missing_right = bitset_obstack_alloc(&p->obst, cols);
117                 bitset_set_all(p->missing_left);
118                 bitset_set_all(p->missing_right);
119         }
120
121         /* allocate space for cost matrix */
122         p->cost = (int **)get_init_mem(&p->obst, rows * sizeof(p->cost[0]));
123         for (i = 0; i < p->num_rows; i++)
124                 p->cost[i] = (int *)get_init_mem(&p->obst, cols * sizeof(p->cost[0][0]));
125
126         return p;
127 }
128
129 /**
130  * Prepare the cost matrix.
131  */
132 void hungarian_prepare_cost_matrix(hungarian_problem_t *p, int mode) {
133         int i, j;
134
135         if (mode == HUNGARIAN_MODE_MAXIMIZE_UTIL) {
136                 for (i = 0; i < p->num_rows; i++) {
137                         for (j = 0; j < p->num_cols; j++) {
138                                 p->cost[i][j] = p->max_cost - p->cost[i][j];
139                         }
140                 }
141         }
142         else if (mode == HUNGARIAN_MODE_MINIMIZE_COST) {
143                 /* nothing to do */
144         }
145         else
146                 fprintf(stderr, "Unknown mode. Mode was set to HUNGARIAN_MODE_MINIMIZE_COST.\n");
147 }
148
149 /**
150  * Set cost[left][right] to cost.
151  */
152 void hungarian_add(hungarian_problem_t *p, int left, int right, int cost) {
153         assert(p->num_rows > left  && "Invalid row selected.");
154         assert(p->num_cols > right && "Invalid column selected.");
155         assert(cost >= 0);
156
157         p->cost[left][right] = cost;
158         p->max_cost          = MAX(p->max_cost, cost);
159
160         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
161                 bitset_clear(p->missing_left, left);
162                 bitset_clear(p->missing_right, right);
163         }
164 }
165
166 /**
167  * Set cost[left][right] to 0.
168  */
169 void hungarian_remv(hungarian_problem_t *p, int left, int right) {
170         assert(p->num_rows > left  && "Invalid row selected.");
171         assert(p->num_cols > right && "Invalid column selected.");
172
173         p->cost[left][right] = 0;
174
175         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
176                 bitset_set(p->missing_left, left);
177                 bitset_set(p->missing_right, right);
178         }
179 }
180
181 /**
182  * Frees all allocated memory.
183  */
184 void hungarian_free(hungarian_problem_t* p) {
185         obstack_free(&p->obst, NULL);
186         xfree(p);
187 }
188
189 /**
190  * Do the assignment.
191  */
192 int hungarian_solve(hungarian_problem_t* p, int *assignment, int *final_cost, int cost_threshold) {
193         int i, j, m, n, k, l, s, t, q, unmatched, cost;
194         int *col_mate;
195         int *row_mate;
196         int *parent_row;
197         int *unchosen_row;
198         int *row_dec;
199         int *col_inc;
200         int *slack;
201         int *slack_row;
202
203         cost = 0;
204         m    = p->num_rows;
205         n    = p->num_cols;
206
207         col_mate     = xcalloc(p->num_rows, sizeof(col_mate[0]));
208         unchosen_row = xcalloc(p->num_rows, sizeof(unchosen_row[0]));
209         row_dec      = xcalloc(p->num_rows, sizeof(row_dec[0]));
210         slack_row    = xcalloc(p->num_rows, sizeof(slack_row[0]));
211
212         row_mate     = xcalloc(p->num_cols, sizeof(row_mate[0]));
213         parent_row   = xcalloc(p->num_cols, sizeof(parent_row[0]));
214         col_inc      = xcalloc(p->num_cols, sizeof(col_inc[0]));
215         slack        = xcalloc(p->num_cols, sizeof(slack[0]));
216
217         memset(assignment, -1, m * sizeof(assignment[0]));
218
219         /* Begin subtract column minima in order to start with lots of zeros 12 */
220         DBG((p->dbg, LEVEL_1, "Using heuristic\n"));
221
222         for (l = 0; l < n; ++l) {
223                 s = p->cost[0][l];
224
225                 for (k = 1; k < m; ++k) {
226                         if (p->cost[k][l] < s)
227                                 s = p->cost[k][l];
228                 }
229
230                 cost += s;
231
232                 if (s != 0) {
233                         for (k = 0; k < m; ++k)
234                                 p->cost[k][l] -= s;
235                 }
236         }
237         /* End subtract column minima in order to start with lots of zeros 12 */
238
239         /* Begin initial state 16 */
240         t = 0;
241         for (l = 0; l < n; ++l) {
242                 row_mate[l]   = -1;
243                 parent_row[l] = -1;
244                 col_inc[l]    = 0;
245                 slack[l]      = INF;
246         }
247
248         for (k = 0; k < m; ++k) {
249                 s = p->cost[k][0];
250
251                 for (l = 1; l < n; ++l) {
252                         if (p->cost[k][l] < s)
253                                 s = p->cost[k][l];
254                 }
255
256                 row_dec[k] = s;
257
258                 for (l = 0; l < n; ++l) {
259                         if (s == p->cost[k][l] && row_mate[l] < 0) {
260                                 col_mate[k] = l;
261                                 row_mate[l] = k;
262                                 DBG((p->dbg, LEVEL_1, "matching col %d == row %d\n", l, k));
263                                 goto row_done;
264                         }
265                 }
266
267                 col_mate[k] = -1;
268                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, k));
269                 unchosen_row[t++] = k;
270 row_done: ;
271         }
272         /* End initial state 16 */
273
274         /* Begin Hungarian algorithm 18 */
275         if (t == 0)
276                 goto done;
277
278         unmatched = t;
279         while (1) {
280                 DBG((p->dbg, LEVEL_1, "Matched %d rows.\n", m - t));
281                 q = 0;
282
283                 while (1) {
284                         while (q < t) {
285                                 /* Begin explore node q of the forest 19 */
286                                 k = unchosen_row[q];
287                                 s = row_dec[k];
288
289                                 for (l = 0; l < n; ++l) {
290                                         if (slack[l]) {
291                                                 int del = p->cost[k][l] - s + col_inc[l];
292
293                                                 if (del < slack[l]) {
294                                                         if (del == 0) {
295                                                                 if (row_mate[l] < 0)
296                                                                         goto breakthru;
297
298                                                                 slack[l]      = 0;
299                                                                 parent_row[l] = k;
300                                                                 DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[l], l, k));
301                                                                 unchosen_row[t++] = row_mate[l];
302                                                         }
303                                                         else {
304                                                                 slack[l]     = del;
305                                                                 slack_row[l] = k;
306                                                         }
307                                                 }
308                                         }
309                                 }
310                                 /* End explore node q of the forest 19 */
311                                 q++;
312                         }
313
314                         /* Begin introduce a new zero into the matrix 21 */
315                         s = INF;
316                         for (l = 0; l < n; ++l) {
317                                 if (slack[l] && slack[l] < s)
318                                         s = slack[l];
319                         }
320
321                         for (q = 0; q < t; ++q)
322                                 row_dec[unchosen_row[q]] += s;
323
324                         for (l = 0; l < n; ++l) {
325                                 if (slack[l]) {
326                                         slack[l] -= s;
327                                         if (slack[l] == 0) {
328                                                 /* Begin look at a new zero 22 */
329                                                 k = slack_row[l];
330                                                 DBG((p->dbg, LEVEL_1, "Decreasing uncovered elements by %d produces zero at [%d, %d]\n", s, k, l));
331                                                 if (row_mate[l] < 0) {
332                                                         for (j = l + 1; j < n; ++j) {
333                                                                 if (slack[j] == 0)
334                                                                         col_inc[j] += s;
335                                                         }
336                                                         goto breakthru;
337                                                 }
338                                                 else {
339                                                         parent_row[l] = k;
340                                                         DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[l], l, k));
341                                                         unchosen_row[t++] = row_mate[l];
342                                                 }
343                                                 /* End look at a new zero 22 */
344                                         }
345                                 }
346                                 else {
347                                         col_inc[l] += s;
348                                 }
349                         }
350                         /* End introduce a new zero into the matrix 21 */
351                 }
352 breakthru:
353                 /* Begin update the matching 20 */
354                 DBG((p->dbg, LEVEL_1, "Breakthrough at node %d of %d.\n", q, t));
355                 while (1) {
356                         j           = col_mate[k];
357                         col_mate[k] = l;
358                         row_mate[l] = k;
359
360                         DBG((p->dbg, LEVEL_1, "rematching col %d == row %d\n", l, k));
361                         if (j < 0)
362                                 break;
363
364                         k = parent_row[j];
365                         l = j;
366                 }
367                 /* End update the matching 20 */
368
369                 if (--unmatched == 0)
370                         goto done;
371
372                 /* Begin get ready for another stage 17 */
373                 t = 0;
374                 for (l = 0; l < n; ++l) {
375                         parent_row[l] = -1;
376                         slack[l]      = INF;
377                 }
378
379                 for (k = 0; k < m; ++k) {
380                         if (col_mate[k] < 0) {
381                                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, k));
382                                 unchosen_row[t++] = k;
383                         }
384                 }
385                 /* End get ready for another stage 17 */
386         }
387 done:
388
389         /* Begin double check the solution 23 */
390         for (k = 0; k < m; ++k) {
391                 for (l = 0; l < n; ++l) {
392                         if (p->cost[k][l] < row_dec[k] - col_inc[l])
393                                 return -1;
394                 }
395         }
396
397         for (k = 0; k < m; ++k) {
398                 l = col_mate[k];
399                 if (l < 0 || p->cost[k][l] != row_dec[k] - col_inc[l])
400                         return -2;
401         }
402
403         for (k = l = 0; l < n; ++l) {
404                 if (col_inc[l])
405                         k++;
406         }
407
408         if (k > m)
409                 return -3;
410         /* End double check the solution 23 */
411
412         /* End Hungarian algorithm 18 */
413
414         /* collect the assigned values */
415         for (i = 0; i < m; ++i) {
416                 if (cost_threshold > 0 && p->cost[i][col_mate[i]] >= cost_threshold)
417                         assignment[i] = -1; /* remove matching having cost > threshold */
418                 else
419                         assignment[i] = col_mate[i];
420         }
421
422         /* In case of normal matching: remove impossible ones */
423         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
424                 for (i = 0; i < m; ++i) {
425                         if (bitset_is_set(p->missing_left, i) || bitset_is_set(p->missing_right, col_mate[i]))
426                                 assignment[i] = -1;
427                 }
428         }
429
430         for (k = 0; k < m; ++k) {
431                 for (l = 0; l < n; ++l) {
432                         p->cost[k][l] = p->cost[k][l] - row_dec[k] + col_inc[l];
433                 }
434         }
435
436         for (i = 0; i < m; ++i)
437                 cost += row_dec[i];
438
439         for (i = 0; i < n; ++i)
440                 cost -= col_inc[i];
441
442         DBG((p->dbg, LEVEL_1, "Cost is %d\n", cost));
443
444         xfree(slack);
445         xfree(col_inc);
446         xfree(parent_row);
447         xfree(row_mate);
448         xfree(slack_row);
449         xfree(row_dec);
450         xfree(unchosen_row);
451         xfree(col_mate);
452
453         *final_cost = cost;
454
455         return 0;
456 }