improve initializer handling
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <ctype.h>
10
11 #include "adt/error.h"
12
13 struct obstack ast_obstack;
14
15 static FILE *out;
16 static int   indent;
17
18 static void print_statement(const statement_t *statement);
19 static void print_declaration(const declaration_t *declaration);
20
21 static void print_indent(void)
22 {
23         for(int i = 0; i < indent; ++i)
24                 fprintf(out, "\t");
25 }
26
27 static void print_const(const const_t *cnst)
28 {
29         if(cnst->expression.datatype == NULL)
30                 return;
31
32         if(is_type_integer(cnst->expression.datatype)) {
33                 fprintf(out, "%d", cnst->v.int_value);
34         } else if(is_type_floating(cnst->expression.datatype)) {
35                 fprintf(out, "%Lf", cnst->v.float_value);
36         }
37 }
38
39 static void print_string_literal(const string_literal_t *string_literal)
40 {
41         fputc('"', out);
42         for(const char *c = string_literal->value; *c != '\0'; ++c) {
43                 switch(*c) {
44                 case '\"':  fputs("\\\"", out); break;
45                 case '\\':  fputs("\\\\", out); break;
46                 case '\a':  fputs("\\a", out); break;
47                 case '\b':  fputs("\\b", out); break;
48                 case '\f':  fputs("\\f", out); break;
49                 case '\n':  fputs("\\n", out); break;
50                 case '\r':  fputs("\\r", out); break;
51                 case '\t':  fputs("\\t", out); break;
52                 case '\v':  fputs("\\v", out); break;
53                 case '\?':  fputs("\\?", out); break;
54                 default:
55                         if(!isprint(*c)) {
56                                 fprintf(out, "\\x%x", *c);
57                                 break;
58                         }
59                         fputc(*c, out);
60                         break;
61                 }
62         }
63         fputc('"', out);
64 }
65
66 static void print_call_expression(const call_expression_t *call)
67 {
68         print_expression(call->method);
69         fprintf(out, "(");
70         call_argument_t *argument = call->arguments;
71         int              first    = 1;
72         while(argument != NULL) {
73                 if(!first) {
74                         fprintf(out, ", ");
75                 } else {
76                         first = 0;
77                 }
78                 print_expression(argument->expression);
79
80                 argument = argument->next;
81         }
82         fprintf(out, ")");
83 }
84
85 static void print_binary_expression(const binary_expression_t *binexpr)
86 {
87         fprintf(out, "(");
88         print_expression(binexpr->left);
89         fprintf(out, " ");
90         switch(binexpr->type) {
91         case BINEXPR_INVALID:            fputs("INVOP", out); break;
92         case BINEXPR_COMMA:              fputs(",", out);     break;
93         case BINEXPR_ASSIGN:             fputs("=", out);     break;
94         case BINEXPR_ADD:                fputs("+", out);     break;
95         case BINEXPR_SUB:                fputs("-", out);     break;
96         case BINEXPR_MUL:                fputs("*", out);     break;
97         case BINEXPR_MOD:                fputs("%", out);     break;
98         case BINEXPR_DIV:                fputs("/", out);     break;
99         case BINEXPR_BITWISE_OR:         fputs("|", out);     break;
100         case BINEXPR_BITWISE_AND:        fputs("&", out);     break;
101         case BINEXPR_BITWISE_XOR:        fputs("^", out);     break;
102         case BINEXPR_LOGICAL_OR:         fputs("||", out);    break;
103         case BINEXPR_LOGICAL_AND:        fputs("&&", out);    break;
104         case BINEXPR_NOTEQUAL:           fputs("!=", out);    break;
105         case BINEXPR_EQUAL:              fputs("==", out);    break;
106         case BINEXPR_LESS:               fputs("<", out);     break;
107         case BINEXPR_LESSEQUAL:          fputs("<=", out);    break;
108         case BINEXPR_GREATER:            fputs(">", out);     break;
109         case BINEXPR_GREATEREQUAL:       fputs(">=", out);    break;
110         case BINEXPR_SHIFTLEFT:          fputs("<<", out);    break;
111         case BINEXPR_SHIFTRIGHT:         fputs(">>", out);    break;
112
113         case BINEXPR_ADD_ASSIGN:         fputs("+=", out);    break;
114         case BINEXPR_SUB_ASSIGN:         fputs("-=", out);    break;
115         case BINEXPR_MUL_ASSIGN:         fputs("*=", out);    break;
116         case BINEXPR_MOD_ASSIGN:         fputs("%=", out);    break;
117         case BINEXPR_DIV_ASSIGN:         fputs("/=", out);    break;
118         case BINEXPR_BITWISE_OR_ASSIGN:  fputs("|=", out);    break;
119         case BINEXPR_BITWISE_AND_ASSIGN: fputs("&=", out);    break;
120         case BINEXPR_BITWISE_XOR_ASSIGN: fputs("^=", out);    break;
121         case BINEXPR_SHIFTLEFT_ASSIGN:   fputs("<<=", out);   break;
122         case BINEXPR_SHIFTRIGHT_ASSIGN:  fputs(">>=", out);   break;
123         }
124         fprintf(out, " ");
125         print_expression(binexpr->right);
126         fprintf(out, ")");
127 }
128
129 static void print_unary_expression(const unary_expression_t *unexpr)
130 {
131         switch(unexpr->type) {
132         case UNEXPR_NEGATE:           fputs("-", out);  break;
133         case UNEXPR_PLUS:             fputs("+", out);  break;
134         case UNEXPR_NOT:              fputs("!", out);  break;
135         case UNEXPR_BITWISE_NEGATE:   fputs("~", out);  break;
136         case UNEXPR_PREFIX_INCREMENT: fputs("++", out); break;
137         case UNEXPR_PREFIX_DECREMENT: fputs("--", out); break;
138         case UNEXPR_DEREFERENCE:      fputs("*", out);  break;
139         case UNEXPR_TAKE_ADDRESS:     fputs("&", out);  break;
140
141         case UNEXPR_POSTFIX_INCREMENT:
142                 fputs("(", out);
143                 print_expression(unexpr->value);
144                 fputs(")", out);
145                 fputs("++", out);
146                 return;
147         case UNEXPR_POSTFIX_DECREMENT:
148                 fputs("(", out);
149                 print_expression(unexpr->value);
150                 fputs(")", out);
151                 fputs("--", out);
152                 return;
153         case UNEXPR_CAST:
154                 fputs("(", out);
155                 print_type(unexpr->expression.datatype);
156                 fputs(")", out);
157                 break;
158         case UNEXPR_INVALID:
159                 fprintf(out, "unop%d", unexpr->type);
160                 break;
161         }
162         fputs("(", out);
163         print_expression(unexpr->value);
164         fputs(")", out);
165 }
166
167 static void print_reference_expression(const reference_expression_t *ref)
168 {
169         fprintf(out, "%s", ref->declaration->symbol->string);
170 }
171
172 static void print_array_expression(const array_access_expression_t *expression)
173 {
174         fputs("(", out);
175         print_expression(expression->array_ref);
176         fputs(")[", out);
177         print_expression(expression->index);
178         fputs("]", out);
179 }
180
181 static void print_sizeof_expression(const sizeof_expression_t *expression)
182 {
183         fputs("sizeof", out);
184         if(expression->size_expression != NULL) {
185                 fputc('(', out);
186                 print_expression(expression->size_expression);
187                 fputc(')', out);
188         } else {
189                 fputc('(', out);
190                 print_type(expression->type);
191                 fputc(')', out);
192         }
193 }
194
195 static void print_builtin_symbol(const builtin_symbol_expression_t *expression)
196 {
197         fputs(expression->symbol->string, out);
198 }
199
200 static void print_conditional(const conditional_expression_t *expression)
201 {
202         fputs("(", out);
203         print_expression(expression->condition);
204         fputs(" ? ", out);
205         print_expression(expression->true_expression);
206         fputs(" : ", out);
207         print_expression(expression->false_expression);
208         fputs(")", out);
209 }
210
211 static void print_va_arg(const va_arg_expression_t *expression)
212 {
213         fputs("__builtin_va_arg(", out);
214         print_expression(expression->arg);
215         fputs(", ", out);
216         print_type(expression->expression.datatype);
217         fputs(")", out);
218 }
219
220 void print_expression(const expression_t *expression)
221 {
222         switch(expression->type) {
223         case EXPR_INVALID:
224                 fprintf(out, "*invalid expression*");
225                 break;
226         case EXPR_CONST:
227                 print_const((const const_t*) expression);
228                 break;
229         case EXPR_FUNCTION:
230         case EXPR_PRETTY_FUNCTION:
231         case EXPR_STRING_LITERAL:
232                 print_string_literal((const string_literal_t*) expression);
233                 break;
234         case EXPR_CALL:
235                 print_call_expression((const call_expression_t*) expression);
236                 break;
237         case EXPR_BINARY:
238                 print_binary_expression((const binary_expression_t*) expression);
239                 break;
240         case EXPR_REFERENCE:
241                 print_reference_expression((const reference_expression_t*) expression);
242                 break;
243         case EXPR_ARRAY_ACCESS:
244                 print_array_expression((const array_access_expression_t*) expression);
245                 break;
246         case EXPR_UNARY:
247                 print_unary_expression((const unary_expression_t*) expression);
248                 break;
249         case EXPR_SIZEOF:
250                 print_sizeof_expression((const sizeof_expression_t*) expression);
251                 break;
252         case EXPR_BUILTIN_SYMBOL:
253                 print_builtin_symbol((const builtin_symbol_expression_t*) expression);
254                 break;
255         case EXPR_CONDITIONAL:
256                 print_conditional((const conditional_expression_t*) expression);
257                 break;
258         case EXPR_VA_ARG:
259                 print_va_arg((const va_arg_expression_t*) expression);
260                 break;
261
262         case EXPR_OFFSETOF:
263         case EXPR_STATEMENT:
264         case EXPR_SELECT:
265                 /* TODO */
266                 fprintf(out, "some expression of type %d", expression->type);
267                 break;
268         }
269 }
270
271 static void print_compound_statement(const compound_statement_t *block)
272 {
273         fputs("{\n", out);
274         indent++;
275
276         statement_t *statement = block->statements;
277         while(statement != NULL) {
278                 print_indent();
279                 print_statement(statement);
280
281                 statement = statement->next;
282         }
283         indent--;
284         print_indent();
285         fputs("}\n", out);
286 }
287
288 static void print_return_statement(const return_statement_t *statement)
289 {
290         fprintf(out, "return ");
291         if(statement->return_value != NULL)
292                 print_expression(statement->return_value);
293         fputs(";\n", out);
294 }
295
296 static void print_expression_statement(const expression_statement_t *statement)
297 {
298         print_expression(statement->expression);
299         fputs(";\n", out);
300 }
301
302 static void print_goto_statement(const goto_statement_t *statement)
303 {
304         fprintf(out, "goto ");
305         if(statement->label != NULL) {
306                 fprintf(out, "%s", statement->label->symbol->string);
307         } else {
308                 fprintf(out, "?%s", statement->label_symbol->string);
309         }
310         fputs(";\n", out);
311 }
312
313 static void print_label_statement(const label_statement_t *statement)
314 {
315         fprintf(out, "%s:\n", statement->symbol->string);
316 }
317
318 static void print_if_statement(const if_statement_t *statement)
319 {
320         fputs("if(", out);
321         print_expression(statement->condition);
322         fputs(") ", out);
323         if(statement->true_statement != NULL) {
324                 print_statement(statement->true_statement);
325         }
326
327         if(statement->false_statement != NULL) {
328                 print_indent();
329                 fputs("else ", out);
330                 print_statement(statement->false_statement);
331         }
332 }
333
334 static void print_switch_statement(const switch_statement_t *statement)
335 {
336         fputs("switch(", out);
337         print_expression(statement->expression);
338         fputs(") ", out);
339         print_statement(statement->body);
340 }
341
342 static void print_case_label(const case_label_statement_t *statement)
343 {
344         if(statement->expression == NULL) {
345                 fputs("default:\n", out);
346         } else {
347                 fputs("case ", out);
348                 print_expression(statement->expression);
349                 fputs(":\n", out);
350         }
351 }
352
353 static void print_declaration_statement(
354                 const declaration_statement_t *statement)
355 {
356         declaration_t *declaration = statement->declarations_begin;
357         for( ; declaration != statement->declarations_end->next;
358                declaration = declaration->next) {
359                 print_declaration(declaration);
360         }
361 }
362
363 static void print_while_statement(const while_statement_t *statement)
364 {
365         fputs("while(", out);
366         print_expression(statement->condition);
367         fputs(") ", out);
368         print_statement(statement->body);
369 }
370
371 static void print_do_while_statement(const do_while_statement_t *statement)
372 {
373         fputs("do ", out);
374         print_statement(statement->body);
375         print_indent();
376         fputs("while(", out);
377         print_expression(statement->condition);
378         fputs(");\n", out);
379 }
380
381 static void print_for_statement(const for_statement_t *statement)
382 {
383         fputs("for(", out);
384         if(statement->context.declarations != NULL) {
385                 assert(statement->initialisation == NULL);
386                 print_declaration(statement->context.declarations);
387                 if(statement->context.declarations->next != NULL) {
388                         panic("multiple declarations in for statement not supported yet");
389                 }
390         } else if(statement->initialisation) {
391                 print_expression(statement->initialisation);
392         }
393         fputs("; ", out);
394         if(statement->condition != NULL) {
395                 print_expression(statement->condition);
396         }
397         fputs("; ", out);
398         if(statement->step != NULL) {
399                 print_expression(statement->step);
400         }
401         fputs(")", out);
402         print_statement(statement->body);
403 }
404
405 void print_statement(const statement_t *statement)
406 {
407         switch(statement->type) {
408         case STATEMENT_COMPOUND:
409                 print_compound_statement((const compound_statement_t*) statement);
410                 break;
411         case STATEMENT_RETURN:
412                 print_return_statement((const return_statement_t*) statement);
413                 break;
414         case STATEMENT_EXPRESSION:
415                 print_expression_statement((const expression_statement_t*) statement);
416                 break;
417         case STATEMENT_LABEL:
418                 print_label_statement((const label_statement_t*) statement);
419                 break;
420         case STATEMENT_GOTO:
421                 print_goto_statement((const goto_statement_t*) statement);
422                 break;
423         case STATEMENT_CONTINUE:
424                 fputs("continue;\n", out);
425                 break;
426         case STATEMENT_BREAK:
427                 fputs("break;\n", out);
428                 break;
429         case STATEMENT_IF:
430                 print_if_statement((const if_statement_t*) statement);
431                 break;
432         case STATEMENT_SWITCH:
433                 print_switch_statement((const switch_statement_t*) statement);
434                 break;
435         case STATEMENT_CASE_LABEL:
436                 print_case_label((const case_label_statement_t*) statement);
437                 break;
438         case STATEMENT_DECLARATION:
439                 print_declaration_statement((const declaration_statement_t*) statement);
440                 break;
441         case STATEMENT_WHILE:
442                 print_while_statement((const while_statement_t*) statement);
443                 break;
444         case STATEMENT_DO_WHILE:
445                 print_do_while_statement((const do_while_statement_t*) statement);
446                 break;
447         case STATEMENT_FOR:
448                 print_for_statement((const for_statement_t*) statement);
449                 break;
450         case STATEMENT_INVALID:
451                 fprintf(out, "*invalid statement*");
452                 break;
453         }
454 }
455
456 static void print_storage_class(storage_class_t storage_class)
457 {
458         switch(storage_class) {
459         case STORAGE_CLASS_ENUM_ENTRY:
460         case STORAGE_CLASS_NONE:
461                 break;
462         case STORAGE_CLASS_TYPEDEF:  fputs("typedef ", out); break;
463         case STORAGE_CLASS_EXTERN:   fputs("extern ", out); break;
464         case STORAGE_CLASS_STATIC:   fputs("static ", out); break;
465         case STORAGE_CLASS_AUTO:     fputs("auto ", out); break;
466         case STORAGE_CLASS_REGISTER: fputs("register ", out); break;
467         }
468 }
469
470 void print_initializer(const initializer_t *initializer)
471 {
472         if(initializer->type == INITIALIZER_VALUE) {
473                 print_expression(initializer->v.value);
474                 return;
475         }
476
477         assert(initializer->type == INITIALIZER_LIST);
478         fputs("{ ", out);
479         initializer_t *iter = initializer->v.list;
480         for( ; iter != NULL; iter = iter->next) {
481                 print_initializer(iter);
482                 if(iter->next != NULL) {
483                         fputs(", ", out);
484                 }
485         }
486         fputs("}", out);
487 }
488
489 static void print_declaration(const declaration_t *declaration)
490 {
491         print_storage_class(declaration->storage_class);
492         print_type_ext(declaration->type, declaration->symbol,
493                        &declaration->context);
494         if(declaration->statement != NULL) {
495                 fputs("\n", out);
496                 print_statement(declaration->statement);
497         } else if(declaration->initializer != NULL) {
498                 fputs(" = ", out);
499                 print_initializer(declaration->initializer);
500                 fprintf(out, ";\n");
501         } else {
502                 fprintf(out, ";\n");
503         }
504 }
505
506 void print_ast(const translation_unit_t *unit)
507 {
508         declaration_t *declaration = unit->context.declarations;
509         while(declaration != NULL) {
510                 print_declaration(declaration);
511
512                 declaration = declaration->next;
513         }
514 }
515
516 void init_ast(void)
517 {
518         obstack_init(&ast_obstack);
519 }
520
521 void exit_ast(void)
522 {
523         obstack_free(&ast_obstack, NULL);
524 }
525
526 void ast_set_output(FILE *stream)
527 {
528         out = stream;
529         type_set_output(stream);
530 }
531
532 void* (allocate_ast) (size_t size)
533 {
534         return _allocate_ast(size);
535 }