simplify hungarian interface
[libfirm] / ir / adt / hungarian.c
1 /********************************************************************
2  ********************************************************************
3  **
4  ** libhungarian by Cyrill Stachniss, 2004
5  **
6  ** Added and adapted to libFirm by Christian Wuerdig, 2006
7  **
8  ** Solving the Minimum Assignment Problem using the
9  ** Hungarian Method.
10  **
11  ** ** This file may be freely copied and distributed! **
12  **
13  ** Parts of the used code was originally provided by the
14  ** "Stanford GraphGase", but I made changes to this code.
15  ** As asked by  the copyright node of the "Stanford GraphGase",
16  ** I hereby proclaim that this file are *NOT* part of the
17  ** "Stanford GraphGase" distrubition!
18  **
19  ** This file is distributed in the hope that it will be useful,
20  ** but WITHOUT ANY WARRANTY; without even the implied
21  ** warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
22  ** PURPOSE.
23  **
24  ********************************************************************
25  ********************************************************************/
26
27 /**
28  * @file
29  * @brief   Solving the Minimum Assignment Problem using the Hungarian Method.
30  * @version $Id$
31  */
32 #include "config.h"
33
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "irtools.h"
39 #include "xmalloc.h"
40 #include "debug.h"
41 #include "obst.h"
42 #include "bitset.h"
43
44 #include "hungarian.h"
45
46 #define INF (0x7FFFFFFF)
47
48 struct _hungarian_problem_t {
49         int      num_rows;          /**< number of rows */
50         int      num_cols;          /**< number of columns */
51         int      **cost;            /**< the cost matrix */
52         int      max_cost;          /**< the maximal costs in the matrix */
53         int      match_type;        /**< PERFECT or NORMAL matching */
54         bitset_t *missing_left;     /**< left side nodes having no edge to the right side */
55         bitset_t *missing_right;    /**< right side nodes having no edge to the left side */
56         struct obstack obst;
57         DEBUG_ONLY(firm_dbg_module_t *dbg);
58 };
59
60 static inline void *get_init_mem(struct obstack *obst, size_t sz) {
61         void *p = obstack_alloc(obst, sz);
62         memset(p, 0, sz);
63         return p;
64 }
65
66 static void hungarian_dump_f(FILE *f, int **C, int rows, int cols, int width) {
67         int i, j;
68
69         fprintf(f , "\n");
70         for (i = 0; i < rows; i++) {
71                 fprintf(f, " [");
72                 for (j = 0; j < cols; j++) {
73                         fprintf(f, "%*d", width, C[i][j]);
74                 }
75                 fprintf(f, "]\n");
76         }
77         fprintf(f, "\n");
78 }
79
80 void hungarian_print_costmatrix(hungarian_problem_t *p, int width) {
81         hungarian_dump_f(stderr, p->cost, p->num_rows, p->num_cols, width);
82 }
83
84 /**
85  * Create the object and allocate memory for the data structures.
86  */
87 hungarian_problem_t *hungarian_new(int rows, int cols, int match_type) {
88         int i;
89         hungarian_problem_t *p = XMALLOCZ(hungarian_problem_t);
90
91         FIRM_DBG_REGISTER(p->dbg, "firm.hungarian");
92
93         /*
94                 Is the number of cols  not equal to number of rows ?
95                 If yes, expand with 0 - cols / 0 - cols
96         */
97         rows = MAX(cols, rows);
98         cols = rows;
99
100         obstack_init(&p->obst);
101
102         p->num_rows   = rows;
103         p->num_cols   = cols;
104         p->match_type = match_type;
105
106         /*
107                 In case of normal matching, we have to keep
108                 track of nodes without edges to kill them in
109                 the assignment later.
110         */
111         if (match_type == HUNGARIAN_MATCH_NORMAL) {
112                 p->missing_left  = bitset_obstack_alloc(&p->obst, rows);
113                 p->missing_right = bitset_obstack_alloc(&p->obst, cols);
114                 bitset_set_all(p->missing_left);
115                 bitset_set_all(p->missing_right);
116         }
117
118         /* allocate space for cost matrix */
119         p->cost = (int **)get_init_mem(&p->obst, rows * sizeof(p->cost[0]));
120         for (i = 0; i < p->num_rows; i++)
121                 p->cost[i] = (int *)get_init_mem(&p->obst, cols * sizeof(p->cost[0][0]));
122
123         return p;
124 }
125
126 /**
127  * Prepare the cost matrix.
128  */
129 void hungarian_prepare_cost_matrix(hungarian_problem_t *p, int mode) {
130         int i, j;
131
132         if (mode == HUNGARIAN_MODE_MAXIMIZE_UTIL) {
133                 for (i = 0; i < p->num_rows; i++) {
134                         for (j = 0; j < p->num_cols; j++) {
135                                 p->cost[i][j] = p->max_cost - p->cost[i][j];
136                         }
137                 }
138         }
139         else if (mode == HUNGARIAN_MODE_MINIMIZE_COST) {
140                 /* nothing to do */
141         }
142         else
143                 fprintf(stderr, "Unknown mode. Mode was set to HUNGARIAN_MODE_MINIMIZE_COST.\n");
144 }
145
146 /**
147  * Set cost[left][right] to cost.
148  */
149 void hungarian_add(hungarian_problem_t *p, int left, int right, int cost) {
150         assert(p->num_rows > left  && "Invalid row selected.");
151         assert(p->num_cols > right && "Invalid column selected.");
152         assert(cost >= 0);
153
154         p->cost[left][right] = cost;
155         p->max_cost          = MAX(p->max_cost, cost);
156
157         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
158                 bitset_clear(p->missing_left, left);
159                 bitset_clear(p->missing_right, right);
160         }
161 }
162
163 /**
164  * Set cost[left][right] to 0.
165  */
166 void hungarian_remv(hungarian_problem_t *p, int left, int right) {
167         assert(p->num_rows > left  && "Invalid row selected.");
168         assert(p->num_cols > right && "Invalid column selected.");
169
170         p->cost[left][right] = 0;
171
172         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
173                 bitset_set(p->missing_left, left);
174                 bitset_set(p->missing_right, right);
175         }
176 }
177
178 /**
179  * Frees all allocated memory.
180  */
181 void hungarian_free(hungarian_problem_t* p) {
182         obstack_free(&p->obst, NULL);
183         xfree(p);
184 }
185
186 /**
187  * Do the assignment.
188  */
189 int hungarian_solve(hungarian_problem_t* p, int *assignment, int *final_cost, int cost_threshold) {
190         int i, j, m, n, k, l, s, t, q, unmatched, cost;
191         int *col_mate;
192         int *row_mate;
193         int *parent_row;
194         int *unchosen_row;
195         int *row_dec;
196         int *col_inc;
197         int *slack;
198         int *slack_row;
199
200         cost = 0;
201         m    = p->num_rows;
202         n    = p->num_cols;
203
204         col_mate     = XMALLOCNZ(int, p->num_rows);
205         unchosen_row = XMALLOCNZ(int, p->num_rows);
206         row_dec      = XMALLOCNZ(int, p->num_rows);
207         slack_row    = XMALLOCNZ(int, p->num_rows);
208
209         row_mate     = XMALLOCNZ(int, p->num_cols);
210         parent_row   = XMALLOCNZ(int, p->num_cols);
211         col_inc      = XMALLOCNZ(int, p->num_cols);
212         slack        = XMALLOCNZ(int, p->num_cols);
213
214         memset(assignment, -1, m * sizeof(assignment[0]));
215
216         /* Begin subtract column minima in order to start with lots of zeros 12 */
217         DBG((p->dbg, LEVEL_1, "Using heuristic\n"));
218
219         for (l = 0; l < n; ++l) {
220                 s = p->cost[0][l];
221
222                 for (k = 1; k < m; ++k) {
223                         if (p->cost[k][l] < s)
224                                 s = p->cost[k][l];
225                 }
226
227                 cost += s;
228
229                 if (s != 0) {
230                         for (k = 0; k < m; ++k)
231                                 p->cost[k][l] -= s;
232                 }
233         }
234         /* End subtract column minima in order to start with lots of zeros 12 */
235
236         /* Begin initial state 16 */
237         t = 0;
238         for (l = 0; l < n; ++l) {
239                 row_mate[l]   = -1;
240                 parent_row[l] = -1;
241                 col_inc[l]    = 0;
242                 slack[l]      = INF;
243         }
244
245         for (k = 0; k < m; ++k) {
246                 s = p->cost[k][0];
247
248                 for (l = 1; l < n; ++l) {
249                         if (p->cost[k][l] < s)
250                                 s = p->cost[k][l];
251                 }
252
253                 row_dec[k] = s;
254
255                 for (l = 0; l < n; ++l) {
256                         if (s == p->cost[k][l] && row_mate[l] < 0) {
257                                 col_mate[k] = l;
258                                 row_mate[l] = k;
259                                 DBG((p->dbg, LEVEL_1, "matching col %d == row %d\n", l, k));
260                                 goto row_done;
261                         }
262                 }
263
264                 col_mate[k] = -1;
265                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, k));
266                 unchosen_row[t++] = k;
267 row_done: ;
268         }
269         /* End initial state 16 */
270
271         /* Begin Hungarian algorithm 18 */
272         if (t == 0)
273                 goto done;
274
275         unmatched = t;
276         while (1) {
277                 DBG((p->dbg, LEVEL_1, "Matched %d rows.\n", m - t));
278                 q = 0;
279
280                 while (1) {
281                         while (q < t) {
282                                 /* Begin explore node q of the forest 19 */
283                                 k = unchosen_row[q];
284                                 s = row_dec[k];
285
286                                 for (l = 0; l < n; ++l) {
287                                         if (slack[l]) {
288                                                 int del = p->cost[k][l] - s + col_inc[l];
289
290                                                 if (del < slack[l]) {
291                                                         if (del == 0) {
292                                                                 if (row_mate[l] < 0)
293                                                                         goto breakthru;
294
295                                                                 slack[l]      = 0;
296                                                                 parent_row[l] = k;
297                                                                 DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[l], l, k));
298                                                                 unchosen_row[t++] = row_mate[l];
299                                                         }
300                                                         else {
301                                                                 slack[l]     = del;
302                                                                 slack_row[l] = k;
303                                                         }
304                                                 }
305                                         }
306                                 }
307                                 /* End explore node q of the forest 19 */
308                                 q++;
309                         }
310
311                         /* Begin introduce a new zero into the matrix 21 */
312                         s = INF;
313                         for (l = 0; l < n; ++l) {
314                                 if (slack[l] && slack[l] < s)
315                                         s = slack[l];
316                         }
317
318                         for (q = 0; q < t; ++q)
319                                 row_dec[unchosen_row[q]] += s;
320
321                         for (l = 0; l < n; ++l) {
322                                 if (slack[l]) {
323                                         slack[l] -= s;
324                                         if (slack[l] == 0) {
325                                                 /* Begin look at a new zero 22 */
326                                                 k = slack_row[l];
327                                                 DBG((p->dbg, LEVEL_1, "Decreasing uncovered elements by %d produces zero at [%d, %d]\n", s, k, l));
328                                                 if (row_mate[l] < 0) {
329                                                         for (j = l + 1; j < n; ++j) {
330                                                                 if (slack[j] == 0)
331                                                                         col_inc[j] += s;
332                                                         }
333                                                         goto breakthru;
334                                                 }
335                                                 else {
336                                                         parent_row[l] = k;
337                                                         DBG((p->dbg, LEVEL_1, "node %d: row %d == col %d -- row %d\n", t, row_mate[l], l, k));
338                                                         unchosen_row[t++] = row_mate[l];
339                                                 }
340                                                 /* End look at a new zero 22 */
341                                         }
342                                 }
343                                 else {
344                                         col_inc[l] += s;
345                                 }
346                         }
347                         /* End introduce a new zero into the matrix 21 */
348                 }
349 breakthru:
350                 /* Begin update the matching 20 */
351                 DBG((p->dbg, LEVEL_1, "Breakthrough at node %d of %d.\n", q, t));
352                 while (1) {
353                         j           = col_mate[k];
354                         col_mate[k] = l;
355                         row_mate[l] = k;
356
357                         DBG((p->dbg, LEVEL_1, "rematching col %d == row %d\n", l, k));
358                         if (j < 0)
359                                 break;
360
361                         k = parent_row[j];
362                         l = j;
363                 }
364                 /* End update the matching 20 */
365
366                 if (--unmatched == 0)
367                         goto done;
368
369                 /* Begin get ready for another stage 17 */
370                 t = 0;
371                 for (l = 0; l < n; ++l) {
372                         parent_row[l] = -1;
373                         slack[l]      = INF;
374                 }
375
376                 for (k = 0; k < m; ++k) {
377                         if (col_mate[k] < 0) {
378                                 DBG((p->dbg, LEVEL_1, "node %d: unmatched row %d\n", t, k));
379                                 unchosen_row[t++] = k;
380                         }
381                 }
382                 /* End get ready for another stage 17 */
383         }
384 done:
385
386         /* Begin double check the solution 23 */
387         for (k = 0; k < m; ++k) {
388                 for (l = 0; l < n; ++l) {
389                         if (p->cost[k][l] < row_dec[k] - col_inc[l])
390                                 return -1;
391                 }
392         }
393
394         for (k = 0; k < m; ++k) {
395                 l = col_mate[k];
396                 if (l < 0 || p->cost[k][l] != row_dec[k] - col_inc[l])
397                         return -2;
398         }
399
400         for (k = l = 0; l < n; ++l) {
401                 if (col_inc[l])
402                         k++;
403         }
404
405         if (k > m)
406                 return -3;
407         /* End double check the solution 23 */
408
409         /* End Hungarian algorithm 18 */
410
411         /* collect the assigned values */
412         for (i = 0; i < m; ++i) {
413                 if (cost_threshold > 0 && p->cost[i][col_mate[i]] >= cost_threshold)
414                         assignment[i] = -1; /* remove matching having cost > threshold */
415                 else
416                         assignment[i] = col_mate[i];
417         }
418
419         /* In case of normal matching: remove impossible ones */
420         if (p->match_type == HUNGARIAN_MATCH_NORMAL) {
421                 for (i = 0; i < m; ++i) {
422                         if (bitset_is_set(p->missing_left, i) || bitset_is_set(p->missing_right, col_mate[i]))
423                                 assignment[i] = -1;
424                 }
425         }
426
427         for (k = 0; k < m; ++k) {
428                 for (l = 0; l < n; ++l) {
429                         p->cost[k][l] = p->cost[k][l] - row_dec[k] + col_inc[l];
430                 }
431         }
432
433         for (i = 0; i < m; ++i)
434                 cost += row_dec[i];
435
436         for (i = 0; i < n; ++i)
437                 cost -= col_inc[i];
438
439         DBG((p->dbg, LEVEL_1, "Cost is %d\n", cost));
440
441         xfree(slack);
442         xfree(col_inc);
443         xfree(parent_row);
444         xfree(row_mate);
445         xfree(slack_row);
446         xfree(row_dec);
447         xfree(unchosen_row);
448         xfree(col_mate);
449
450         *final_cost = cost;
451
452         return 0;
453 }