5fe99180fd6d6983bf44751a6828b52a82884a5f
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <ctype.h>
10
11 #include "adt/error.h"
12
13 struct obstack ast_obstack;
14
15 static FILE *out;
16 static int   indent;
17
18 static void print_statement(const statement_t *statement);
19 static void print_declaration(const declaration_t *declaration);
20
21 static void print_indent(void)
22 {
23         for(int i = 0; i < indent; ++i)
24                 fprintf(out, "\t");
25 }
26
27 static void print_const(const const_t *cnst)
28 {
29         if(cnst->expression.datatype == NULL)
30                 return;
31
32         if(is_type_integer(cnst->expression.datatype)) {
33                 fprintf(out, "%d", cnst->v.int_value);
34         } else if(is_type_floating(cnst->expression.datatype)) {
35                 fprintf(out, "%Lf", cnst->v.float_value);
36         }
37 }
38
39 static void print_string_literal(const string_literal_t *string_literal)
40 {
41         fputc('"', out);
42         for(const char *c = string_literal->value; *c != '\0'; ++c) {
43                 switch(*c) {
44                 case '\"':  fputs("\\\"", out); break;
45                 case '\\':  fputs("\\\\", out); break;
46                 case '\a':  fputs("\\a", out); break;
47                 case '\b':  fputs("\\b", out); break;
48                 case '\f':  fputs("\\f", out); break;
49                 case '\n':  fputs("\\n", out); break;
50                 case '\r':  fputs("\\r", out); break;
51                 case '\t':  fputs("\\t", out); break;
52                 case '\v':  fputs("\\v", out); break;
53                 case '\?':  fputs("\\?", out); break;
54                 default:
55                         if(!isprint(*c)) {
56                                 fprintf(out, "\\x%x", *c);
57                                 break;
58                         }
59                         fputc(*c, out);
60                         break;
61                 }
62         }
63         fputc('"', out);
64 }
65
66 static void print_call_expression(const call_expression_t *call)
67 {
68         print_expression(call->method);
69         fprintf(out, "(");
70         call_argument_t *argument = call->arguments;
71         int              first    = 1;
72         while(argument != NULL) {
73                 if(!first) {
74                         fprintf(out, ", ");
75                 } else {
76                         first = 0;
77                 }
78                 print_expression(argument->expression);
79
80                 argument = argument->next;
81         }
82         fprintf(out, ")");
83 }
84
85 static void print_binary_expression(const binary_expression_t *binexpr)
86 {
87         fprintf(out, "(");
88         print_expression(binexpr->left);
89         fprintf(out, " ");
90         switch(binexpr->type) {
91         case BINEXPR_INVALID:            fputs("INVOP", out); break;
92         case BINEXPR_COMMA:              fputs(",", out);     break;
93         case BINEXPR_ASSIGN:             fputs("=", out);     break;
94         case BINEXPR_ADD:                fputs("+", out);     break;
95         case BINEXPR_SUB:                fputs("-", out);     break;
96         case BINEXPR_MUL:                fputs("*", out);     break;
97         case BINEXPR_MOD:                fputs("%", out);     break;
98         case BINEXPR_DIV:                fputs("/", out);     break;
99         case BINEXPR_BITWISE_OR:         fputs("|", out);     break;
100         case BINEXPR_BITWISE_AND:        fputs("&", out);     break;
101         case BINEXPR_BITWISE_XOR:        fputs("^", out);     break;
102         case BINEXPR_LOGICAL_OR:         fputs("||", out);    break;
103         case BINEXPR_LOGICAL_AND:        fputs("&&", out);    break;
104         case BINEXPR_NOTEQUAL:           fputs("!=", out);    break;
105         case BINEXPR_EQUAL:              fputs("==", out);    break;
106         case BINEXPR_LESS:               fputs("<", out);     break;
107         case BINEXPR_LESSEQUAL:          fputs("<=", out);    break;
108         case BINEXPR_GREATER:            fputs(">", out);     break;
109         case BINEXPR_GREATEREQUAL:       fputs(">=", out);    break;
110         case BINEXPR_SHIFTLEFT:          fputs("<<", out);    break;
111         case BINEXPR_SHIFTRIGHT:         fputs(">>", out);    break;
112
113         case BINEXPR_ADD_ASSIGN:         fputs("+=", out);    break;
114         case BINEXPR_SUB_ASSIGN:         fputs("-=", out);    break;
115         case BINEXPR_MUL_ASSIGN:         fputs("*=", out);    break;
116         case BINEXPR_MOD_ASSIGN:         fputs("%=", out);    break;
117         case BINEXPR_DIV_ASSIGN:         fputs("/=", out);    break;
118         case BINEXPR_BITWISE_OR_ASSIGN:  fputs("|=", out);    break;
119         case BINEXPR_BITWISE_AND_ASSIGN: fputs("&=", out);    break;
120         case BINEXPR_BITWISE_XOR_ASSIGN: fputs("^=", out);    break;
121         case BINEXPR_SHIFTLEFT_ASSIGN:   fputs("<<=", out);   break;
122         case BINEXPR_SHIFTRIGHT_ASSIGN:  fputs(">>=", out);   break;
123         }
124         fprintf(out, " ");
125         print_expression(binexpr->right);
126         fprintf(out, ")");
127 }
128
129 static void print_unary_expression(const unary_expression_t *unexpr)
130 {
131         switch(unexpr->type) {
132         case UNEXPR_NEGATE:           fputs("-", out);  break;
133         case UNEXPR_PLUS:             fputs("+", out);  break;
134         case UNEXPR_NOT:              fputs("!", out);  break;
135         case UNEXPR_BITWISE_NEGATE:   fputs("~", out);  break;
136         case UNEXPR_PREFIX_INCREMENT: fputs("++", out); break;
137         case UNEXPR_PREFIX_DECREMENT: fputs("--", out); break;
138         case UNEXPR_DEREFERENCE:      fputs("*", out);  break;
139         case UNEXPR_TAKE_ADDRESS:     fputs("&", out);  break;
140
141         case UNEXPR_POSTFIX_INCREMENT:
142                 fputs("(", out);
143                 print_expression(unexpr->value);
144                 fputs(")", out);
145                 fputs("++", out);
146                 return;
147         case UNEXPR_POSTFIX_DECREMENT:
148                 fputs("(", out);
149                 print_expression(unexpr->value);
150                 fputs(")", out);
151                 fputs("--", out);
152                 return;
153         case UNEXPR_CAST:
154                 fputs("(", out);
155                 print_type(unexpr->expression.datatype);
156                 fputs(")", out);
157                 break;
158         case UNEXPR_INVALID:
159                 fprintf(out, "unop%d", unexpr->type);
160                 break;
161         }
162         fputs("(", out);
163         print_expression(unexpr->value);
164         fputs(")", out);
165 }
166
167 static void print_reference_expression(const reference_expression_t *ref)
168 {
169         fprintf(out, "%s", ref->declaration->symbol->string);
170 }
171
172 static void print_array_expression(const array_access_expression_t *expression)
173 {
174         fputs("(", out);
175         print_expression(expression->array_ref);
176         fputs(")[", out);
177         print_expression(expression->index);
178         fputs("]", out);
179 }
180
181 static void print_sizeof_expression(const sizeof_expression_t *expression)
182 {
183         fputs("sizeof", out);
184         if(expression->size_expression != NULL) {
185                 fputc('(', out);
186                 print_expression(expression->size_expression);
187                 fputc(')', out);
188         } else {
189                 fputc('(', out);
190                 print_type(expression->type);
191                 fputc(')', out);
192         }
193 }
194
195 static void print_builtin_symbol(const builtin_symbol_expression_t *expression)
196 {
197         fputs(expression->symbol->string, out);
198 }
199
200 static void print_conditional(const conditional_expression_t *expression)
201 {
202         fputs("(", out);
203         print_expression(expression->condition);
204         fputs(" ? ", out);
205         print_expression(expression->true_expression);
206         fputs(" : ", out);
207         print_expression(expression->false_expression);
208         fputs(")", out);
209 }
210
211 void print_expression(const expression_t *expression)
212 {
213         switch(expression->type) {
214         case EXPR_INVALID:
215                 fprintf(out, "*invalid expression*");
216                 break;
217         case EXPR_CONST:
218                 print_const((const const_t*) expression);
219                 break;
220         case EXPR_FUNCTION:
221         case EXPR_PRETTY_FUNCTION:
222         case EXPR_STRING_LITERAL:
223                 print_string_literal((const string_literal_t*) expression);
224                 break;
225         case EXPR_CALL:
226                 print_call_expression((const call_expression_t*) expression);
227                 break;
228         case EXPR_BINARY:
229                 print_binary_expression((const binary_expression_t*) expression);
230                 break;
231         case EXPR_REFERENCE:
232                 print_reference_expression((const reference_expression_t*) expression);
233                 break;
234         case EXPR_ARRAY_ACCESS:
235                 print_array_expression((const array_access_expression_t*) expression);
236                 break;
237         case EXPR_UNARY:
238                 print_unary_expression((const unary_expression_t*) expression);
239                 break;
240         case EXPR_SIZEOF:
241                 print_sizeof_expression((const sizeof_expression_t*) expression);
242                 break;
243         case EXPR_BUILTIN_SYMBOL:
244                 print_builtin_symbol((const builtin_symbol_expression_t*) expression);
245                 break;
246         case EXPR_CONDITIONAL:
247                 print_conditional((const conditional_expression_t*) expression);
248                 break;
249
250         case EXPR_OFFSETOF:
251         case EXPR_STATEMENT:
252         case EXPR_SELECT:
253                 /* TODO */
254                 fprintf(out, "some expression of type %d", expression->type);
255                 break;
256         }
257 }
258
259 static void print_compound_statement(const compound_statement_t *block)
260 {
261         fputs("{\n", out);
262         indent++;
263
264         statement_t *statement = block->statements;
265         while(statement != NULL) {
266                 print_indent();
267                 print_statement(statement);
268
269                 statement = statement->next;
270         }
271         indent--;
272         print_indent();
273         fputs("}\n", out);
274 }
275
276 static void print_return_statement(const return_statement_t *statement)
277 {
278         fprintf(out, "return ");
279         if(statement->return_value != NULL)
280                 print_expression(statement->return_value);
281         fputs(";\n", out);
282 }
283
284 static void print_expression_statement(const expression_statement_t *statement)
285 {
286         print_expression(statement->expression);
287         fputs(";\n", out);
288 }
289
290 static void print_goto_statement(const goto_statement_t *statement)
291 {
292         fprintf(out, "goto ");
293         if(statement->label != NULL) {
294                 fprintf(out, "%s", statement->label->symbol->string);
295         } else {
296                 fprintf(out, "?%s", statement->label_symbol->string);
297         }
298         fputs(";\n", out);
299 }
300
301 static void print_label_statement(const label_statement_t *statement)
302 {
303         fprintf(out, "%s:\n", statement->symbol->string);
304 }
305
306 static void print_if_statement(const if_statement_t *statement)
307 {
308         fputs("if(", out);
309         print_expression(statement->condition);
310         fputs(") ", out);
311         if(statement->true_statement != NULL) {
312                 print_statement(statement->true_statement);
313         }
314
315         if(statement->false_statement != NULL) {
316                 print_indent();
317                 fputs("else ", out);
318                 print_statement(statement->false_statement);
319         }
320 }
321
322 static void print_switch_statement(const switch_statement_t *statement)
323 {
324         fputs("switch(", out);
325         print_expression(statement->expression);
326         fputs(") ", out);
327         print_statement(statement->body);
328 }
329
330 static void print_case_label(const case_label_statement_t *statement)
331 {
332         if(statement->expression == NULL) {
333                 fputs("default:\n", out);
334         } else {
335                 fputs("case ", out);
336                 print_expression(statement->expression);
337                 fputs(":\n", out);
338         }
339 }
340
341 static void print_declaration_statement(
342                 const declaration_statement_t *statement)
343 {
344         declaration_t *declaration = statement->declarations_begin;
345         for( ; declaration != statement->declarations_end->next;
346                declaration = declaration->next) {
347                 print_declaration(declaration);
348         }
349 }
350
351 static void print_while_statement(const while_statement_t *statement)
352 {
353         fputs("while(", out);
354         print_expression(statement->condition);
355         fputs(") ", out);
356         print_statement(statement->body);
357 }
358
359 static void print_do_while_statement(const do_while_statement_t *statement)
360 {
361         fputs("do ", out);
362         print_statement(statement->body);
363         print_indent();
364         fputs("while(", out);
365         print_expression(statement->condition);
366         fputs(");\n", out);
367 }
368
369 static void print_for_statement(const for_statement_t *statement)
370 {
371         fputs("for(", out);
372         if(statement->context.declarations != NULL) {
373                 assert(statement->initialisation == NULL);
374                 print_declaration(statement->context.declarations);
375                 if(statement->context.declarations->next != NULL) {
376                         panic("multiple declarations in for statement not supported yet");
377                 }
378         } else if(statement->initialisation) {
379                 print_expression(statement->initialisation);
380         }
381         fputs("; ", out);
382         if(statement->condition != NULL) {
383                 print_expression(statement->condition);
384         }
385         fputs("; ", out);
386         if(statement->step != NULL) {
387                 print_expression(statement->step);
388         }
389         fputs(")", out);
390         print_statement(statement->body);
391 }
392
393 void print_statement(const statement_t *statement)
394 {
395         switch(statement->type) {
396         case STATEMENT_COMPOUND:
397                 print_compound_statement((const compound_statement_t*) statement);
398                 break;
399         case STATEMENT_RETURN:
400                 print_return_statement((const return_statement_t*) statement);
401                 break;
402         case STATEMENT_EXPRESSION:
403                 print_expression_statement((const expression_statement_t*) statement);
404                 break;
405         case STATEMENT_LABEL:
406                 print_label_statement((const label_statement_t*) statement);
407                 break;
408         case STATEMENT_GOTO:
409                 print_goto_statement((const goto_statement_t*) statement);
410                 break;
411         case STATEMENT_CONTINUE:
412                 fputs("continue;\n", out);
413                 break;
414         case STATEMENT_BREAK:
415                 fputs("break;\n", out);
416                 break;
417         case STATEMENT_IF:
418                 print_if_statement((const if_statement_t*) statement);
419                 break;
420         case STATEMENT_SWITCH:
421                 print_switch_statement((const switch_statement_t*) statement);
422                 break;
423         case STATEMENT_CASE_LABEL:
424                 print_case_label((const case_label_statement_t*) statement);
425                 break;
426         case STATEMENT_DECLARATION:
427                 print_declaration_statement((const declaration_statement_t*) statement);
428                 break;
429         case STATEMENT_WHILE:
430                 print_while_statement((const while_statement_t*) statement);
431                 break;
432         case STATEMENT_DO_WHILE:
433                 print_do_while_statement((const do_while_statement_t*) statement);
434                 break;
435         case STATEMENT_FOR:
436                 print_for_statement((const for_statement_t*) statement);
437                 break;
438         case STATEMENT_INVALID:
439                 fprintf(out, "*invalid statement*");
440                 break;
441         }
442 }
443
444 static void print_storage_class(storage_class_t storage_class)
445 {
446         switch(storage_class) {
447         case STORAGE_CLASS_ENUM_ENTRY:
448         case STORAGE_CLASS_NONE:
449                 break;
450         case STORAGE_CLASS_TYPEDEF:  fputs("typedef ", out); break;
451         case STORAGE_CLASS_EXTERN:   fputs("extern ", out); break;
452         case STORAGE_CLASS_STATIC:   fputs("static ", out); break;
453         case STORAGE_CLASS_AUTO:     fputs("auto ", out); break;
454         case STORAGE_CLASS_REGISTER: fputs("register ", out); break;
455         }
456 }
457
458 static void print_declaration(const declaration_t *declaration)
459 {
460         print_storage_class(declaration->storage_class);
461         print_type_ext(declaration->type, declaration->symbol,
462                        &declaration->context);
463         if(declaration->statement != NULL) {
464                 fputs("\n", out);
465                 print_statement(declaration->statement);
466         } else if(declaration->initializer != NULL) {
467                 fputs(" = ", out);
468                 print_expression(declaration->initializer);
469                 fprintf(out, ";\n");
470         } else {
471                 fprintf(out, ";\n");
472         }
473 }
474
475 void print_ast(const translation_unit_t *unit)
476 {
477         declaration_t *declaration = unit->context.declarations;
478         while(declaration != NULL) {
479                 print_declaration(declaration);
480
481                 declaration = declaration->next;
482         }
483 }
484
485 void init_ast(void)
486 {
487         obstack_init(&ast_obstack);
488 }
489
490 void exit_ast(void)
491 {
492         obstack_free(&ast_obstack, NULL);
493 }
494
495 void ast_set_output(FILE *stream)
496 {
497         out = stream;
498         type_set_output(stream);
499 }
500
501 void* (allocate_ast) (size_t size)
502 {
503         return _allocate_ast(size);
504 }