8f497ee8833edfbf30bed14fcdd64f3cfbbd8a30
[cparser] / ast.c
1 #include <config.h>
2
3 #include "ast_t.h"
4 #include "type_t.h"
5
6 #include <assert.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9
10 #include "adt/error.h"
11
12 struct obstack ast_obstack;
13
14 static FILE *out;
15 static int   indent;
16
17 static void print_expression(const expression_t *expression);
18 static void print_statement(const statement_t *statement);
19
20 static
21 void print_const(const const_t *cnst)
22 {
23         fprintf(out, "%d", cnst->value);
24 }
25
26 static
27 void print_string_literal(const string_literal_t *string_literal)
28 {
29         /* TODO escape " and non-printable chars */
30         fprintf(out, "\"%s\"", string_literal->value);
31 }
32
33 static
34 void print_call_expression(const call_expression_t *call)
35 {
36         print_expression(call->method);
37         fprintf(out, "(");
38         call_argument_t *argument = call->arguments;
39         int              first    = 1;
40         while(argument != NULL) {
41                 if(!first) {
42                         fprintf(out, ", ");
43                 } else {
44                         first = 0;
45                 }
46                 print_expression(argument->expression);
47
48                 argument = argument->next;
49         }
50         fprintf(out, ")");
51 }
52
53 static
54 void print_binary_expression(const binary_expression_t *binexpr)
55 {
56         fprintf(out, "(");
57         print_expression(binexpr->left);
58         fprintf(out, " ");
59         switch(binexpr->type) {
60         case BINEXPR_INVALID:
61                 fprintf(out, "INVOP");
62                 break;
63         case BINEXPR_ASSIGN:
64                 fprintf(out, "<-");
65                 break;
66         case BINEXPR_ADD:
67                 fprintf(out, "+");
68                 break;
69         case BINEXPR_SUB:
70                 fprintf(out, "-");
71                 break;
72         case BINEXPR_NOTEQUAL:
73                 fprintf(out, "/=");
74                 break;
75         case BINEXPR_EQUAL:
76                 fprintf(out, "=");
77                 break;
78         case BINEXPR_LESS:
79                 fprintf(out, "<");
80                 break;
81         case BINEXPR_LESSEQUAL:
82                 fprintf(out, "<=");
83                 break;
84         case BINEXPR_GREATER:
85                 fprintf(out, ">");
86                 break;
87         case BINEXPR_GREATEREQUAL:
88                 fprintf(out, ">=");
89                 break;
90         default:
91                 /* TODO: add missing ops */
92                 fprintf(out, "op%d", binexpr->type);
93                 break;
94         }
95         fprintf(out, " ");
96         print_expression(binexpr->right);
97         fprintf(out, ")");
98 }
99
100 void print_expression(const expression_t *expression)
101 {
102         switch(expression->type) {
103         case EXPR_INVALID:
104                 fprintf(out, "*invalid expression*");
105                 break;
106         case EXPR_CONST:
107                 print_const((const const_t*) expression);
108                 break;
109         case EXPR_STRING_LITERAL:
110                 print_string_literal((const string_literal_t*) expression);
111                 break;
112         case EXPR_CALL:
113                 print_call_expression((const call_expression_t*) expression);
114                 break;
115         case EXPR_BINARY:
116                 print_binary_expression((const binary_expression_t*) expression);
117                 break;
118         case EXPR_REFERENCE:
119         case EXPR_UNARY:
120         case EXPR_SELECT:
121         case EXPR_ARRAY_ACCESS:
122         case EXPR_SIZEOF:
123                 /* TODO */
124                 fprintf(out, "some expression of type %d", expression->type);
125                 break;
126         }
127 }
128
129 static
130 void print_compound_statement(const compound_statement_t *block)
131 {
132         fputs("{\n", out);
133         indent++;
134
135         statement_t *statement = block->statements;
136         while(statement != NULL) {
137                 print_statement(statement);
138
139                 statement = statement->next;
140         }
141         indent--;
142         fputs("}\n", out);
143 }
144
145 static
146 void print_return_statement(const return_statement_t *statement)
147 {
148         fprintf(out, "return ");
149         if(statement->return_value != NULL)
150                 print_expression(statement->return_value);
151 }
152
153 static
154 void print_expression_statement(const expression_statement_t *statement)
155 {
156         print_expression(statement->expression);
157 }
158
159 static
160 void print_goto_statement(const goto_statement_t *statement)
161 {
162         fprintf(out, "goto ");
163         if(statement->label != NULL) {
164                 fprintf(out, "%s", statement->label->symbol->string);
165         } else {
166                 fprintf(out, "?%s", statement->label_symbol->string);
167         }
168 }
169
170 static
171 void print_label_statement(const label_statement_t *statement)
172 {
173         fprintf(out, ":%s", statement->symbol->string);
174 }
175
176 static
177 void print_if_statement(const if_statement_t *statement)
178 {
179         fprintf(out, "if ");
180         print_expression(statement->condition);
181         fprintf(out, ":\n");
182         if(statement->true_statement != NULL) {
183                 print_statement(statement->true_statement);
184         }
185
186         if(statement->false_statement != NULL) {
187                 fprintf(out, "else:\n");
188                 print_statement(statement->false_statement);
189         }
190 }
191
192 static
193 void print_declaration_statement(const declaration_statement_t *statement)
194 {
195         (void) statement;
196         fprintf(out, "*declaration statement*");
197 }
198
199 void print_statement(const statement_t *statement)
200 {
201         for(int i = 0; i < indent; ++i)
202                 fprintf(out, "\t");
203
204         switch(statement->type) {
205         case STATEMENT_COMPOUND:
206                 print_compound_statement((const compound_statement_t*) statement);
207                 break;
208         case STATEMENT_RETURN:
209                 print_return_statement((const return_statement_t*) statement);
210                 break;
211         case STATEMENT_EXPRESSION:
212                 print_expression_statement((const expression_statement_t*) statement);
213                 break;
214         case STATEMENT_LABEL:
215                 print_label_statement((const label_statement_t*) statement);
216                 break;
217         case STATEMENT_GOTO:
218                 print_goto_statement((const goto_statement_t*) statement);
219                 break;
220         case STATEMENT_IF:
221                 print_if_statement((const if_statement_t*) statement);
222                 break;
223         case STATEMENT_DECLARATION:
224                 print_declaration_statement((const declaration_statement_t*) statement);
225                 break;
226         case STATEMENT_INVALID:
227         default:
228                 fprintf(out, "*invalid statement*");
229                 break;
230
231         }
232         fprintf(out, "\n");
233 }
234
235 #if 0
236 static
237 void print_method_parameters(const method_parameter_t *parameters,
238                              const method_type_t *method_type)
239 {
240         fprintf(out, "(");
241
242         int                            first          = 1;
243         const method_parameter_t      *parameter      = parameters;
244         const method_parameter_type_t *parameter_type
245                 = method_type->parameter_types;
246         while(parameter != NULL && parameter_type != NULL) {
247                 if(!first) {
248                         fprintf(out, ", ");
249                 } else {
250                         first = 0;
251                 }
252
253                 print_type(parameter_type->type);
254                 fprintf(out, " %s", parameter->symbol->string);
255
256                 parameter      = parameter->next;
257                 parameter_type = parameter_type->next;
258         }
259         assert(parameter == NULL && parameter_type == NULL);
260
261         fprintf(out, ")");
262 }
263 #endif
264
265 static
266 void print_declaration(const declaration_t *declaration)
267 {
268         print_type(declaration->type, declaration->symbol);
269         fprintf(out, "\n");
270         if(declaration->statement != NULL) {
271                 print_statement(declaration->statement);
272         }
273 }
274
275 void print_ast(const translation_unit_t *unit)
276 {
277         declaration_t *declaration = unit->context.declarations;
278         while(declaration != NULL) {
279                 print_declaration(declaration);
280
281                 declaration = declaration->next;
282         }
283 }
284
285 void init_ast(void)
286 {
287         obstack_init(&ast_obstack);
288 }
289
290 void exit_ast(void)
291 {
292         obstack_free(&ast_obstack, NULL);
293 }
294
295 void ast_set_output(FILE *stream)
296 {
297         out = stream;
298         type_set_output(stream);
299 }
300
301 void* (allocate_ast) (size_t size)
302 {
303         return _allocate_ast(size);
304 }