cleanup: Add and use macro MAX().
[cparser] / wrappergen / write_jna.c
index b4ea642..aa9ed36 100644 (file)
@@ -1,27 +1,12 @@
 /*
  * This file is part of cparser.
- * Copyright (C) 2007-2009 Matthias Braun <matze@braunis.de>
- *
- * This program is free software; you can redistribute it and/or
- * modify it under the terms of the GNU General Public License
- * as published by the Free Software Foundation; either version 2
- * of the License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program; if not, write to the Free Software
- * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
- * 02111-1307, USA.
+ * Copyright (C) 2012 Matthias Braun <matze@braunis.de>
  */
 #include <config.h>
 
-#include <errno.h>
 #include <string.h>
 
+#include "adt/strutil.h"
 #include "write_jna.h"
 #include "symbol_t.h"
 #include "ast_t.h"
 #include "type.h"
 #include "printer.h"
 #include "adt/error.h"
-#include <libfirm/adt/pset_new.h>
+#include "adt/xmalloc.h"
+#include "adt/pset_new.h"
+#include "separator_t.h"
+#include "symbol_table.h"
+
+typedef struct output_limit {
+       const char          *filename;
+       struct output_limit *next;
+} output_limit;
 
 static const scope_t *global_scope;
 static FILE          *out;
 static pset_new_t     avoid_symbols;
-
-static void write_type(type_t *type);
-
-static bool is_system_header(const char *fname)
-{
-       return strncmp(fname, "/usr/include", 12) == 0;
-}
+static output_limit  *output_limits;
+static const char    *libname;
 
 static const char *fix_builtin_names(const char *name)
 {
-       if (strcmp(name, "class") == 0) {
-               return "_class";
-       } else if(strcmp(name, "this") == 0) {
-               return "_this";
-       } else if(strcmp(name, "public") == 0) {
-               return "_public";
-       } else if(strcmp(name, "protected") == 0) {
-               return "_protected";
-       } else if(strcmp(name, "private") == 0) {
-               return "_private";
-       } else if(strcmp(name, "final") == 0) {
-               return "_final";
-       }
+#define FIX(x) if (streq(name, x)) return "_" x
+       FIX("class");
+       FIX("final");
+       FIX("private");
+       FIX("protected");
+       FIX("public");
+       FIX("this");
        /* TODO put all reserved names here */
+#undef FIX
        return name;
 }
 
@@ -69,7 +52,7 @@ static const char *get_atomic_type_string(const atomic_type_kind_t type)
        case ATOMIC_TYPE_CHAR:        return "byte";
        case ATOMIC_TYPE_SCHAR:       return "byte";
        case ATOMIC_TYPE_UCHAR:       return "byte";
-       case ATOMIC_TYPE_SHORT:       return "short";
+       case ATOMIC_TYPE_SHORT:       return "short";
        case ATOMIC_TYPE_USHORT:      return "short";
        case ATOMIC_TYPE_INT:         return "int";
        case ATOMIC_TYPE_UINT:        return "int";
@@ -79,7 +62,6 @@ static const char *get_atomic_type_string(const atomic_type_kind_t type)
        case ATOMIC_TYPE_ULONGLONG:   return "long";
        case ATOMIC_TYPE_FLOAT:       return "float";
        case ATOMIC_TYPE_DOUBLE:      return "double";
-       case ATOMIC_TYPE_LONG_DOUBLE: return "double";
        case ATOMIC_TYPE_BOOL:        return "boolean";
        default:                      panic("unsupported atomic type");
        }
@@ -99,7 +81,7 @@ static void write_pointer_type(const pointer_type_t *type)
        }
        if (is_type_pointer(points_to)) {
                /* hack... */
-               fputs("Pointer[]", out);
+               fputs("java.nio.Buffer", out);
                return;
        }
        fputs("Pointer", out);
@@ -213,16 +195,11 @@ static void write_type(type_t *type)
        case TYPE_ENUM:
                write_enum_type(&type->enumt);
                return;
-       case TYPE_BUILTIN:
-               write_type(type->builtin.real_type);
-               return;
        case TYPE_ERROR:
-       case TYPE_INVALID:
        case TYPE_TYPEOF:
        case TYPE_TYPEDEF:
-               panic("invalid type found");
+               panic("invalid type");
        case TYPE_ARRAY:
-       case TYPE_BITFIELD:
        case TYPE_REFERENCE:
        case TYPE_FUNCTION:
        case TYPE_COMPLEX:
@@ -266,11 +243,11 @@ static void write_unary_expression(const unary_expression_t *expression)
        case EXPR_UNARY_NOT:
                fputc('!', out);
                break;
-       case EXPR_UNARY_CAST_IMPLICIT:
+       case EXPR_UNARY_CAST:
                write_expression(expression->value);
                return;
        default:
