filter trigraphs in advance and simplify lexer code because of that
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <ctype.h>
10
11 #include "adt/error.h"
12
13 struct obstack ast_obstack;
14
15 static FILE *out;
16 static int   indent;
17
18 static void print_statement(const statement_t *statement);
19 static void print_declaration(const declaration_t *declaration);
20
21 static void print_indent(void)
22 {
23         for(int i = 0; i < indent; ++i)
24                 fprintf(out, "\t");
25 }
26
27 static void print_const(const const_t *cnst)
28 {
29         fprintf(out, "%d", cnst->value);
30 }
31
32 static void print_string_literal(const string_literal_t *string_literal)
33 {
34         fputc('"', out);
35         for(const char *c = string_literal->value; *c != '\0'; ++c) {
36                 switch(*c) {
37                 case '\"':  fputs("\\\"", out); break;
38                 case '\\':  fputs("\\\\", out); break;
39                 case '\a':  fputs("\\a", out); break;
40                 case '\b':  fputs("\\b", out); break;
41                 case '\f':  fputs("\\f", out); break;
42                 case '\n':  fputs("\\n", out); break;
43                 case '\r':  fputs("\\r", out); break;
44                 case '\t':  fputs("\\t", out); break;
45                 case '\v':  fputs("\\v", out); break;
46                 case '\?':  fputs("\\?", out); break;
47                 default:
48                         if(!isprint(*c)) {
49                                 fprintf(out, "\\x%x", *c);
50                                 break;
51                         }
52                         fputc(*c, out);
53                         break;
54                 }
55         }
56         fputc('"', out);
57 }
58
59 static void print_call_expression(const call_expression_t *call)
60 {
61         print_expression(call->method);
62         fprintf(out, "(");
63         call_argument_t *argument = call->arguments;
64         int              first    = 1;
65         while(argument != NULL) {
66                 if(!first) {
67                         fprintf(out, ", ");
68                 } else {
69                         first = 0;
70                 }
71                 print_expression(argument->expression);
72
73                 argument = argument->next;
74         }
75         fprintf(out, ")");
76 }
77
78 static void print_binary_expression(const binary_expression_t *binexpr)
79 {
80         fprintf(out, "(");
81         print_expression(binexpr->left);
82         fprintf(out, " ");
83         switch(binexpr->type) {
84         case BINEXPR_INVALID:            fputs("INVOP", out); break;
85         case BINEXPR_COMMA:              fputs(",", out);     break;
86         case BINEXPR_ASSIGN:             fputs("=", out);     break;
87         case BINEXPR_ADD:                fputs("+", out);     break;
88         case BINEXPR_SUB:                fputs("-", out);     break;
89         case BINEXPR_MUL:                fputs("*", out);     break;
90         case BINEXPR_MOD:                fputs("%", out);     break;
91         case BINEXPR_DIV:                fputs("/", out);     break;
92         case BINEXPR_BITWISE_OR:         fputs("|", out);     break;
93         case BINEXPR_BITWISE_AND:        fputs("&", out);     break;
94         case BINEXPR_BITWISE_XOR:        fputs("^", out);     break;
95         case BINEXPR_LOGICAL_OR:         fputs("||", out);    break;
96         case BINEXPR_LOGICAL_AND:        fputs("&&", out);    break;
97         case BINEXPR_NOTEQUAL:           fputs("!=", out);    break;
98         case BINEXPR_EQUAL:              fputs("==", out);    break;
99         case BINEXPR_LESS:               fputs("<", out);     break;
100         case BINEXPR_LESSEQUAL:          fputs("<=", out);    break;
101         case BINEXPR_GREATER:            fputs(">", out);     break;
102         case BINEXPR_GREATEREQUAL:       fputs(">=", out);    break;
103         case BINEXPR_SHIFTLEFT:          fputs("<<", out);    break;
104         case BINEXPR_SHIFTRIGHT:         fputs(">>", out);    break;
105
106         case BINEXPR_ADD_ASSIGN:         fputs("+=", out);    break;
107         case BINEXPR_SUB_ASSIGN:         fputs("-=", out);    break;
108         case BINEXPR_MUL_ASSIGN:         fputs("*=", out);    break;
109         case BINEXPR_MOD_ASSIGN:         fputs("%=", out);    break;
110         case BINEXPR_DIV_ASSIGN:         fputs("/=", out);    break;
111         case BINEXPR_BITWISE_OR_ASSIGN:  fputs("|=", out);    break;
112         case BINEXPR_BITWISE_AND_ASSIGN: fputs("&=", out);    break;
113         case BINEXPR_BITWISE_XOR_ASSIGN: fputs("^=", out);    break;
114         case BINEXPR_SHIFTLEFT_ASSIGN:   fputs("<<=", out);   break;
115         case BINEXPR_SHIFTRIGHT_ASSIGN:  fputs(">>=", out);   break;
116         }
117         fprintf(out, " ");
118         print_expression(binexpr->right);
119         fprintf(out, ")");
120 }
121
122 static void print_unary_expression(const unary_expression_t *unexpr)
123 {
124         switch(unexpr->type) {
125         case UNEXPR_NEGATE:           fputs("-", out);  break;
126         case UNEXPR_PLUS:             fputs("+", out);  break;
127         case UNEXPR_NOT:              fputs("!", out);  break;
128         case UNEXPR_BITWISE_NEGATE:   fputs("~", out);  break;
129         case UNEXPR_PREFIX_INCREMENT: fputs("++", out); break;
130         case UNEXPR_PREFIX_DECREMENT: fputs("--", out); break;
131         case UNEXPR_DEREFERENCE:      fputs("*", out);  break;
132         case UNEXPR_TAKE_ADDRESS:     fputs("&", out);  break;
133
134         case UNEXPR_POSTFIX_INCREMENT:
135                 fputs("(", out);
136                 print_expression(unexpr->value);
137                 fputs(")", out);
138                 fputs("++", out);
139                 return;
140         case UNEXPR_POSTFIX_DECREMENT:
141                 fputs("(", out);
142                 print_expression(unexpr->value);
143                 fputs(")", out);
144                 fputs("--", out);
145                 return;
146         case UNEXPR_CAST:
147                 fputs("(", out);
148                 print_type(unexpr->expression.datatype);
149                 fputs(")", out);
150                 break;
151         case UNEXPR_INVALID:
152                 fprintf(out, "unop%d", unexpr->type);
153                 break;
154         }
155         fputs("(", out);
156         print_expression(unexpr->value);
157         fputs(")", out);
158 }
159
160 static void print_reference_expression(const reference_expression_t *ref)
161 {
162         fprintf(out, "%s", ref->declaration->symbol->string);
163 }
164
165 static void print_array_expression(const array_access_expression_t *expression)
166 {
167         fputs("(", out);
168         print_expression(expression->array_ref);
169         fputs(")[", out);
170         print_expression(expression->index);
171         fputs("]", out);
172 }
173
174 static void print_sizeof_expression(const sizeof_expression_t *expression)
175 {
176         fputs("sizeof", out);
177         if(expression->size_expression != NULL) {
178                 fputc('(', out);
179                 print_expression(expression->size_expression);
180                 fputc(')', out);
181         } else {
182                 fputc('(', out);
183                 print_type(expression->type);
184                 fputc(')', out);
185         }
186 }
187
188 static void print_builtin_symbol(const builtin_symbol_expression_t *expression)
189 {
190         fputs(expression->symbol->string, out);
191 }
192
193 void print_expression(const expression_t *expression)
194 {
195         switch(expression->type) {
196         case EXPR_INVALID:
197                 fprintf(out, "*invalid expression*");
198                 break;
199         case EXPR_CONST:
200                 print_const((const const_t*) expression);
201                 break;
202         case EXPR_FUNCTION:
203         case EXPR_PRETTY_FUNCTION:
204         case EXPR_STRING_LITERAL:
205                 print_string_literal((const string_literal_t*) expression);
206                 break;
207         case EXPR_CALL:
208                 print_call_expression((const call_expression_t*) expression);
209                 break;
210         case EXPR_BINARY:
211                 print_binary_expression((const binary_expression_t*) expression);
212                 break;
213         case EXPR_REFERENCE:
214                 print_reference_expression((const reference_expression_t*) expression);
215                 break;
216         case EXPR_ARRAY_ACCESS:
217                 print_array_expression((const array_access_expression_t*) expression);
218                 break;
219         case EXPR_UNARY:
220                 print_unary_expression((const unary_expression_t*) expression);
221                 break;
222         case EXPR_SIZEOF:
223                 print_sizeof_expression((const sizeof_expression_t*) expression);
224                 break;
225         case EXPR_BUILTIN_SYMBOL:
226                 print_builtin_symbol((const builtin_symbol_expression_t*) expression);
227                 break;
228
229         case EXPR_CONDITIONAL:
230         case EXPR_OFFSETOF:
231         case EXPR_STATEMENT:
232         case EXPR_SELECT:
233                 /* TODO */
234                 fprintf(out, "some expression of type %d", expression->type);
235                 break;
236         }
237 }
238
239 static void print_compound_statement(const compound_statement_t *block)
240 {
241         fputs("{\n", out);
242         indent++;
243
244         statement_t *statement = block->statements;
245         while(statement != NULL) {
246                 print_indent();
247                 print_statement(statement);
248
249                 statement = statement->next;
250         }
251         indent--;
252         print_indent();
253         fputs("}\n", out);
254 }
255
256 static void print_return_statement(const return_statement_t *statement)
257 {
258         fprintf(out, "return ");
259         if(statement->return_value != NULL)
260                 print_expression(statement->return_value);
261         fputs(";\n", out);
262 }
263
264 static void print_expression_statement(const expression_statement_t *statement)
265 {
266         print_expression(statement->expression);
267         fputs(";\n", out);
268 }
269
270 static void print_goto_statement(const goto_statement_t *statement)
271 {
272         fprintf(out, "goto ");
273         if(statement->label != NULL) {
274                 fprintf(out, "%s", statement->label->symbol->string);
275         } else {
276                 fprintf(out, "?%s", statement->label_symbol->string);
277         }
278         fputs(";\n", out);
279 }
280
281 static void print_label_statement(const label_statement_t *statement)
282 {
283         fprintf(out, "%s:\n", statement->symbol->string);
284 }
285
286 static void print_if_statement(const if_statement_t *statement)
287 {
288         fputs("if(", out);
289         print_expression(statement->condition);
290         fputs(") ", out);
291         if(statement->true_statement != NULL) {
292                 print_statement(statement->true_statement);
293         }
294
295         if(statement->false_statement != NULL) {
296                 print_indent();
297                 fputs("else ", out);
298                 print_statement(statement->false_statement);
299         }
300 }
301
302 static void print_switch_statement(const switch_statement_t *statement)
303 {
304         fputs("switch(", out);
305         print_expression(statement->expression);
306         fputs(") ", out);
307         print_statement(statement->body);
308 }
309
310 static void print_case_label(const case_label_statement_t *statement)
311 {
312         if(statement->expression == NULL) {
313                 fputs("default:\n", out);
314         } else {
315                 fputs("case ", out);
316                 print_expression(statement->expression);
317                 fputs(":\n", out);
318         }
319 }
320
321 static void print_declaration_statement(
322                 const declaration_statement_t *statement)
323 {
324         declaration_t *declaration = statement->declarations_begin;
325         for( ; declaration != statement->declarations_end->next;
326                declaration = declaration->next) {
327                 print_declaration(declaration);
328         }
329 }
330
331 static void print_while_statement(const while_statement_t *statement)
332 {
333         fputs("while(", out);
334         print_expression(statement->condition);
335         fputs(") ", out);
336         print_statement(statement->body);
337 }
338
339 static void print_do_while_statement(const do_while_statement_t *statement)
340 {
341         fputs("do ", out);
342         print_statement(statement->body);
343         print_indent();
344         fputs("while(", out);
345         print_expression(statement->condition);
346         fputs(");\n", out);
347 }
348
349 static void print_for_statemenet(const for_statement_t *statement)
350 {
351         fprintf(out, "for(TODO) ");
352         print_statement(statement->body);
353 }
354
355 void print_statement(const statement_t *statement)
356 {
357         switch(statement->type) {
358         case STATEMENT_COMPOUND:
359                 print_compound_statement((const compound_statement_t*) statement);
360                 break;
361         case STATEMENT_RETURN:
362                 print_return_statement((const return_statement_t*) statement);
363                 break;
364         case STATEMENT_EXPRESSION:
365                 print_expression_statement((const expression_statement_t*) statement);
366                 break;
367         case STATEMENT_LABEL:
368                 print_label_statement((const label_statement_t*) statement);
369                 break;
370         case STATEMENT_GOTO:
371                 print_goto_statement((const goto_statement_t*) statement);
372                 break;
373         case STATEMENT_CONTINUE:
374                 fputs("continue;\n", out);
375                 break;
376         case STATEMENT_BREAK:
377                 fputs("break;\n", out);
378                 break;
379         case STATEMENT_IF:
380                 print_if_statement((const if_statement_t*) statement);
381                 break;
382         case STATEMENT_SWITCH:
383                 print_switch_statement((const switch_statement_t*) statement);
384                 break;
385         case STATEMENT_CASE_LABEL:
386                 print_case_label((const case_label_statement_t*) statement);
387                 break;
388         case STATEMENT_DECLARATION:
389                 print_declaration_statement((const declaration_statement_t*) statement);
390                 break;
391         case STATEMENT_WHILE:
392                 print_while_statement((const while_statement_t*) statement);
393                 break;
394         case STATEMENT_DO_WHILE:
395                 print_do_while_statement((const do_while_statement_t*) statement);
396                 break;
397         case STATEMENT_FOR:
398                 print_for_statemenet((const for_statement_t*) statement);
399                 break;
400         case STATEMENT_INVALID:
401                 fprintf(out, "*invalid statement*");
402                 break;
403         }
404 }
405
406 static void print_storage_class(storage_class_t storage_class)
407 {
408         switch(storage_class) {
409         case STORAGE_CLASS_ENUM_ENTRY:
410         case STORAGE_CLASS_NONE:
411                 break;
412         case STORAGE_CLASS_TYPEDEF:  fputs("typedef ", out); break;
413         case STORAGE_CLASS_EXTERN:   fputs("extern ", out); break;
414         case STORAGE_CLASS_STATIC:   fputs("static ", out); break;
415         case STORAGE_CLASS_AUTO:     fputs("auto ", out); break;
416         case STORAGE_CLASS_REGISTER: fputs("register ", out); break;
417         }
418 }
419
420 static void print_declaration(const declaration_t *declaration)
421 {
422         print_storage_class(declaration->storage_class);
423         print_type_ext(declaration->type, declaration->symbol,
424                        &declaration->context);
425         if(declaration->statement != NULL) {
426                 fputs("\n", out);
427                 print_statement(declaration->statement);
428         } else if(declaration->initializer != NULL) {
429                 fputs(" = ", out);
430                 print_expression(declaration->initializer);
431                 fprintf(out, ";\n");
432         } else {
433                 fprintf(out, ";\n");
434         }
435 }
436
437 void print_ast(const translation_unit_t *unit)
438 {
439         declaration_t *declaration = unit->context.declarations;
440         while(declaration != NULL) {
441                 print_declaration(declaration);
442
443                 declaration = declaration->next;
444         }
445 }
446
447 void init_ast(void)
448 {
449         obstack_init(&ast_obstack);
450 }
451
452 void exit_ast(void)
453 {
454         obstack_free(&ast_obstack, NULL);
455 }
456
457 void ast_set_output(FILE *stream)
458 {
459         out = stream;
460         type_set_output(stream);
461 }
462
463 void* (allocate_ast) (size_t size)
464 {
465         return _allocate_ast(size);
466 }