96f4abc8182a574de105cbd54dfdb2f02c518df0
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <ctype.h>
10
11 #include "adt/error.h"
12
13 struct obstack ast_obstack;
14
15 static FILE *out;
16 static int   indent;
17
18 static void print_statement(const statement_t *statement);
19 static void print_declaration(const declaration_t *declaration);
20
21 static void print_indent(void)
22 {
23         for(int i = 0; i < indent; ++i)
24                 fprintf(out, "\t");
25 }
26
27 static void print_const(const const_t *cnst)
28 {
29         fprintf(out, "%d", cnst->value);
30 }
31
32 static void print_string_literal(const string_literal_t *string_literal)
33 {
34         fputc('"', out);
35         for(const char *c = string_literal->value; *c != '\0'; ++c) {
36                 switch(*c) {
37                 case '\"':  fputs("\\\"", out); break;
38                 case '\\':  fputs("\\\\", out); break;
39                 case '\a':  fputs("\\a", out); break;
40                 case '\b':  fputs("\\b", out); break;
41                 case '\f':  fputs("\\f", out); break;
42                 case '\n':  fputs("\\n", out); break;
43                 case '\r':  fputs("\\r", out); break;
44                 case '\t':  fputs("\\t", out); break;
45                 case '\v':  fputs("\\v", out); break;
46                 case '\?':  fputs("\\?", out); break;
47                 default:
48                         if(!isprint(*c)) {
49                                 fprintf(out, "\\x%x", *c);
50                                 break;
51                         }
52                         fputc(*c, out);
53                         break;
54                 }
55         }
56         fputc('"', out);
57 }
58
59 static void print_call_expression(const call_expression_t *call)
60 {
61         print_expression(call->method);
62         fprintf(out, "(");
63         call_argument_t *argument = call->arguments;
64         int              first    = 1;
65         while(argument != NULL) {
66                 if(!first) {
67                         fprintf(out, ", ");
68                 } else {
69                         first = 0;
70                 }
71                 print_expression(argument->expression);
72
73                 argument = argument->next;
74         }
75         fprintf(out, ")");
76 }
77
78 static void print_binary_expression(const binary_expression_t *binexpr)
79 {
80         fprintf(out, "(");
81         print_expression(binexpr->left);
82         fprintf(out, " ");
83         switch(binexpr->type) {
84         case BINEXPR_INVALID:            fputs("INVOP", out); break;
85         case BINEXPR_COMMA:              fputs(",", out);     break;
86         case BINEXPR_ASSIGN:             fputs("=", out);     break;
87         case BINEXPR_ADD:                fputs("+", out);     break;
88         case BINEXPR_SUB:                fputs("-", out);     break;
89         case BINEXPR_MUL:                fputs("*", out);     break;
90         case BINEXPR_MOD:                fputs("%", out);     break;
91         case BINEXPR_DIV:                fputs("/", out);     break;
92         case BINEXPR_BITWISE_OR:         fputs("|", out);     break;
93         case BINEXPR_BITWISE_AND:        fputs("&", out);     break;
94         case BINEXPR_BITWISE_XOR:        fputs("^", out);     break;
95         case BINEXPR_LOGICAL_OR:         fputs("||", out);    break;
96         case BINEXPR_LOGICAL_AND:        fputs("&&", out);    break;
97         case BINEXPR_NOTEQUAL:           fputs("!=", out);    break;
98         case BINEXPR_EQUAL:              fputs("==", out);    break;
99         case BINEXPR_LESS:               fputs("<", out);     break;
100         case BINEXPR_LESSEQUAL:          fputs("<=", out);    break;
101         case BINEXPR_GREATER:            fputs(">", out);     break;
102         case BINEXPR_GREATEREQUAL:       fputs(">=", out);    break;
103         case BINEXPR_SHIFTLEFT:          fputs("<<", out);    break;
104         case BINEXPR_SHIFTRIGHT:         fputs(">>", out);    break;
105
106         case BINEXPR_ADD_ASSIGN:         fputs("+=", out);    break;
107         case BINEXPR_SUB_ASSIGN:         fputs("-=", out);    break;
108         case BINEXPR_MUL_ASSIGN:         fputs("*=", out);    break;
109         case BINEXPR_MOD_ASSIGN:         fputs("%=", out);    break;
110         case BINEXPR_DIV_ASSIGN:         fputs("/=", out);    break;
111         case BINEXPR_BITWISE_OR_ASSIGN:  fputs("|=", out);    break;
112         case BINEXPR_BITWISE_AND_ASSIGN: fputs("&=", out);    break;
113         case BINEXPR_BITWISE_XOR_ASSIGN: fputs("^=", out);    break;
114         case BINEXPR_SHIFTLEFT_ASSIGN:   fputs("<<=", out);   break;
115         case BINEXPR_SHIFTRIGHT_ASSIGN:  fputs(">>=", out);   break;
116         }
117         fprintf(out, " ");
118         print_expression(binexpr->right);
119         fprintf(out, ")");
120 }
121
122 static void print_unary_expression(const unary_expression_t *unexpr)
123 {
124         switch(unexpr->type) {
125         case UNEXPR_NEGATE:           fputs("-", out);  break;
126         case UNEXPR_PLUS:             fputs("+", out);  break;
127         case UNEXPR_NOT:              fputs("!", out);  break;
128         case UNEXPR_BITWISE_NEGATE:   fputs("~", out);  break;
129         case UNEXPR_PREFIX_INCREMENT: fputs("++", out); break;
130         case UNEXPR_PREFIX_DECREMENT: fputs("--", out); break;
131         case UNEXPR_DEREFERENCE:      fputs("*", out);  break;
132         case UNEXPR_TAKE_ADDRESS:     fputs("&", out);  break;
133
134         case UNEXPR_POSTFIX_INCREMENT:
135                 fputs("(", out);
136                 print_expression(unexpr->value);
137                 fputs(")", out);
138                 fputs("++", out);
139                 return;
140         case UNEXPR_POSTFIX_DECREMENT:
141                 fputs("(", out);
142                 print_expression(unexpr->value);
143                 fputs(")", out);
144                 fputs("--", out);
145                 return;
146         case UNEXPR_CAST:
147                 fputs("(", out);
148                 print_type(unexpr->expression.datatype);
149                 fputs(")", out);
150                 break;
151         case UNEXPR_INVALID:
152                 fprintf(out, "unop%d", unexpr->type);
153                 break;
154         }
155         fputs("(", out);
156         print_expression(unexpr->value);
157         fputs(")", out);
158 }
159
160 static void print_reference_expression(const reference_expression_t *ref)
161 {
162         fprintf(out, "%s", ref->declaration->symbol->string);
163 }
164
165 static void print_array_expression(const array_access_expression_t *expression)
166 {
167         fputs("(", out);
168         print_expression(expression->array_ref);
169         fputs(")[", out);
170         print_expression(expression->index);
171         fputs("]", out);
172 }
173
174 static void print_sizeof_expression(const sizeof_expression_t *expression)
175 {
176         fputs("sizeof", out);
177         if(expression->size_expression != NULL) {
178                 fputc('(', out);
179                 print_expression(expression->size_expression);
180                 fputc(')', out);
181         } else {
182                 fputc('(', out);
183                 print_type(expression->type);
184                 fputc(')', out);
185         }
186 }
187
188 static void print_builtin_symbol(const builtin_symbol_expression_t *expression)
189 {
190         fputs(expression->symbol->string, out);
191 }
192
193 void print_expression(const expression_t *expression)
194 {
195         switch(expression->type) {
196         case EXPR_INVALID:
197                 fprintf(out, "*invalid expression*");
198                 break;
199         case EXPR_CONST:
200                 print_const((const const_t*) expression);
201                 break;
202         case EXPR_FUNCTION:
203         case EXPR_PRETTY_FUNCTION:
204         case EXPR_STRING_LITERAL:
205                 print_string_literal((const string_literal_t*) expression);
206                 break;
207         case EXPR_CALL:
208                 print_call_expression((const call_expression_t*) expression);
209                 break;
210         case EXPR_BINARY:
211                 print_binary_expression((const binary_expression_t*) expression);
212                 break;
213         case EXPR_REFERENCE:
214                 print_reference_expression((const reference_expression_t*) expression);
215                 break;
216         case EXPR_ARRAY_ACCESS:
217                 print_array_expression((const array_access_expression_t*) expression);
218                 break;
219         case EXPR_UNARY:
220                 print_unary_expression((const unary_expression_t*) expression);
221                 break;
222         case EXPR_SIZEOF:
223                 print_sizeof_expression((const sizeof_expression_t*) expression);
224                 break;
225         case EXPR_BUILTIN_SYMBOL:
226                 print_builtin_symbol((const builtin_symbol_expression_t*) expression);
227                 break;
228
229         case EXPR_OFFSETOF:
230         case EXPR_STATEMENT:
231         case EXPR_SELECT:
232                 /* TODO */
233                 fprintf(out, "some expression of type %d", expression->type);
234                 break;
235         }
236 }
237
238 static void print_compound_statement(const compound_statement_t *block)
239 {
240         fputs("{\n", out);
241         indent++;
242
243         statement_t *statement = block->statements;
244         while(statement != NULL) {
245                 print_indent();
246                 print_statement(statement);
247
248                 statement = statement->next;
249         }
250         indent--;
251         print_indent();
252         fputs("}\n", out);
253 }
254
255 static void print_return_statement(const return_statement_t *statement)
256 {
257         fprintf(out, "return ");
258         if(statement->return_value != NULL)
259                 print_expression(statement->return_value);
260         fputs(";\n", out);
261 }
262
263 static void print_expression_statement(const expression_statement_t *statement)
264 {
265         print_expression(statement->expression);
266         fputs(";\n", out);
267 }
268
269 static void print_goto_statement(const goto_statement_t *statement)
270 {
271         fprintf(out, "goto ");
272         if(statement->label != NULL) {
273                 fprintf(out, "%s", statement->label->symbol->string);
274         } else {
275                 fprintf(out, "?%s", statement->label_symbol->string);
276         }
277         fputs(";\n", out);
278 }
279
280 static void print_label_statement(const label_statement_t *statement)
281 {
282         fprintf(out, "%s:\n", statement->symbol->string);
283 }
284
285 static void print_if_statement(const if_statement_t *statement)
286 {
287         fputs("if(", out);
288         print_expression(statement->condition);
289         fputs(") ", out);
290         if(statement->true_statement != NULL) {
291                 print_statement(statement->true_statement);
292         }
293
294         if(statement->false_statement != NULL) {
295                 print_indent();
296                 fputs("else ", out);
297                 print_statement(statement->false_statement);
298         }
299 }
300
301 static void print_switch_statement(const switch_statement_t *statement)
302 {
303         fputs("switch(", out);
304         print_expression(statement->expression);
305         fputs(") ", out);
306         print_statement(statement->body);
307 }
308
309 static void print_case_label(const case_label_statement_t *statement)
310 {
311         if(statement->expression == NULL) {
312                 fputs("default:\n", out);
313         } else {
314                 fputs("case ", out);
315                 print_expression(statement->expression);
316                 fputs(":\n", out);
317         }
318 }
319
320 static void print_declaration_statement(
321                 const declaration_statement_t *statement)
322 {
323         declaration_t *declaration = statement->declarations_begin;
324         for( ; declaration != statement->declarations_end->next;
325                declaration = declaration->next) {
326                 print_declaration(declaration);
327         }
328 }
329
330 static void print_while_statement(const while_statement_t *statement)
331 {
332         fputs("while(", out);
333         print_expression(statement->condition);
334         fputs(") ", out);
335         print_statement(statement->body);
336 }
337
338 static void print_do_while_statement(const do_while_statement_t *statement)
339 {
340         fputs("do ", out);
341         print_statement(statement->body);
342         print_indent();
343         fputs("while(", out);
344         print_expression(statement->condition);
345         fputs(");\n", out);
346 }
347
348 static void print_for_statemenet(const for_statement_t *statement)
349 {
350         fprintf(out, "for(TODO) ");
351         print_statement(statement->body);
352 }
353
354 void print_statement(const statement_t *statement)
355 {
356         switch(statement->type) {
357         case STATEMENT_COMPOUND:
358                 print_compound_statement((const compound_statement_t*) statement);
359                 break;
360         case STATEMENT_RETURN:
361                 print_return_statement((const return_statement_t*) statement);
362                 break;
363         case STATEMENT_EXPRESSION:
364                 print_expression_statement((const expression_statement_t*) statement);
365                 break;
366         case STATEMENT_LABEL:
367                 print_label_statement((const label_statement_t*) statement);
368                 break;
369         case STATEMENT_GOTO:
370                 print_goto_statement((const goto_statement_t*) statement);
371                 break;
372         case STATEMENT_CONTINUE:
373                 fputs("continue;\n", out);
374                 break;
375         case STATEMENT_BREAK:
376                 fputs("break;\n", out);
377                 break;
378         case STATEMENT_IF:
379                 print_if_statement((const if_statement_t*) statement);
380                 break;
381         case STATEMENT_SWITCH:
382                 print_switch_statement((const switch_statement_t*) statement);
383                 break;
384         case STATEMENT_CASE_LABEL:
385                 print_case_label((const case_label_statement_t*) statement);
386                 break;
387         case STATEMENT_DECLARATION:
388                 print_declaration_statement((const declaration_statement_t*) statement);
389                 break;
390         case STATEMENT_WHILE:
391                 print_while_statement((const while_statement_t*) statement);
392                 break;
393         case STATEMENT_DO_WHILE:
394                 print_do_while_statement((const do_while_statement_t*) statement);
395                 break;
396         case STATEMENT_FOR:
397                 print_for_statemenet((const for_statement_t*) statement);
398                 break;
399         case STATEMENT_INVALID:
400                 fprintf(out, "*invalid statement*");
401                 break;
402         }
403 }
404
405 static void print_storage_class(storage_class_t storage_class)
406 {
407         switch(storage_class) {
408         case STORAGE_CLASS_ENUM_ENTRY:
409         case STORAGE_CLASS_NONE:
410                 break;
411         case STORAGE_CLASS_TYPEDEF:  fputs("typedef ", out); break;
412         case STORAGE_CLASS_EXTERN:   fputs("extern ", out); break;
413         case STORAGE_CLASS_STATIC:   fputs("static ", out); break;
414         case STORAGE_CLASS_AUTO:     fputs("auto ", out); break;
415         case STORAGE_CLASS_REGISTER: fputs("register ", out); break;
416         }
417 }
418
419 static void print_declaration(const declaration_t *declaration)
420 {
421         print_storage_class(declaration->storage_class);
422         print_type_ext(declaration->type, declaration->symbol,
423                        &declaration->context);
424         if(declaration->statement != NULL) {
425                 fputs("\n", out);
426                 print_statement(declaration->statement);
427         } else if(declaration->initializer != NULL) {
428                 fputs(" = ", out);
429                 print_expression(declaration->initializer);
430                 fprintf(out, ";\n");
431         } else {
432                 fprintf(out, ";\n");
433         }
434 }
435
436 void print_ast(const translation_unit_t *unit)
437 {
438         declaration_t *declaration = unit->context.declarations;
439         while(declaration != NULL) {
440                 print_declaration(declaration);
441
442                 declaration = declaration->next;
443         }
444 }
445
446 void init_ast(void)
447 {
448         obstack_init(&ast_obstack);
449 }
450
451 void exit_ast(void)
452 {
453         obstack_free(&ast_obstack, NULL);
454 }
455
456 void ast_set_output(FILE *stream)
457 {
458         out = stream;
459         type_set_output(stream);
460 }
461
462 void* (allocate_ast) (size_t size)
463 {
464         return _allocate_ast(size);
465 }