a168ebf6ae5c86c91eb0e7028089af6bb3008c7a
[riscv-tests.git] / benchmarks / common / syscalls.c
1 #include <stdint.h>
2 #include <string.h>
3 #include <stdarg.h>
4 #include <stdio.h>
5 #include <limits.h>
6 #include <machine/syscall.h>
7 #include "util.h"
8
9 #define SYS_stats 1234
10
11 static long handle_frontend_syscall(long which, long arg0, long arg1, long arg2)
12 {
13 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
14 magic_mem[0] = which;
15 magic_mem[1] = arg0;
16 magic_mem[2] = arg1;
17 magic_mem[3] = arg2;
18 __sync_synchronize();
19 write_csr(tohost, (long)magic_mem);
20 while (swap_csr(fromhost, 0) == 0);
21 return magic_mem[0];
22 }
23
24 // In setStats, we might trap reading uarch-specific counters.
25 // The trap handler will skip over the instruction and write 0,
26 // but only if v0 is the destination register.
27 #define read_csr_safe(reg) ({ register long __tmp asm("v0"); \
28 asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \
29 __tmp; })
30
31 #define NUM_COUNTERS 18
32 static long counters[NUM_COUNTERS];
33 static char* counter_names[NUM_COUNTERS];
34 static int handle_stats(int enable)
35 {
36 //use csrs to set stats register
37 if(enable) {
38 asm volatile (R"(
39 addi v0, x0, 1
40 csrrs v0, stats, v0
41 )" : : : "v0");
42 }
43 int i = 0;
44 #define READ_CTR(name) do { \
45 while (i >= NUM_COUNTERS) ; \
46 long csr = read_csr_safe(name); \
47 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
48 counters[i++] = csr; \
49 } while (0)
50 READ_CTR(cycle); READ_CTR(instret);
51 READ_CTR(uarch0); READ_CTR(uarch1); READ_CTR(uarch2); READ_CTR(uarch3);
52 READ_CTR(uarch4); READ_CTR(uarch5); READ_CTR(uarch6); READ_CTR(uarch7);
53 READ_CTR(uarch8); READ_CTR(uarch9); READ_CTR(uarch10); READ_CTR(uarch11);
54 READ_CTR(uarch12); READ_CTR(uarch13); READ_CTR(uarch14); READ_CTR(uarch15);
55 #undef READ_CTR
56 if(!enable) {
57 asm volatile (R"(
58 addi v0, x0, 1
59 csrrc v0, stats, v0
60 )" : : : "v0");
61 }
62 return 0;
63 }
64
65 static void tohost_exit(int code)
66 {
67 write_csr(tohost, (code << 1) | 1);
68 while (1);
69 }
70
71 long handle_trap(long cause, long epc, long regs[32])
72 {
73 int csr_insn;
74 asm volatile ("lw %0, 1f; j 2f; 1: csrr v0, stats; 2:" : "=r"(csr_insn));
75 long sys_ret = 0;
76
77 if (cause == CAUSE_ILLEGAL_INSTRUCTION &&
78 (*(int*)epc & csr_insn) == csr_insn)
79 ;
80 else if (cause != CAUSE_SYSCALL)
81 tohost_exit(1337);
82 else if (regs[16] == SYS_exit)
83 tohost_exit(regs[18]);
84 else if (regs[16] == SYS_stats)
85 sys_ret = handle_stats(regs[18]);
86 else
87 sys_ret = handle_frontend_syscall(regs[16], regs[18], regs[19], regs[20]);
88
89 regs[16] = sys_ret;
90 return epc+4;
91 }
92
93 static long syscall(long num, long arg0, long arg1, long arg2)
94 {
95 register long v0 asm("v0") = num;
96 register long a0 asm("a0") = arg0;
97 register long a1 asm("a1") = arg1;
98 register long a2 asm("a2") = arg2;
99 asm volatile ("scall" : "+r"(v0) : "r"(a0), "r"(a1), "r"(a2) : "s0");
100 return v0;
101 }
102
103 void exit(int code)
104 {
105 syscall(SYS_exit, code, 0, 0);
106 }
107
108 void setStats(int enable)
109 {
110 syscall(SYS_stats, enable, 0, 0);
111 }
112
113 void printstr(const char* s)
114 {
115 syscall(SYS_write, 1, (long)s, strlen(s));
116 }
117
118 void __attribute__((weak)) thread_entry(int cid, int nc)
119 {
120 // multi-threaded programs override this function.
121 // for the case of single-threaded programs, only let core 0 proceed.
122 while (cid != 0);
123 }
124
125 int __attribute__((weak)) main(int argc, char** argv)
126 {
127 // single-threaded programs override this function.
128 printstr("Implement main(), foo!\n");
129 return -1;
130 }
131
132 void _init(int cid, int nc)
133 {
134 thread_entry(cid, nc);
135
136 // only single-threaded programs should ever get here.
137 int ret = main(0, 0);
138
139 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
140 char* pbuf = buf;
141 for (int i = 0; i < NUM_COUNTERS; i++)
142 if (counters[i])
143 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
144 if (pbuf != buf)
145 printstr(buf);
146
147 exit(ret);
148 }
149
150 #undef putchar
151 int putchar(int ch)
152 {
153 static char buf[64] __attribute__((aligned(64)));
154 static int buflen = 0;
155
156 buf[buflen++] = ch;
157
158 if (ch == '\n' || buflen == sizeof(buf))
159 {
160 syscall(SYS_write, 1, (long)buf, buflen);
161 buflen = 0;
162 }
163
164 return 0;
165 }
166
167 void printhex(uint64_t x)
168 {
169 char str[17];
170 int i;
171 for (i = 0; i < 16; i++)
172 {
173 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
174 x >>= 4;
175 }
176 str[16] = 0;
177
178 printstr(str);
179 }
180
181 static inline void printnum(void (*putch)(int, void**), void **putdat,
182 unsigned long long num, unsigned base, int width, int padc)
183 {
184 unsigned digs[sizeof(num)*CHAR_BIT];
185 int pos = 0;
186
187 while (1)
188 {
189 digs[pos++] = num % base;
190 if (num < base)
191 break;
192 num /= base;
193 }
194
195 while (width-- > pos)
196 putch(padc, putdat);
197
198 while (pos-- > 0)
199 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
200 }
201
202 static unsigned long long getuint(va_list *ap, int lflag)
203 {
204 if (lflag >= 2)
205 return va_arg(*ap, unsigned long long);
206 else if (lflag)
207 return va_arg(*ap, unsigned long);
208 else
209 return va_arg(*ap, unsigned int);
210 }
211
212 static long long getint(va_list *ap, int lflag)
213 {
214 if (lflag >= 2)
215 return va_arg(*ap, long long);
216 else if (lflag)
217 return va_arg(*ap, long);
218 else
219 return va_arg(*ap, int);
220 }
221
222 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
223 {
224 register const char* p;
225 const char* last_fmt;
226 register int ch, err;
227 unsigned long long num;
228 int base, lflag, width, precision, altflag;
229 char padc;
230
231 while (1) {
232 while ((ch = *(unsigned char *) fmt) != '%') {
233 if (ch == '\0')
234 return;
235 fmt++;
236 putch(ch, putdat);
237 }
238 fmt++;
239
240 // Process a %-escape sequence
241 last_fmt = fmt;
242 padc = ' ';
243 width = -1;
244 precision = -1;
245 lflag = 0;
246 altflag = 0;
247 reswitch:
248 switch (ch = *(unsigned char *) fmt++) {
249
250 // flag to pad on the right
251 case '-':
252 padc = '-';
253 goto reswitch;
254
255 // flag to pad with 0's instead of spaces
256 case '0':
257 padc = '0';
258 goto reswitch;
259
260 // width field
261 case '1':
262 case '2':
263 case '3':
264 case '4':
265 case '5':
266 case '6':
267 case '7':
268 case '8':
269 case '9':
270 for (precision = 0; ; ++fmt) {
271 precision = precision * 10 + ch - '0';
272 ch = *fmt;
273 if (ch < '0' || ch > '9')
274 break;
275 }
276 goto process_precision;
277
278 case '*':
279 precision = va_arg(ap, int);
280 goto process_precision;
281
282 case '.':
283 if (width < 0)
284 width = 0;
285 goto reswitch;
286
287 case '#':
288 altflag = 1;
289 goto reswitch;
290
291 process_precision:
292 if (width < 0)
293 width = precision, precision = -1;
294 goto reswitch;
295
296 // long flag (doubled for long long)
297 case 'l':
298 lflag++;
299 goto reswitch;
300
301 // character
302 case 'c':
303 putch(va_arg(ap, int), putdat);
304 break;
305
306 // string
307 case 's':
308 if ((p = va_arg(ap, char *)) == NULL)
309 p = "(null)";
310 if (width > 0 && padc != '-')
311 for (width -= strnlen(p, precision); width > 0; width--)
312 putch(padc, putdat);
313 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
314 putch(ch, putdat);
315 p++;
316 }
317 for (; width > 0; width--)
318 putch(' ', putdat);
319 break;
320
321 // (signed) decimal
322 case 'd':
323 num = getint(&ap, lflag);
324 if ((long long) num < 0) {
325 putch('-', putdat);
326 num = -(long long) num;
327 }
328 base = 10;
329 goto signed_number;
330
331 // unsigned decimal
332 case 'u':
333 base = 10;
334 goto unsigned_number;
335
336 // (unsigned) octal
337 case 'o':
338 // should do something with padding so it's always 3 octits
339 base = 8;
340 goto unsigned_number;
341
342 // pointer
343 case 'p':
344 static_assert(sizeof(long) == sizeof(void*));
345 lflag = 1;
346 putch('0', putdat);
347 putch('x', putdat);
348 /* fall through to 'x' */
349
350 // (unsigned) hexadecimal
351 case 'x':
352 base = 16;
353 unsigned_number:
354 num = getuint(&ap, lflag);
355 signed_number:
356 printnum(putch, putdat, num, base, width, padc);
357 break;
358
359 // escaped '%' character
360 case '%':
361 putch(ch, putdat);
362 break;
363
364 // unrecognized escape sequence - just print it literally
365 default:
366 putch('%', putdat);
367 fmt = last_fmt;
368 break;
369 }
370 }
371 }
372
373 int printf(const char* fmt, ...)
374 {
375 va_list ap;
376 va_start(ap, fmt);
377
378 vprintfmt((void*)putchar, 0, fmt, ap);
379
380 va_end(ap);
381 return 0; // incorrect return value, but who cares, anyway?
382 }
383
384 int sprintf(char* str, const char* fmt, ...)
385 {
386 va_list ap;
387 char* str0 = str;
388 va_start(ap, fmt);
389
390 void sprintf_putch(int ch, void** data)
391 {
392 char** pstr = (char**)data;
393 **pstr = ch;
394 (*pstr)++;
395 }
396
397 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
398 *str = 0;
399
400 va_end(ap);
401 return str - str0;
402 }