-               panic("unimeplemented unary expression found");
+               panic("unimplemented unary expression");
        }
        write_expression(expression->value);
 }
@@ -279,6 +256,7 @@ static void write_binary_expression(const binary_expression_t *expression)
 {
        fputs("(", out);
        write_expression(expression->left);
+       fputc(' ', out);
        switch(expression->base.kind) {
        case EXPR_BINARY_BITWISE_OR:  fputs("|", out); break;
        case EXPR_BINARY_BITWISE_AND: fputs("&", out); break;
@@ -292,32 +270,37 @@ static void write_binary_expression(const binary_expression_t *expression)
        default:
                panic("unimplemented binexpr");
        }
+       fputc(' ', out);
        write_expression(expression->right);
        fputs(")", out);
 }
 
+static void write_integer(const literal_expression_t *literal)
+{
+       for (const char *c = literal->value.begin; c != literal->suffix; ++c) {
+               fputc(*c, out);
+       }
+}
+
 static void write_expression(const expression_t *expression)
 {
        /* TODO */
        switch(expression->kind) {
        case EXPR_LITERAL_INTEGER:
-       case EXPR_LITERAL_INTEGER_OCTAL:
-               fprintf(out, "%s", expression->literal.value.begin);
-               break;
-       case EXPR_LITERAL_INTEGER_HEXADECIMAL:
-               fprintf(out, "0x%s", expression->literal.value.begin);
+               write_integer(&expression->literal);
                break;
-       case EXPR_REFERENCE_ENUM_VALUE: {
+
+       case EXPR_ENUM_CONSTANT: {
                /* UHOH... hacking */
                entity_t *entity = expression->reference.entity;
                write_enum_name(& entity->enum_value.enum_type->enumt);
                fprintf(out, ".%s.val", entity->base.symbol->string);
                break;
        }
-       EXPR_UNARY_CASES
+       case EXPR_UNARY_CASES:
                write_unary_expression(&expression->unary);
                break;
-       EXPR_BINARY_CASES
+       case EXPR_BINARY_CASES:
                write_binary_expression(&expression->binary);
                break;
        default:
@@ -357,23 +340,29 @@ static void write_enum(const symbol_t *symbol, const enum_t *entity)
                }
        }
        fprintf(out, "\t\tpublic final int val;\n");
-       fprintf(out, "\t\tprivate static class C { static int next_val; }\n\n");
+       fprintf(out, "\n");
+       fprintf(out, "\t\tprivate static class C {\n");
+       fprintf(out, "\t\t\tstatic int next_val;\n");
+       fprintf(out, "\t\t}\n");
+       fprintf(out, "\n");
        fprintf(out, "\t\t%s(int val) {\n", name);
        fprintf(out, "\t\t\tthis.val = val;\n");
        fprintf(out, "\t\t\tC.next_val = val + 1;\n");
        fprintf(out, "\t\t}\n");
+       fprintf(out, "\n");
        fprintf(out, "\t\t%s() {\n", name);
        fprintf(out, "\t\t\tthis.val = C.next_val++;\n");
        fprintf(out, "\t\t}\n");
-       fprintf(out, "\t\t\n");
+       fprintf(out, "\n");
        fprintf(out, "\t\tpublic static %s getEnum(int val) {\n", name);
-       fprintf(out, "\t\t\tfor(%s entry : values()) {\n", name);
+       fprintf(out, "\t\t\tfor (%s entry : values()) {\n", name);
        fprintf(out, "\t\t\t\tif (val == entry.val)\n");
        fprintf(out, "\t\t\t\t\treturn entry;\n");
        fprintf(out, "\t\t\t}\n");
        fprintf(out, "\t\t\treturn null;\n");
        fprintf(out, "\t\t}\n");
        fprintf(out, "\t}\n");
+       fprintf(out, "\n");
 }
 
 #if 0
