convert statements and expression to new union style (but didn't remove all casts...
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <ctype.h>
10
11 #include "adt/error.h"
12
13 struct obstack ast_obstack;
14
15 static FILE *out;
16 static int   indent;
17
18 static void print_statement(const statement_t *statement);
19
20 void change_indent(int delta)
21 {
22         indent += delta;
23         assert(indent >= 0);
24 }
25
26 void print_indent(void)
27 {
28         for(int i = 0; i < indent; ++i)
29                 fprintf(out, "\t");
30 }
31
32 static void print_const(const const_expression_t *cnst)
33 {
34         if(cnst->expression.datatype == NULL)
35                 return;
36
37         if(is_type_integer(cnst->expression.datatype)) {
38                 fprintf(out, "%lld", cnst->v.int_value);
39         } else if(is_type_floating(cnst->expression.datatype)) {
40                 fprintf(out, "%Lf", cnst->v.float_value);
41         }
42 }
43
44 static void print_string_literal(
45                 const string_literal_expression_t *string_literal)
46 {
47         fputc('"', out);
48         for(const char *c = string_literal->value; *c != '\0'; ++c) {
49                 switch(*c) {
50                 case '\"':  fputs("\\\"", out); break;
51                 case '\\':  fputs("\\\\", out); break;
52                 case '\a':  fputs("\\a", out); break;
53                 case '\b':  fputs("\\b", out); break;
54                 case '\f':  fputs("\\f", out); break;
55                 case '\n':  fputs("\\n", out); break;
56                 case '\r':  fputs("\\r", out); break;
57                 case '\t':  fputs("\\t", out); break;
58                 case '\v':  fputs("\\v", out); break;
59                 case '\?':  fputs("\\?", out); break;
60                 default:
61                         if(!isprint(*c)) {
62                                 fprintf(out, "\\x%x", *c);
63                                 break;
64                         }
65                         fputc(*c, out);
66                         break;
67                 }
68         }
69         fputc('"', out);
70 }
71
72 static void print_call_expression(const call_expression_t *call)
73 {
74         print_expression(call->function);
75         fprintf(out, "(");
76         call_argument_t *argument = call->arguments;
77         int              first    = 1;
78         while(argument != NULL) {
79                 if(!first) {
80                         fprintf(out, ", ");
81                 } else {
82                         first = 0;
83                 }
84                 print_expression(argument->expression);
85
86                 argument = argument->next;
87         }
88         fprintf(out, ")");
89 }
90
91 static void print_binary_expression(const binary_expression_t *binexpr)
92 {
93         fprintf(out, "(");
94         print_expression(binexpr->left);
95         fprintf(out, " ");
96         switch(binexpr->type) {
97         case BINEXPR_INVALID:            fputs("INVOP", out); break;
98         case BINEXPR_COMMA:              fputs(",", out);     break;
99         case BINEXPR_ASSIGN:             fputs("=", out);     break;
100         case BINEXPR_ADD:                fputs("+", out);     break;
101         case BINEXPR_SUB:                fputs("-", out);     break;
102         case BINEXPR_MUL:                fputs("*", out);     break;
103         case BINEXPR_MOD:                fputs("%", out);     break;
104         case BINEXPR_DIV:                fputs("/", out);     break;
105         case BINEXPR_BITWISE_OR:         fputs("|", out);     break;
106         case BINEXPR_BITWISE_AND:        fputs("&", out);     break;
107         case BINEXPR_BITWISE_XOR:        fputs("^", out);     break;
108         case BINEXPR_LOGICAL_OR:         fputs("||", out);    break;
109         case BINEXPR_LOGICAL_AND:        fputs("&&", out);    break;
110         case BINEXPR_NOTEQUAL:           fputs("!=", out);    break;
111         case BINEXPR_EQUAL:              fputs("==", out);    break;
112         case BINEXPR_LESS:               fputs("<", out);     break;
113         case BINEXPR_LESSEQUAL:          fputs("<=", out);    break;
114         case BINEXPR_GREATER:            fputs(">", out);     break;
115         case BINEXPR_GREATEREQUAL:       fputs(">=", out);    break;
116         case BINEXPR_SHIFTLEFT:          fputs("<<", out);    break;
117         case BINEXPR_SHIFTRIGHT:         fputs(">>", out);    break;
118
119         case BINEXPR_ADD_ASSIGN:         fputs("+=", out);    break;
120         case BINEXPR_SUB_ASSIGN:         fputs("-=", out);    break;
121         case BINEXPR_MUL_ASSIGN:         fputs("*=", out);    break;
122         case BINEXPR_MOD_ASSIGN:         fputs("%=", out);    break;
123         case BINEXPR_DIV_ASSIGN:         fputs("/=", out);    break;
124         case BINEXPR_BITWISE_OR_ASSIGN:  fputs("|=", out);    break;
125         case BINEXPR_BITWISE_AND_ASSIGN: fputs("&=", out);    break;
126         case BINEXPR_BITWISE_XOR_ASSIGN: fputs("^=", out);    break;
127         case BINEXPR_SHIFTLEFT_ASSIGN:   fputs("<<=", out);   break;
128         case BINEXPR_SHIFTRIGHT_ASSIGN:  fputs(">>=", out);   break;
129         }
130         fprintf(out, " ");
131         print_expression(binexpr->right);
132         fprintf(out, ")");
133 }
134
135 static void print_unary_expression(const unary_expression_t *unexpr)
136 {
137         switch(unexpr->type) {
138         case UNEXPR_NEGATE:           fputs("-", out);  break;
139         case UNEXPR_PLUS:             fputs("+", out);  break;
140         case UNEXPR_NOT:              fputs("!", out);  break;
141         case UNEXPR_BITWISE_NEGATE:   fputs("~", out);  break;
142         case UNEXPR_PREFIX_INCREMENT: fputs("++", out); break;
143         case UNEXPR_PREFIX_DECREMENT: fputs("--", out); break;
144         case UNEXPR_DEREFERENCE:      fputs("*", out);  break;
145         case UNEXPR_TAKE_ADDRESS:     fputs("&", out);  break;
146
147         case UNEXPR_POSTFIX_INCREMENT:
148                 fputs("(", out);
149                 print_expression(unexpr->value);
150                 fputs(")", out);
151                 fputs("++", out);
152                 return;
153         case UNEXPR_POSTFIX_DECREMENT:
154                 fputs("(", out);
155                 print_expression(unexpr->value);
156                 fputs(")", out);
157                 fputs("--", out);
158                 return;
159         case UNEXPR_CAST:
160                 fputs("(", out);
161                 print_type(unexpr->expression.datatype);
162                 fputs(")", out);
163                 break;
164         case UNEXPR_INVALID:
165                 fprintf(out, "unop%d", (int) unexpr->type);
166                 break;
167         }
168         fputs("(", out);
169         print_expression(unexpr->value);
170         fputs(")", out);
171 }
172
173 static void print_reference_expression(const reference_expression_t *ref)
174 {
175         fprintf(out, "%s", ref->declaration->symbol->string);
176 }
177
178 static void print_array_expression(const array_access_expression_t *expression)
179 {
180         if(!expression->flipped) {
181                 fputs("(", out);
182                 print_expression(expression->array_ref);
183                 fputs(")[", out);
184                 print_expression(expression->index);
185                 fputs("]", out);
186         } else {
187                 fputs("(", out);
188                 print_expression(expression->index);
189                 fputs(")[", out);
190                 print_expression(expression->array_ref);
191                 fputs("]", out);
192         }
193 }
194
195 static void print_sizeof_expression(const sizeof_expression_t *expression)
196 {
197         fputs("sizeof", out);
198         if(expression->size_expression != NULL) {
199                 fputc('(', out);
200                 print_expression(expression->size_expression);
201                 fputc(')', out);
202         } else {
203                 fputc('(', out);
204                 print_type(expression->type);
205                 fputc(')', out);
206         }
207 }
208
209 static void print_builtin_symbol(const builtin_symbol_expression_t *expression)
210 {
211         fputs(expression->symbol->string, out);
212 }
213
214 static void print_conditional(const conditional_expression_t *expression)
215 {
216         fputs("(", out);
217         print_expression(expression->condition);
218         fputs(" ? ", out);
219         print_expression(expression->true_expression);
220         fputs(" : ", out);
221         print_expression(expression->false_expression);
222         fputs(")", out);
223 }
224
225 static void print_va_arg(const va_arg_expression_t *expression)
226 {
227         fputs("__builtin_va_arg(", out);
228         print_expression(expression->arg);
229         fputs(", ", out);
230         print_type(expression->expression.datatype);
231         fputs(")", out);
232 }
233
234 static void print_select(const select_expression_t *expression)
235 {
236         print_expression(expression->compound);
237         if(expression->compound->base.datatype == NULL ||
238                         expression->compound->base.datatype->type == TYPE_POINTER) {
239                 fputs("->", out);
240         } else {
241                 fputc('.', out);
242         }
243         fputs(expression->symbol->string, out);
244 }
245
246 static void print_classify_type_expression(
247         const classify_type_expression_t *const expr)
248 {
249         fputs("__builtin_classify_type(", out);
250         print_expression(expr->type_expression);
251         fputc(')', out);
252 }
253
254 void print_expression(const expression_t *expression)
255 {
256         switch(expression->type) {
257         case EXPR_UNKNOWN:
258         case EXPR_INVALID:
259                 fprintf(out, "*invalid expression*");
260                 break;
261         case EXPR_CONST:
262                 print_const(&expression->conste);
263                 break;
264         case EXPR_FUNCTION:
265         case EXPR_PRETTY_FUNCTION:
266         case EXPR_STRING_LITERAL:
267                 print_string_literal(&expression->string_literal);
268                 break;
269         case EXPR_CALL:
270                 print_call_expression((const call_expression_t*) expression);
271                 break;
272         case EXPR_BINARY:
273                 print_binary_expression((const binary_expression_t*) expression);
274                 break;
275         case EXPR_REFERENCE:
276                 print_reference_expression((const reference_expression_t*) expression);
277                 break;
278         case EXPR_ARRAY_ACCESS:
279                 print_array_expression((const array_access_expression_t*) expression);
280                 break;
281         case EXPR_UNARY:
282                 print_unary_expression((const unary_expression_t*) expression);
283                 break;
284         case EXPR_SIZEOF:
285                 print_sizeof_expression((const sizeof_expression_t*) expression);
286                 break;
287         case EXPR_BUILTIN_SYMBOL:
288                 print_builtin_symbol((const builtin_symbol_expression_t*) expression);
289                 break;
290         case EXPR_CONDITIONAL:
291                 print_conditional((const conditional_expression_t*) expression);
292                 break;
293         case EXPR_VA_ARG:
294                 print_va_arg((const va_arg_expression_t*) expression);
295                 break;
296         case EXPR_SELECT:
297                 print_select((const select_expression_t*) expression);
298                 break;
299         case EXPR_CLASSIFY_TYPE:
300                 print_classify_type_expression((const classify_type_expression_t*)expression);
301                 break;
302
303         case EXPR_OFFSETOF:
304         case EXPR_STATEMENT:
305                 /* TODO */
306                 fprintf(out, "some expression of type %d", (int) expression->type);
307                 break;
308         }
309 }
310
311 static void print_compound_statement(const compound_statement_t *block)
312 {
313         fputs("{\n", out);
314         indent++;
315
316         statement_t *statement = block->statements;
317         while(statement != NULL) {
318                 print_indent();
319                 print_statement(statement);
320
321                 statement = statement->base.next;
322         }
323         indent--;
324         print_indent();
325         fputs("}\n", out);
326 }
327
328 static void print_return_statement(const return_statement_t *statement)
329 {
330         fprintf(out, "return ");
331         if(statement->return_value != NULL)
332                 print_expression(statement->return_value);
333         fputs(";\n", out);
334 }
335
336 static void print_expression_statement(const expression_statement_t *statement)
337 {
338         print_expression(statement->expression);
339         fputs(";\n", out);
340 }
341
342 static void print_goto_statement(const goto_statement_t *statement)
343 {
344         fprintf(out, "goto ");
345         fputs(statement->label->symbol->string, out);
346         fprintf(stderr, "(%p)", (void*) statement->label);
347         fputs(";\n", out);
348 }
349
350 static void print_label_statement(const label_statement_t *statement)
351 {
352         fprintf(stderr, "(%p)", (void*) statement->label);
353         fprintf(out, "%s:\n", statement->label->symbol->string);
354         if(statement->label_statement != NULL) {
355                 print_statement(statement->label_statement);
356         }
357 }
358
359 static void print_if_statement(const if_statement_t *statement)
360 {
361         fputs("if(", out);
362         print_expression(statement->condition);
363         fputs(") ", out);
364         if(statement->true_statement != NULL) {
365                 print_statement(statement->true_statement);
366         }
367
368         if(statement->false_statement != NULL) {
369                 print_indent();
370                 fputs("else ", out);
371                 print_statement(statement->false_statement);
372         }
373 }
374
375 static void print_switch_statement(const switch_statement_t *statement)
376 {
377         fputs("switch(", out);
378         print_expression(statement->expression);
379         fputs(") ", out);
380         print_statement(statement->body);
381 }
382
383 static void print_case_label(const case_label_statement_t *statement)
384 {
385         if(statement->expression == NULL) {
386                 fputs("default:\n", out);
387         } else {
388                 fputs("case ", out);
389                 print_expression(statement->expression);
390                 fputs(":\n", out);
391         }
392         print_statement(statement->label_statement);
393 }
394
395 static void print_declaration_statement(
396                 const declaration_statement_t *statement)
397 {
398         int first = 1;
399         declaration_t *declaration = statement->declarations_begin;
400         for( ; declaration != statement->declarations_end->next;
401                declaration = declaration->next) {
402                 if(!first) {
403                         print_indent();
404                 } else {
405                         first = 0;
406                 }
407                 print_declaration(declaration);
408                 fputc('\n', out);
409         }
410 }
411
412 static void print_while_statement(const while_statement_t *statement)
413 {
414         fputs("while(", out);
415         print_expression(statement->condition);
416         fputs(") ", out);
417         print_statement(statement->body);
418 }
419
420 static void print_do_while_statement(const do_while_statement_t *statement)
421 {
422         fputs("do ", out);
423         print_statement(statement->body);
424         print_indent();
425         fputs("while(", out);
426         print_expression(statement->condition);
427         fputs(");\n", out);
428 }
429
430 static void print_for_statement(const for_statement_t *statement)
431 {
432         fputs("for(", out);
433         if(statement->context.declarations != NULL) {
434                 assert(statement->initialisation == NULL);
435                 print_declaration(statement->context.declarations);
436                 if(statement->context.declarations->next != NULL) {
437                         panic("multiple declarations in for statement not supported yet");
438                 }
439                 fputc(' ', out);
440         } else {
441                 if(statement->initialisation) {
442                         print_expression(statement->initialisation);
443                 }
444                 fputs("; ", out);
445         }
446         if(statement->condition != NULL) {
447                 print_expression(statement->condition);
448         }
449         fputs("; ", out);
450         if(statement->step != NULL) {
451                 print_expression(statement->step);
452         }
453         fputs(")", out);
454         print_statement(statement->body);
455 }
456
457 void print_statement(const statement_t *statement)
458 {
459         switch(statement->type) {
460         case STATEMENT_COMPOUND:
461                 print_compound_statement((const compound_statement_t*) statement);
462                 break;
463         case STATEMENT_RETURN:
464                 print_return_statement((const return_statement_t*) statement);
465                 break;
466         case STATEMENT_EXPRESSION:
467                 print_expression_statement((const expression_statement_t*) statement);
468                 break;
469         case STATEMENT_LABEL:
470                 print_label_statement((const label_statement_t*) statement);
471                 break;
472         case STATEMENT_GOTO:
473                 print_goto_statement((const goto_statement_t*) statement);
474                 break;
475         case STATEMENT_CONTINUE:
476                 fputs("continue;\n", out);
477                 break;
478         case STATEMENT_BREAK:
479                 fputs("break;\n", out);
480                 break;
481         case STATEMENT_IF:
482                 print_if_statement((const if_statement_t*) statement);
483                 break;
484         case STATEMENT_SWITCH:
485                 print_switch_statement((const switch_statement_t*) statement);
486                 break;
487         case STATEMENT_CASE_LABEL:
488                 print_case_label((const case_label_statement_t*) statement);
489                 break;
490         case STATEMENT_DECLARATION:
491                 print_declaration_statement((const declaration_statement_t*) statement);
492                 break;
493         case STATEMENT_WHILE:
494                 print_while_statement((const while_statement_t*) statement);
495                 break;
496         case STATEMENT_DO_WHILE:
497                 print_do_while_statement((const do_while_statement_t*) statement);
498                 break;
499         case STATEMENT_FOR:
500                 print_for_statement((const for_statement_t*) statement);
501                 break;
502         case STATEMENT_INVALID:
503                 fprintf(out, "*invalid statement*");
504                 break;
505         }
506 }
507
508 static void print_storage_class(storage_class_t storage_class)
509 {
510         switch(storage_class) {
511         case STORAGE_CLASS_ENUM_ENTRY:
512         case STORAGE_CLASS_NONE:
513                 break;
514         case STORAGE_CLASS_TYPEDEF:  fputs("typedef ", out); break;
515         case STORAGE_CLASS_EXTERN:   fputs("extern ", out); break;
516         case STORAGE_CLASS_STATIC:   fputs("static ", out); break;
517         case STORAGE_CLASS_AUTO:     fputs("auto ", out); break;
518         case STORAGE_CLASS_REGISTER: fputs("register ", out); break;
519         }
520 }
521
522 void print_initializer(const initializer_t *initializer)
523 {
524         if(initializer->type == INITIALIZER_VALUE) {
525                 const initializer_value_t *value = &initializer->value;
526                 print_expression(value->value);
527                 return;
528         }
529
530         assert(initializer->type == INITIALIZER_LIST);
531         fputs("{ ", out);
532         const initializer_list_t *list = &initializer->list;
533
534         for(size_t i = 0 ; i < list->len; ++i) {
535                 if(i > 0) {
536                         fputs(", ", out);
537                 }
538                 print_initializer(list->initializers[i]);
539         }
540         fputs("}", out);
541 }
542
543 static void print_normal_declaration(const declaration_t *declaration)
544 {
545         print_storage_class((storage_class_t)declaration->storage_class);
546         print_type_ext(declaration->type, declaration->symbol,
547                        &declaration->context);
548         if(declaration->is_inline) {
549                 fputs("inline ", out);
550         }
551
552         if(declaration->type->type == TYPE_FUNCTION) {
553                 if(declaration->init.statement != NULL) {
554                         fputs("\n", out);
555                         print_statement(declaration->init.statement);
556                         return;
557                 }
558         } else if(declaration->init.initializer != NULL) {
559                 fputs(" = ", out);
560                 print_initializer(declaration->init.initializer);
561         }
562         fputc(';', out);
563 }
564
565 void print_declaration(const declaration_t *declaration)
566 {
567         if(declaration->namespc != NAMESPACE_NORMAL &&
568                         declaration->symbol == NULL)
569                 return;
570
571         switch(declaration->namespc) {
572         case NAMESPACE_NORMAL:
573                 print_normal_declaration(declaration);
574                 break;
575         case NAMESPACE_STRUCT:
576                 fputs("struct ", out);
577                 fputs(declaration->symbol->string, out);
578                 fputc(' ', out);
579                 print_compound_definition(declaration);
580                 fputc(';', out);
581                 break;
582         case NAMESPACE_UNION:
583                 fputs("union ", out);
584                 fputs(declaration->symbol->string, out);
585                 fputc(' ', out);
586                 print_compound_definition(declaration);
587                 fputc(';', out);
588                 break;
589         case NAMESPACE_ENUM:
590                 fputs("enum ", out);
591                 fputs(declaration->symbol->string, out);
592                 fputc(' ', out);
593                 print_enum_definition(declaration);
594                 fputc(';', out);
595                 break;
596         }
597 }
598
599 void print_ast(const translation_unit_t *unit)
600 {
601         inc_type_visited();
602         set_print_compound_entries(true);
603
604         declaration_t *declaration = unit->context.declarations;
605         for( ; declaration != NULL; declaration = declaration->next) {
606                 if(declaration->storage_class == STORAGE_CLASS_ENUM_ENTRY)
607                         continue;
608                 if(declaration->namespc != NAMESPACE_NORMAL &&
609                                 declaration->symbol == NULL)
610                         continue;
611
612                 print_indent();
613                 print_declaration(declaration);
614                 fputc('\n', out);
615         }
616 }
617
618 void init_ast(void)
619 {
620         obstack_init(&ast_obstack);
621 }
622
623 void exit_ast(void)
624 {
625         obstack_free(&ast_obstack, NULL);
626 }
627
628 void ast_set_output(FILE *stream)
629 {
630         out = stream;
631         type_set_output(stream);
632 }
633
634 void* (allocate_ast) (size_t size)
635 {
636         return _allocate_ast(size);
637 }