setStats in benchmarks now should set and unset the stats register. Also, removed...
[riscv-tests.git] / benchmarks / common / syscalls.c
1 #include <stdint.h>
2 #include <string.h>
3 #include <stdarg.h>
4 #include <stdio.h>
5 #include <limits.h>
6 #include <machine/syscall.h>
7 #include "encoding.h"
8
9 #define SYS_stats 1234
10 #define static_assert(cond) switch(0) { case 0: case !!(long)(cond): ; }
11
12 static long handle_frontend_syscall(long which, long arg0, long arg1, long arg2)
13 {
14 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
15 magic_mem[0] = which;
16 magic_mem[1] = arg0;
17 magic_mem[2] = arg1;
18 magic_mem[3] = arg2;
19 __sync_synchronize();
20 write_csr(tohost, (long)magic_mem);
21 while (swap_csr(fromhost, 0) == 0);
22 return magic_mem[0];
23 }
24
25 // In setStats, we might trap reading uarch-specific counters.
26 // The trap handler will skip over the instruction and write 0,
27 // but only if v0 is the destination register.
28 #define read_csr_safe(reg) ({ register long __tmp asm("v0"); \
29 asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \
30 __tmp; })
31
32 #define NUM_COUNTERS 18
33 static long counters[NUM_COUNTERS];
34 static char* counter_names[NUM_COUNTERS];
35 static int handle_stats(int enable)
36 {
37 //use csrs to set stats register
38 if(enable) {
39 asm volatile (R"(
40 addi v0, x0, 1
41 csrrs v0, stats, v0
42 )" : : : "v0");
43 }
44 int i = 0;
45 #define READ_CTR(name) do { \
46 while (i >= NUM_COUNTERS) ; \
47 long csr = read_csr_safe(name); \
48 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
49 counters[i++] = csr; \
50 } while (0)
51 READ_CTR(cycle); READ_CTR(instret);
52 READ_CTR(uarch0); READ_CTR(uarch1); READ_CTR(uarch2); READ_CTR(uarch3);
53 READ_CTR(uarch4); READ_CTR(uarch5); READ_CTR(uarch6); READ_CTR(uarch7);
54 READ_CTR(uarch8); READ_CTR(uarch9); READ_CTR(uarch10); READ_CTR(uarch11);
55 READ_CTR(uarch12); READ_CTR(uarch13); READ_CTR(uarch14); READ_CTR(uarch15);
56 #undef READ_CTR
57 if(!enable) {
58 asm volatile (R"(
59 addi v0, x0, 1
60 csrrc v0, stats, v0
61 )" : : : "v0");
62 }
63 return 0;
64 }
65
66 static void tohost_exit(int code)
67 {
68 write_csr(tohost, (code << 1) | 1);
69 while (1);
70 }
71
72 long handle_trap(long cause, long epc, long regs[32])
73 {
74 int csr_insn;
75 asm volatile ("lw %0, 1f; j 2f; 1: csrr v0, stats; 2:" : "=r"(csr_insn));
76 long sys_ret = 0;
77
78 if (cause == CAUSE_ILLEGAL_INSTRUCTION &&
79 (*(int*)epc & csr_insn) == csr_insn)
80 ;
81 else if (cause != CAUSE_SYSCALL)
82 tohost_exit(1337);
83 else if (regs[16] == SYS_exit)
84 tohost_exit(regs[18]);
85 else if (regs[16] == SYS_stats)
86 sys_ret = handle_stats(regs[18]);
87 else
88 sys_ret = handle_frontend_syscall(regs[16], regs[18], regs[19], regs[20]);
89
90 regs[16] = sys_ret;
91 return epc+4;
92 }
93
94 static long syscall(long num, long arg0, long arg1, long arg2)
95 {
96 register long v0 asm("v0") = num;
97 register long a0 asm("a0") = arg0;
98 register long a1 asm("a1") = arg1;
99 register long a2 asm("a2") = arg2;
100 asm volatile ("scall" : "+r"(v0) : "r"(a0), "r"(a1), "r"(a2) : "s0");
101 return v0;
102 }
103
104 void exit(int code)
105 {
106 syscall(SYS_exit, code, 0, 0);
107 }
108
109 void setStats(int enable)
110 {
111 syscall(SYS_stats, enable, 0, 0);
112 }
113
114 void printstr(const char* s)
115 {
116 syscall(SYS_write, 1, (long)s, strlen(s));
117 }
118
119 void __attribute__((weak)) thread_entry(int cid, int nc)
120 {
121 // multi-threaded programs override this function.
122 // for the case of single-threaded programs, only let core 0 proceed.
123 while (cid != 0);
124 }
125
126 int __attribute__((weak)) main(int argc, char** argv)
127 {
128 // single-threaded programs override this function.
129 printstr("Implement main(), foo!\n");
130 return -1;
131 }
132
133 void _init(int cid, int nc)
134 {
135 thread_entry(cid, nc);
136
137 // only single-threaded programs should ever get here.
138 int ret = main(0, 0);
139
140 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
141 char* pbuf = buf;
142 for (int i = 0; i < NUM_COUNTERS; i++)
143 if (counters[i])
144 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
145 if (pbuf != buf)
146 printstr(buf);
147
148 exit(ret);
149 }
150
151 #undef putchar
152 int putchar(int ch)
153 {
154 static char buf[64] __attribute__((aligned(64)));
155 static int buflen = 0;
156
157 buf[buflen++] = ch;
158
159 if (ch == '\n' || buflen == sizeof(buf))
160 {
161 syscall(SYS_write, 1, (long)buf, buflen);
162 buflen = 0;
163 }
164
165 return 0;
166 }
167
168 void printhex(uint64_t x)
169 {
170 char str[17];
171 int i;
172 for (i = 0; i < 16; i++)
173 {
174 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
175 x >>= 4;
176 }
177 str[16] = 0;
178
179 printstr(str);
180 }
181
182 static inline void printnum(void (*putch)(int, void**), void **putdat,
183 unsigned long long num, unsigned base, int width, int padc)
184 {
185 unsigned digs[sizeof(num)*CHAR_BIT];
186 int pos = 0;
187
188 while (1)
189 {
190 digs[pos++] = num % base;
191 if (num < base)
192 break;
193 num /= base;
194 }
195
196 while (width-- > pos)
197 putch(padc, putdat);
198
199 while (pos-- > 0)
200 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
201 }
202
203 static unsigned long long getuint(va_list *ap, int lflag)
204 {
205 if (lflag >= 2)
206 return va_arg(*ap, unsigned long long);
207 else if (lflag)
208 return va_arg(*ap, unsigned long);
209 else
210 return va_arg(*ap, unsigned int);
211 }
212
213 static long long getint(va_list *ap, int lflag)
214 {
215 if (lflag >= 2)
216 return va_arg(*ap, long long);
217 else if (lflag)
218 return va_arg(*ap, long);
219 else
220 return va_arg(*ap, int);
221 }
222
223 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
224 {
225 register const char* p;
226 const char* last_fmt;
227 register int ch, err;
228 unsigned long long num;
229 int base, lflag, width, precision, altflag;
230 char padc;
231
232 while (1) {
233 while ((ch = *(unsigned char *) fmt) != '%') {
234 if (ch == '\0')
235 return;
236 fmt++;
237 putch(ch, putdat);
238 }
239 fmt++;
240
241 // Process a %-escape sequence
242 last_fmt = fmt;
243 padc = ' ';
244 width = -1;
245 precision = -1;
246 lflag = 0;
247 altflag = 0;
248 reswitch:
249 switch (ch = *(unsigned char *) fmt++) {
250
251 // flag to pad on the right
252 case '-':
253 padc = '-';
254 goto reswitch;
255
256 // flag to pad with 0's instead of spaces
257 case '0':
258 padc = '0';
259 goto reswitch;
260
261 // width field
262 case '1':
263 case '2':
264 case '3':
265 case '4':
266 case '5':
267 case '6':
268 case '7':
269 case '8':
270 case '9':
271 for (precision = 0; ; ++fmt) {
272 precision = precision * 10 + ch - '0';
273 ch = *fmt;
274 if (ch < '0' || ch > '9')
275 break;
276 }
277 goto process_precision;
278
279 case '*':
280 precision = va_arg(ap, int);
281 goto process_precision;
282
283 case '.':
284 if (width < 0)
285 width = 0;
286 goto reswitch;
287
288 case '#':
289 altflag = 1;
290 goto reswitch;
291
292 process_precision:
293 if (width < 0)
294 width = precision, precision = -1;
295 goto reswitch;
296
297 // long flag (doubled for long long)
298 case 'l':
299 lflag++;
300 goto reswitch;
301
302 // character
303 case 'c':
304 putch(va_arg(ap, int), putdat);
305 break;
306
307 // string
308 case 's':
309 if ((p = va_arg(ap, char *)) == NULL)
310 p = "(null)";
311 if (width > 0 && padc != '-')
312 for (width -= strnlen(p, precision); width > 0; width--)
313 putch(padc, putdat);
314 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
315 putch(ch, putdat);
316 p++;
317 }
318 for (; width > 0; width--)
319 putch(' ', putdat);
320 break;
321
322 // (signed) decimal
323 case 'd':
324 num = getint(&ap, lflag);
325 if ((long long) num < 0) {
326 putch('-', putdat);
327 num = -(long long) num;
328 }
329 base = 10;
330 goto signed_number;
331
332 // unsigned decimal
333 case 'u':
334 base = 10;
335 goto unsigned_number;
336
337 // (unsigned) octal
338 case 'o':
339 // should do something with padding so it's always 3 octits
340 base = 8;
341 goto unsigned_number;
342
343 // pointer
344 case 'p':
345 static_assert(sizeof(long) == sizeof(void*));
346 lflag = 1;
347 putch('0', putdat);
348 putch('x', putdat);
349 /* fall through to 'x' */
350
351 // (unsigned) hexadecimal
352 case 'x':
353 base = 16;
354 unsigned_number:
355 num = getuint(&ap, lflag);
356 signed_number:
357 printnum(putch, putdat, num, base, width, padc);
358 break;
359
360 // escaped '%' character
361 case '%':
362 putch(ch, putdat);
363 break;
364
365 // unrecognized escape sequence - just print it literally
366 default:
367 putch('%', putdat);
368 fmt = last_fmt;
369 break;
370 }
371 }
372 }
373
374 int printf(const char* fmt, ...)
375 {
376 va_list ap;
377 va_start(ap, fmt);
378
379 vprintfmt((void*)putchar, 0, fmt, ap);
380
381 va_end(ap);
382 return 0; // incorrect return value, but who cares, anyway?
383 }
384
385 int sprintf(char* str, const char* fmt, ...)
386 {
387 va_list ap;
388 char* str0 = str;
389 va_start(ap, fmt);
390
391 void sprintf_putch(int ch, void** data)
392 {
393 char** pstr = (char**)data;
394 **pstr = ch;
395 (*pstr)++;
396 }
397
398 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
399 *str = 0;
400
401 va_end(ap);
402 return str - str0;
403 }