Use streq() instead of strcmp() == 0.
[cparser] / wrappergen / write_jna.c
1 /*
2  * This file is part of cparser.
3  * Copyright (C) 2007-2009 Matthias Braun <matze@braunis.de>
4  *
5  * This program is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU General Public License
7  * as published by the Free Software Foundation; either version 2
8  * of the License, or (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
18  * 02111-1307, USA.
19  */
20 #include <config.h>
21
22 #include <errno.h>
23 #include <string.h>
24
25 #include "adt/strutil.h"
26 #include "write_jna.h"
27 #include "symbol_t.h"
28 #include "ast_t.h"
29 #include "type_t.h"
30 #include "entity_t.h"
31 #include "type.h"
32 #include "printer.h"
33 #include "adt/error.h"
34 #include "adt/xmalloc.h"
35 #include <libfirm/adt/pset_new.h>
36
37 typedef struct output_limit {
38         const char          *filename;
39         struct output_limit *next;
40 } output_limit;
41
42 static const scope_t *global_scope;
43 static FILE          *out;
44 static pset_new_t     avoid_symbols;
45 static output_limit  *output_limits;
46 static const char    *libname;
47
48 static bool is_system_header(const char *fname)
49 {
50         if (strstart(fname, "/usr/include"))
51                 return true;
52         if (fname == builtin_source_position.input_name)
53                 return true;
54         return false;
55 }
56
57 static const char *fix_builtin_names(const char *name)
58 {
59         if (streq(name, "class")) {
60                 return "_class";
61         } else if (streq(name, "this")) {
62                 return "_this";
63         } else if (streq(name, "public")) {
64                 return "_public";
65         } else if (streq(name, "protected")) {
66                 return "_protected";
67         } else if (streq(name, "private")) {
68                 return "_private";
69         } else if (streq(name, "final")) {
70                 return "_final";
71         }
72         /* TODO put all reserved names here */
73         return name;
74 }
75
76 static const char *get_atomic_type_string(const atomic_type_kind_t type)
77 {
78         switch(type) {
79         case ATOMIC_TYPE_VOID:        return "void";
80         case ATOMIC_TYPE_CHAR:        return "byte";
81         case ATOMIC_TYPE_SCHAR:       return "byte";
82         case ATOMIC_TYPE_UCHAR:       return "byte";
83         case ATOMIC_TYPE_SHORT:       return "short";
84         case ATOMIC_TYPE_USHORT:      return "short";
85         case ATOMIC_TYPE_INT:         return "int";
86         case ATOMIC_TYPE_UINT:        return "int";
87         case ATOMIC_TYPE_LONG:        return "com.sun.jna.NativeLong";
88         case ATOMIC_TYPE_ULONG:       return "com.sun.jna.NativeLong";
89         case ATOMIC_TYPE_LONGLONG:    return "long";
90         case ATOMIC_TYPE_ULONGLONG:   return "long";
91         case ATOMIC_TYPE_FLOAT:       return "float";
92         case ATOMIC_TYPE_DOUBLE:      return "double";
93         case ATOMIC_TYPE_BOOL:        return "boolean";
94         default:                      panic("unsupported atomic type");
95         }
96 }
97
98 static void write_atomic_type(const atomic_type_t *type)
99 {
100         fputs(get_atomic_type_string(type->akind), out);
101 }
102
103 static void write_pointer_type(const pointer_type_t *type)
104 {
105         type_t *points_to = skip_typeref(type->points_to);
106         if (is_type_atomic(points_to, ATOMIC_TYPE_CHAR)) {
107                 fputs("String", out);
108                 return;
109         }
110         if (is_type_pointer(points_to)) {
111                 /* hack... */
112                 fputs("java.nio.Buffer", out);
113                 return;
114         }
115         fputs("Pointer", out);
116 }
117
118 static entity_t *find_typedef(const type_t *type)
119 {
120         /* first: search for a matching typedef in the global type... */
121         entity_t *entity = global_scope->entities;
122         for ( ; entity != NULL; entity = entity->base.next) {
123                 if (entity->kind != ENTITY_TYPEDEF)
124                         continue;
125                 if (entity->typedefe.type == type)
126                         break;
127         }
128
129         return entity;
130 }
131
132 static entity_t *find_enum_typedef(const enum_t *enume)
133 {
134         /* first: search for a matching typedef in the global type... */
135         entity_t *entity = global_scope->entities;
136         for ( ; entity != NULL; entity = entity->base.next) {
137                 if (entity->kind != ENTITY_TYPEDEF)
138                         continue;
139                 type_t *type = entity->typedefe.type;
140                 if (type->kind != TYPE_ENUM)
141                         continue;
142
143                 enum_t *e_entity = type->enumt.enume;
144                 if (e_entity == enume)
145                         break;
146         }
147
148         return entity;
149 }
150
151 static void write_compound_type(const compound_type_t *type)
152 {
153         entity_t *entity = find_typedef((const type_t*) type);
154         if(entity != NULL) {
155                 fputs(entity->base.symbol->string, out);
156                 return;
157         }
158
159         /* does the struct have a name? */
160         symbol_t *symbol = type->compound->base.symbol;
161         if(symbol != NULL) {
162                 /* TODO: make sure we create a struct for it... */
163                 fputs(symbol->string, out);
164                 return;
165         }
166         /* TODO: create a struct and use its name here... */
167         fputs("/* TODO anonymous struct */byte", out);
168 }
169
170 static void write_enum_name(const enum_type_t *type)
171 {
172         entity_t *entity = find_typedef((const type_t*) type);
173         if (entity != NULL) {
174                 fputs(entity->base.symbol->string, out);
175                 return;
176         }
177
178         /* does the enum have a name? */
179         symbol_t *symbol = type->enume->base.symbol;
180         if (symbol != NULL) {
181                 /* TODO: make sure we create an enum for it... */
182                 fputs(symbol->string, out);
183                 return;
184         }
185
186         /* now we have a problem as we don't know how we'll call the anonymous
187          * enum */
188         panic("can't reference entries from anonymous enums yet");
189 }
190
191 static void write_enum_type(const enum_type_t *type)
192 {
193         entity_t *entity = find_typedef((const type_t*) type);
194         if (entity != NULL) {
195                 fprintf(out, "/* %s */int", entity->base.symbol->string);
196                 return;
197         }
198
199         /* does the enum have a name? */
200         symbol_t *symbol = type->enume->base.symbol;
201         if (symbol != NULL) {
202                 /* TODO: make sure we create an enum for it... */
203                 fprintf(out, "/* %s */int", symbol->string);
204                 return;
205         }
206         fprintf(out, "/* anonymous enum */int");
207 }
208
209 static void write_type(type_t *type)
210 {
211         type = skip_typeref(type);
212         switch(type->kind) {
213         case TYPE_ATOMIC:
214                 write_atomic_type(&type->atomic);
215                 return;
216         case TYPE_POINTER:
217                 write_pointer_type(&type->pointer);
218                 return;
219         case TYPE_COMPOUND_UNION:
220         case TYPE_COMPOUND_STRUCT:
221                 write_compound_type(&type->compound);
222                 return;
223         case TYPE_ENUM:
224                 write_enum_type(&type->enumt);
225                 return;
226         case TYPE_ERROR:
227         case TYPE_TYPEOF:
228         case TYPE_TYPEDEF:
229                 panic("invalid type found");
230         case TYPE_ARRAY:
231         case TYPE_REFERENCE:
232         case TYPE_FUNCTION:
233         case TYPE_COMPLEX:
234         case TYPE_IMAGINARY:
235                 fprintf(out, "/* TODO type */Pointer");
236                 break;
237         }
238 }
239
240 #if 0
241 static void write_compound_entry(const entity_t *entity)
242 {
243         fprintf(out, "\t%s : ", entity->base.symbol->string);
244         write_type(entity->declaration.type);
245         fprintf(out, "\n");
246 }
247
248 static void write_compound(const symbol_t *symbol, const compound_type_t *type)
249 {
250         fprintf(out, "%s %s:\n",
251                 type->base.kind == TYPE_COMPOUND_STRUCT ? "struct" : "union",
252                         symbol->string);
253
254         const entity_t *entity = type->compound->members.entities;
255         for ( ; entity != NULL; entity = entity->base.next) {
256                 write_compound_entry(entity);
257         }
258
259         fprintf(out, "\n");
260 }
261 #endif
262
263 static void write_expression(const expression_t *expression);
264
265 static void write_unary_expression(const unary_expression_t *expression)
266 {
267         switch(expression->base.kind) {
268         case EXPR_UNARY_NEGATE:
269                 fputc('-', out);
270                 break;
271         case EXPR_UNARY_NOT:
272                 fputc('!', out);
273                 break;
274         case EXPR_UNARY_CAST:
275                 write_expression(expression->value);
276                 return;
277         default:
278                 panic("unimeplemented unary expression found");
279         }
280         write_expression(expression->value);
281 }
282
283 static void write_binary_expression(const binary_expression_t *expression)
284 {
285         fputs("(", out);
286         write_expression(expression->left);
287         fputc(' ', out);
288         switch(expression->base.kind) {
289         case EXPR_BINARY_BITWISE_OR:  fputs("|", out); break;
290         case EXPR_BINARY_BITWISE_AND: fputs("&", out); break;
291         case EXPR_BINARY_BITWISE_XOR: fputs("^", out); break;
292         case EXPR_BINARY_SHIFTLEFT:   fputs("<<", out); break;
293         case EXPR_BINARY_SHIFTRIGHT:  fputs(">>", out); break;
294         case EXPR_BINARY_ADD:         fputs("+", out); break;
295         case EXPR_BINARY_SUB:         fputs("-", out); break;
296         case EXPR_BINARY_MUL:         fputs("*", out); break;
297         case EXPR_BINARY_DIV:         fputs("/", out); break;
298         default:
299                 panic("unimplemented binexpr");
300         }
301         fputc(' ', out);
302         write_expression(expression->right);
303         fputs(")", out);
304 }
305
306 static void write_expression(const expression_t *expression)
307 {
308         /* TODO */
309         switch(expression->kind) {
310         case EXPR_LITERAL_INTEGER:
311         case EXPR_LITERAL_INTEGER_OCTAL:
312                 fprintf(out, "%s", expression->literal.value.begin);
313                 break;
314         case EXPR_LITERAL_INTEGER_HEXADECIMAL:
315                 fprintf(out, "0x%s", expression->literal.value.begin);
316                 break;
317         case EXPR_REFERENCE_ENUM_VALUE: {
318                 /* UHOH... hacking */
319                 entity_t *entity = expression->reference.entity;
320                 write_enum_name(& entity->enum_value.enum_type->enumt);
321                 fprintf(out, ".%s.val", entity->base.symbol->string);
322                 break;
323         }
324         EXPR_UNARY_CASES
325                 write_unary_expression(&expression->unary);
326                 break;
327         EXPR_BINARY_CASES
328                 write_binary_expression(&expression->binary);
329                 break;
330         default:
331                 panic("not implemented expression");
332         }
333 }
334
335 static void write_enum(const symbol_t *symbol, const enum_t *entity)
336 {
337         char buf[128];
338         const char *name;
339
340         if (symbol == NULL) {
341                 static int lastenum = 0;
342                 snprintf(buf, sizeof(buf), "AnonEnum%d", lastenum++);
343                 name = buf;
344         } else {
345                 name = symbol->string;
346         }
347
348         fprintf(out, "\tpublic static enum %s {\n", name);
349
350         entity_t *entry = entity->base.next;
351         for ( ; entry != NULL && entry->kind == ENTITY_ENUM_VALUE;
352                         entry = entry->base.next) {
353                 fprintf(out, "\t\t%s", entry->base.symbol->string);
354                 fprintf(out, "(");
355                 if(entry->enum_value.value != NULL) {
356                         write_expression(entry->enum_value.value);
357                 }
358                 fprintf(out, ")");
359                 if (entry->base.next != NULL
360                                 && entry->base.next->kind == ENTITY_ENUM_VALUE) {
361                         fputs(",\n", out);
362                 } else {
363                         fputs(";\n", out);
364                 }
365         }
366         fprintf(out, "\t\tpublic final int val;\n");
367         fprintf(out, "\n");
368         fprintf(out, "\t\tprivate static class C {\n");
369         fprintf(out, "\t\t\tstatic int next_val;\n");
370         fprintf(out, "\t\t}\n");
371         fprintf(out, "\n");
372         fprintf(out, "\t\t%s(int val) {\n", name);
373         fprintf(out, "\t\t\tthis.val = val;\n");
374         fprintf(out, "\t\t\tC.next_val = val + 1;\n");
375         fprintf(out, "\t\t}\n");
376         fprintf(out, "\n");
377         fprintf(out, "\t\t%s() {\n", name);
378         fprintf(out, "\t\t\tthis.val = C.next_val++;\n");
379         fprintf(out, "\t\t}\n");
380         fprintf(out, "\n");
381         fprintf(out, "\t\tpublic static %s getEnum(int val) {\n", name);
382         fprintf(out, "\t\t\tfor (%s entry : values()) {\n", name);
383         fprintf(out, "\t\t\t\tif (val == entry.val)\n");
384         fprintf(out, "\t\t\t\t\treturn entry;\n");
385         fprintf(out, "\t\t\t}\n");
386         fprintf(out, "\t\t\treturn null;\n");
387         fprintf(out, "\t\t}\n");
388         fprintf(out, "\t}\n");
389         fprintf(out, "\n");
390 }
391
392 #if 0
393 static void write_variable(const entity_t *entity)
394 {
395         fprintf(out, "var %s : ", entity->base.symbol->string);
396         write_type(entity->declaration.type);
397         fprintf(out, "\n");
398 }
399 #endif
400
401 static void write_function(const entity_t *entity)
402 {
403         if (entity->function.statement != NULL) {
404                 fprintf(stderr, "Warning: can't convert function bodies (at %s)\n",
405                         entity->base.symbol->string);
406                 return;
407         }
408
409
410         const function_type_t *function_type
411                 = (const function_type_t*) entity->declaration.type;
412
413         fputc('\n', out);
414         fprintf(out, "\tpublic static native ");
415         type_t *return_type = skip_typeref(function_type->return_type);
416         write_type(return_type);
417         fprintf(out, " %s(", entity->base.symbol->string);
418
419         entity_t *parameter = entity->function.parameters.entities;
420         int       first     = 1;
421         int       n         = 0;
422         for ( ; parameter != NULL; parameter = parameter->base.next) {
423                 assert(parameter->kind == ENTITY_PARAMETER);
424                 if(!first) {
425                         fprintf(out, ", ");
426                 } else {
427                         first = 0;
428                 }
429                 write_type(parameter->declaration.type);
430                 if(parameter->base.symbol != NULL) {
431                         fprintf(out, " %s", fix_builtin_names(parameter->base.symbol->string));
432                 } else {
433                         fprintf(out, " _%d", n++);
434                 }
435         }
436         if(function_type->variadic) {
437                 if(!first) {
438                         fprintf(out, ", ");
439                 } else {
440                         first = 0;
441                 }
442                 fputs("Object ... args", out);
443         }
444         fprintf(out, ");\n");
445 }
446
447 void jna_limit_output(const char *filename)
448 {
449         output_limit *limit = xmalloc(sizeof(limit[0]));
450         limit->filename = filename;
451
452         limit->next   = output_limits;
453         output_limits = limit;
454 }
455
456 void jna_set_libname(const char *new_libname)
457 {
458         libname = new_libname;
459 }
460
461 void write_jna_decls(FILE *output, const translation_unit_t *unit)
462 {
463         out          = output;
464         global_scope = &unit->scope;
465
466         pset_new_init(&avoid_symbols);
467
468         print_to_file(out);
469         fprintf(out, "/* WARNING: Automatically generated file */\n");
470         fputs("import com.sun.jna.Native;\n", out);
471         fputs("import com.sun.jna.Pointer;\n", out);
472         fputs("\n", out);
473
474         const char *register_libname = libname;
475         if (register_libname == NULL)
476                 register_libname = "library";
477
478         /* TODO: where to get the name from? */
479         fputs("public class binding {\n", out);
480         fputs("\tstatic {\n", out);
481         fprintf(out, "\t\tNative.register(\"%s\");\n", register_libname);
482         fputs("\t}\n", out);
483         fputs("\n", out);
484
485         /* read the avoid list */
486         FILE *avoid = fopen("avoid.config", "r");
487         if (avoid != NULL) {
488                 while (!feof(avoid)) {
489                         char buf[1024];
490                         char *res = fgets(buf, sizeof(buf), avoid);
491                         if (res == NULL)
492                                 break;
493                         if (buf[0] == 0)
494                                 continue;
495
496                         size_t len = strlen(buf);
497                         if (buf[len-1] == '\n')
498                                 buf[len-1] = 0;
499
500                         char *str = malloc(len+1);
501                         memcpy(str, buf, len+1);
502                         symbol_t *symbol = symbol_table_insert(str);
503                         pset_new_insert(&avoid_symbols, symbol);
504                 }
505                 fclose(avoid);
506         }
507
508         /* write structs,unions + enums */
509         entity_t *entity = unit->scope.entities;
510         for ( ; entity != NULL; entity = entity->base.next) {
511                 if (entity->kind == ENTITY_ENUM) {
512                         if (find_enum_typedef(&entity->enume) != NULL)
513                                 continue;
514                         write_enum(entity->base.symbol, &entity->enume);
515                 } else if (entity->kind == ENTITY_TYPEDEF) {
516                         type_t *type = entity->typedefe.type;
517                         if (type->kind == TYPE_ENUM) {
518                                 write_enum(entity->base.symbol, type->enumt.enume);
519                         }
520                 }
521
522 #if 0
523                 if(type->kind == TYPE_COMPOUND_STRUCT
524                                 || type->kind == TYPE_COMPOUND_UNION) {
525                         write_compound(entity->base.symbol, &type->compound);
526                 }
527 #endif
528         }
529
530         /* write functions */
531         entity = unit->scope.entities;
532         for ( ; entity != NULL; entity = entity->base.next) {
533                 if (entity->kind != ENTITY_FUNCTION)
534                         continue;
535                 const char *input_name = entity->base.source_position.input_name;
536                 if (is_system_header(input_name))
537                         continue;
538                 if (entity->function.elf_visibility != ELF_VISIBILITY_DEFAULT)
539                         continue;
540                 if (output_limits != NULL) {
541                         bool in_limits = false;
542                         for (output_limit *limit = output_limits; limit != NULL;
543                              limit = limit->next) {
544                             if (streq(limit->filename, input_name)) {
545                                         in_limits = true;
546                                         break;
547                                 }
548                         }
549                         if (!in_limits)
550                                 continue;
551                 }
552
553                 if (pset_new_contains(&avoid_symbols, entity->base.symbol))
554                         continue;
555                 write_function(entity);
556         }
557
558         fputs("}\n", out);
559
560         pset_new_destroy(&avoid_symbols);
561 }