Run benchmarks in user mode
[riscv-tests.git] / benchmarks / common / syscalls.c
1 #include <stdint.h>
2 #include <string.h>
3 #include <stdarg.h>
4 #include <stdio.h>
5 #include <limits.h>
6 #include <machine/syscall.h>
7 #include "encoding.h"
8
9 #define SYS_stats 1234
10 #define static_assert(cond) switch(0) { case 0: case !!(long)(cond): ; }
11
12 static long handle_frontend_syscall(long which, long arg0, long arg1, long arg2)
13 {
14 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
15 magic_mem[0] = which;
16 magic_mem[1] = arg0;
17 magic_mem[2] = arg1;
18 magic_mem[3] = arg2;
19 __sync_synchronize();
20 write_csr(tohost, (long)magic_mem);
21 while (swap_csr(fromhost, 0) == 0);
22 return magic_mem[0];
23 }
24
25 // In setStats, we might trap reading uarch-specific counters.
26 // The trap handler will skip over the instruction and write 0,
27 // but only if v0 is the destination register.
28 #define read_csr_safe(reg) ({ register long __tmp asm("v0"); \
29 asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \
30 __tmp; })
31
32 #define NUM_COUNTERS 18
33 static long counters[NUM_COUNTERS];
34 static char* counter_names[NUM_COUNTERS];
35 static int handle_stats(int enable)
36 {
37 int i = 0;
38 #define READ_CTR(name) do { \
39 while (i >= NUM_COUNTERS) ; \
40 long csr = read_csr_safe(name); \
41 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
42 counters[i++] = csr; \
43 } while (0)
44 READ_CTR(cycle); READ_CTR(instret);
45 READ_CTR(uarch0); READ_CTR(uarch1); READ_CTR(uarch2); READ_CTR(uarch3);
46 READ_CTR(uarch4); READ_CTR(uarch5); READ_CTR(uarch6); READ_CTR(uarch7);
47 READ_CTR(uarch8); READ_CTR(uarch9); READ_CTR(uarch10); READ_CTR(uarch11);
48 READ_CTR(uarch12); READ_CTR(uarch13); READ_CTR(uarch14); READ_CTR(uarch15);
49 #undef READ_CTR
50 return 0;
51 }
52
53 static void tohost_exit(int code)
54 {
55 write_csr(tohost, (code << 1) | 1);
56 while (1);
57 }
58
59 long handle_trap(long cause, long epc, long regs[32])
60 {
61 int csr_insn;
62 asm volatile ("lw %0, 1f; j 2f; 1: csrr v0, uarch0; 2:" : "=r"(csr_insn));
63 long sys_ret = 0;
64
65 if (cause == CAUSE_ILLEGAL_INSTRUCTION &&
66 (*(int*)epc & csr_insn) == csr_insn)
67 ;
68 else if (cause != CAUSE_SYSCALL)
69 tohost_exit(1337);
70 else if (regs[16] == SYS_exit)
71 tohost_exit(regs[18]);
72 else if (regs[16] == SYS_stats)
73 sys_ret = handle_stats(regs[18]);
74 else
75 sys_ret = handle_frontend_syscall(regs[16], regs[18], regs[19], regs[20]);
76
77 regs[16] = sys_ret;
78 return epc+4;
79 }
80
81 static long syscall(long num, long arg0, long arg1, long arg2)
82 {
83 register long v0 asm("v0") = num;
84 register long a0 asm("a0") = arg0;
85 register long a1 asm("a1") = arg1;
86 register long a2 asm("a2") = arg2;
87 asm volatile ("scall" : "+r"(v0) : "r"(a0), "r"(a1), "r"(a2) : "s0");
88 return v0;
89 }
90
91 void exit(int code)
92 {
93 syscall(SYS_exit, code, 0, 0);
94 }
95
96 void setStats(int enable)
97 {
98 syscall(SYS_stats, enable, 0, 0);
99 }
100
101 void printstr(const char* s)
102 {
103 syscall(SYS_write, 1, (long)s, strlen(s));
104 }
105
106 void __attribute__((weak)) thread_entry(int cid, int nc)
107 {
108 // multi-threaded programs override this function.
109 // for the case of single-threaded programs, only let core 0 proceed.
110 while (cid != 0);
111 }
112
113 int __attribute__((weak)) main(int argc, char** argv)
114 {
115 // single-threaded programs override this function.
116 printstr("Implement main(), foo!\n");
117 return -1;
118 }
119
120 void _init(int cid, int nc)
121 {
122 thread_entry(cid, nc);
123
124 // only single-threaded programs should ever get here.
125 int ret = main(0, 0);
126
127 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
128 char* pbuf = buf;
129 for (int i = 0; i < NUM_COUNTERS; i++)
130 if (counters[i])
131 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
132 if (pbuf != buf)
133 printstr(buf);
134
135 exit(ret);
136 }
137
138 #undef putchar
139 int putchar(int ch)
140 {
141 static char buf[64] __attribute__((aligned(64)));
142 static int buflen = 0;
143
144 buf[buflen++] = ch;
145
146 if (ch == '\n' || buflen == sizeof(buf))
147 {
148 syscall(SYS_write, 1, (long)buf, buflen);
149 buflen = 0;
150 }
151
152 return 0;
153 }
154
155 void printhex(uint64_t x)
156 {
157 char str[17];
158 int i;
159 for (i = 0; i < 16; i++)
160 {
161 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
162 x >>= 4;
163 }
164 str[16] = 0;
165
166 printstr(str);
167 }
168
169 static inline void printnum(void (*putch)(int, void**), void **putdat,
170 unsigned long long num, unsigned base, int width, int padc)
171 {
172 unsigned digs[sizeof(num)*CHAR_BIT];
173 int pos = 0;
174
175 while (1)
176 {
177 digs[pos++] = num % base;
178 if (num < base)
179 break;
180 num /= base;
181 }
182
183 while (width-- > pos)
184 putch(padc, putdat);
185
186 while (pos-- > 0)
187 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
188 }
189
190 static unsigned long long getuint(va_list *ap, int lflag)
191 {
192 if (lflag >= 2)
193 return va_arg(*ap, unsigned long long);
194 else if (lflag)
195 return va_arg(*ap, unsigned long);
196 else
197 return va_arg(*ap, unsigned int);
198 }
199
200 static long long getint(va_list *ap, int lflag)
201 {
202 if (lflag >= 2)
203 return va_arg(*ap, long long);
204 else if (lflag)
205 return va_arg(*ap, long);
206 else
207 return va_arg(*ap, int);
208 }
209
210 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
211 {
212 register const char* p;
213 const char* last_fmt;
214 register int ch, err;
215 unsigned long long num;
216 int base, lflag, width, precision, altflag;
217 char padc;
218
219 while (1) {
220 while ((ch = *(unsigned char *) fmt) != '%') {
221 if (ch == '\0')
222 return;
223 fmt++;
224 putch(ch, putdat);
225 }
226 fmt++;
227
228 // Process a %-escape sequence
229 last_fmt = fmt;
230 padc = ' ';
231 width = -1;
232 precision = -1;
233 lflag = 0;
234 altflag = 0;
235 reswitch:
236 switch (ch = *(unsigned char *) fmt++) {
237
238 // flag to pad on the right
239 case '-':
240 padc = '-';
241 goto reswitch;
242
243 // flag to pad with 0's instead of spaces
244 case '0':
245 padc = '0';
246 goto reswitch;
247
248 // width field
249 case '1':
250 case '2':
251 case '3':
252 case '4':
253 case '5':
254 case '6':
255 case '7':
256 case '8':
257 case '9':
258 for (precision = 0; ; ++fmt) {
259 precision = precision * 10 + ch - '0';
260 ch = *fmt;
261 if (ch < '0' || ch > '9')
262 break;
263 }
264 goto process_precision;
265
266 case '*':
267 precision = va_arg(ap, int);
268 goto process_precision;
269
270 case '.':
271 if (width < 0)
272 width = 0;
273 goto reswitch;
274
275 case '#':
276 altflag = 1;
277 goto reswitch;
278
279 process_precision:
280 if (width < 0)
281 width = precision, precision = -1;
282 goto reswitch;
283
284 // long flag (doubled for long long)
285 case 'l':
286 lflag++;
287 goto reswitch;
288
289 // character
290 case 'c':
291 putch(va_arg(ap, int), putdat);
292 break;
293
294 // string
295 case 's':
296 if ((p = va_arg(ap, char *)) == NULL)
297 p = "(null)";
298 if (width > 0 && padc != '-')
299 for (width -= strnlen(p, precision); width > 0; width--)
300 putch(padc, putdat);
301 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
302 putch(ch, putdat);
303 p++;
304 }
305 for (; width > 0; width--)
306 putch(' ', putdat);
307 break;
308
309 // (signed) decimal
310 case 'd':
311 num = getint(&ap, lflag);
312 if ((long long) num < 0) {
313 putch('-', putdat);
314 num = -(long long) num;
315 }
316 base = 10;
317 goto signed_number;
318
319 // unsigned decimal
320 case 'u':
321 base = 10;
322 goto unsigned_number;
323
324 // (unsigned) octal
325 case 'o':
326 // should do something with padding so it's always 3 octits
327 base = 8;
328 goto unsigned_number;
329
330 // pointer
331 case 'p':
332 static_assert(sizeof(long) == sizeof(void*));
333 lflag = 1;
334 putch('0', putdat);
335 putch('x', putdat);
336 /* fall through to 'x' */
337
338 // (unsigned) hexadecimal
339 case 'x':
340 base = 16;
341 unsigned_number:
342 num = getuint(&ap, lflag);
343 signed_number:
344 printnum(putch, putdat, num, base, width, padc);
345 break;
346
347 // escaped '%' character
348 case '%':
349 putch(ch, putdat);
350 break;
351
352 // unrecognized escape sequence - just print it literally
353 default:
354 putch('%', putdat);
355 fmt = last_fmt;
356 break;
357 }
358 }
359 }
360
361 int printf(const char* fmt, ...)
362 {
363 va_list ap;
364 va_start(ap, fmt);
365
366 vprintfmt((void*)putchar, 0, fmt, ap);
367
368 va_end(ap);
369 return 0; // incorrect return value, but who cares, anyway?
370 }
371
372 int sprintf(char* str, const char* fmt, ...)
373 {
374 va_list ap;
375 char* str0 = str;
376 va_start(ap, fmt);
377
378 void sprintf_putch(int ch, void** data)
379 {
380 char** pstr = (char**)data;
381 **pstr = ch;
382 (*pstr)++;
383 }
384
385 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
386 *str = 0;
387
388 va_end(ap);
389 return str - str0;
390 }