@@ -387,7 +376,7 @@ static void write_variable(const entity_t *entity)
 
 static void write_function(const entity_t *entity)
 {
-       if (entity->function.statement != NULL) {
+       if (entity->function.body != NULL) {
                fprintf(stderr, "Warning: can't convert function bodies (at %s)\n",
                        entity->base.symbol->string);
                return;
@@ -397,21 +386,18 @@ static void write_function(const entity_t *entity)
        const function_type_t *function_type
                = (const function_type_t*) entity->declaration.type;
 
-       fputc('\t', out);
+       fputc('\n', out);
+       fprintf(out, "\tpublic static native ");
        type_t *return_type = skip_typeref(function_type->return_type);
        write_type(return_type);
        fprintf(out, " %s(", entity->base.symbol->string);
 
-       entity_t *parameter = entity->function.parameters.entities;
-       int       first     = 1;
-       int       n         = 0;
-       for( ; parameter != NULL; parameter = parameter->base.next) {
+       entity_t   *parameter = entity->function.parameters.entities;
+       separator_t sep       = { "", ", " };
+       int         n         = 0;
+       for ( ; parameter != NULL; parameter = parameter->base.next) {
                assert(parameter->kind == ENTITY_PARAMETER);
-               if(!first) {
-                       fprintf(out, ", ");
-               } else {
-                       first = 0;
-               }
+               fputs(sep_next(&sep), out);
                write_type(parameter->declaration.type);
                if(parameter->base.symbol != NULL) {
                        fprintf(out, " %s", fix_builtin_names(parameter->base.symbol->string));
@@ -420,16 +406,25 @@ static void write_function(const entity_t *entity)
                }
        }
        if(function_type->variadic) {
-               if(!first) {
-                       fprintf(out, ", ");
-               } else {
-                       first = 0;
-               }
+               fputs(sep_next(&sep), out);
                fputs("Object ... args", out);
        }
        fprintf(out, ");\n");
 }
 
+void jna_limit_output(const char *filename)
+{
+       output_limit *limit = xmalloc(sizeof(limit[0]));
+       limit->filename = filename;
+
+       limit->next   = output_limits;
+       output_limits = limit;
+}
+
+void jna_set_libname(const char *new_libname)
+{
+       libname = new_libname;
+}
 
 void write_jna_decls(FILE *output, const translation_unit_t *unit)
 {
@@ -440,17 +435,25 @@ void write_jna_decls(FILE *output, const translation_unit_t *unit)
 
        print_to_file(out);
        fprintf(out, "/* WARNING: Automatically generated file */\n");
-       fputs("import com.sun.jna.Library;\n", out);
+       fputs("import com.sun.jna.Native;\n", out);
        fputs("import com.sun.jna.Pointer;\n", out);
-       fputs("\n\n", out);
+       fputs("\n", out);
+
+       const char *register_libname = libname;
+       if (register_libname == NULL)
+               register_libname = "library";
 
        /* TODO: where to get the name from? */
-       fputs("public interface binding extends Library {\n", out);
+       fputs("public class binding {\n", out);
+       fputs("\tstatic {\n", out);
+       fprintf(out, "\t\tNative.register(\"%s\");\n", register_libname);
+       fputs("\t}\n", out);
+       fputs("\n", out);
 
        /* read the avoid list */
        FILE *avoid = fopen("avoid.config", "r");
        if (avoid != NULL) {
-               while (!feof(avoid)) {
+               for (;;) {
                        char buf[1024];
                        char *res = fgets(buf, sizeof(buf), avoid);
                        if (res == NULL)
@@ -472,7 +475,7 @@ void write_jna_decls(FILE *output, const translation_unit_t *unit)
 
        /* write structs,unions + enums */
        entity_t *entity = unit->scope.entities;
-       for( ; entity != NULL; entity = entity->base.next) {
+       for ( ; entity != NULL; entity = entity->base.next) {
                if (entity->kind == ENTITY_ENUM) {
                        if (find_enum_typedef(&entity->enume) != NULL)
                                continue;
@@ -485,8 +488,7 @@ void write_jna_decls(FILE *output, const translation_unit_t *unit)
                }
 
 #if 0
-               if(type->kind == TYPE_COMPOUND_STRUCT
-                               || type->kind == TYPE_COMPOUND_UNION) {
+               if (is_type_compound(type)) {
                        write_compound(entity->base.symbol, &type->compound);
                }
 #endif
@@ -494,11 +496,26 @@ void write_jna_decls(FILE *output, const translation_unit_t *unit)
 
        /* write functions */
        entity = unit->scope.entities;
-       for( ; entity != NULL; entity = entity->base.next) {
+       for ( ; entity != NULL; entity = entity->base.next) {
                if (entity->kind != ENTITY_FUNCTION)
                        continue;
-               if (is_system_header(entity->base.source_position.input_name))
+               if (entity->base.pos.is_system_header)
                        continue;
+               if (entity->function.elf_visibility != ELF_VISIBILITY_DEFAULT)
+                       continue;
+               if (output_limits != NULL) {
+                       bool              in_limits  = false;
+                       char const *const input_name = entity->base.pos.input_name;
+                       for (output_limit *limit = output_limits; limit != NULL;
+                            limit = limit->next) {
+                           if (streq(limit->filename, input_name)) {
+                                       in_limits = true;
+                                       break;
+                               }
+                       }
+                       if (!in_limits)
+                               continue;
+               }
 
                if (pset_new_contains(&avoid_symbols, entity->base.symbol))
                        continue;