Enable LR/SC tests, even for uniprocessors
[riscv-tests.git] / benchmarks / common / syscalls.c
1 // See LICENSE for license details.
2
3 #include <stdint.h>
4 #include <string.h>
5 #include <stdarg.h>
6 #include <stdio.h>
7 #include <limits.h>
8 #include "util.h"
9
10 #define SYS_write 64
11 #define SYS_exit 93
12 #define SYS_stats 1234
13
14 // initialized in crt.S
15 int have_vec;
16
17 volatile uint64_t tohost __attribute__((aligned(64)));
18 volatile uint64_t fromhost __attribute__((aligned(64)));
19
20 static long handle_frontend_syscall(long which, long arg0, long arg1, long arg2)
21 {
22 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
23 magic_mem[0] = which;
24 magic_mem[1] = arg0;
25 magic_mem[2] = arg1;
26 magic_mem[3] = arg2;
27 __sync_synchronize();
28
29 tohost = (uintptr_t)magic_mem;
30 while (fromhost == 0)
31 ;
32 fromhost = 0;
33
34 __sync_synchronize();
35 return magic_mem[0];
36 }
37
38 // In setStats, we might trap reading uarch-specific counters.
39 // The trap handler will skip over the instruction and write 0,
40 // but only if a0 is the destination register.
41 #define read_csr_safe(reg) ({ register long __tmp asm("a0"); \
42 asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \
43 __tmp; })
44
45 #define NUM_COUNTERS 18
46 static long counters[NUM_COUNTERS];
47 static char* counter_names[NUM_COUNTERS];
48 static int handle_stats(int enable)
49 {
50 int i = 0;
51 #define READ_CTR(name) do { \
52 while (i >= NUM_COUNTERS) ; \
53 long csr = read_csr_safe(name); \
54 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
55 counters[i++] = csr; \
56 } while (0)
57 READ_CTR(mcycle); READ_CTR(minstret);
58 READ_CTR(0xcc0); READ_CTR(0xcc1); READ_CTR(0xcc2); READ_CTR(0xcc3);
59 READ_CTR(0xcc4); READ_CTR(0xcc5); READ_CTR(0xcc6); READ_CTR(0xcc7);
60 READ_CTR(0xcc8); READ_CTR(0xcc9); READ_CTR(0xcca); READ_CTR(0xccb);
61 READ_CTR(0xccc); READ_CTR(0xccd); READ_CTR(0xcce); READ_CTR(0xccf);
62 #undef READ_CTR
63 return 0;
64 }
65
66 void tohost_exit(long code)
67 {
68 tohost = (code << 1) | 1;
69 while (1);
70 }
71
72 long handle_trap(long cause, long epc, long regs[32])
73 {
74 int* csr_insn;
75 asm ("jal %0, 1f; csrr a0, 0xcc0; 1:" : "=r"(csr_insn));
76 long sys_ret = 0;
77
78 if (cause == CAUSE_ILLEGAL_INSTRUCTION &&
79 (*(int*)epc & *csr_insn) == *csr_insn)
80 ;
81 else if (cause != CAUSE_MACHINE_ECALL)
82 tohost_exit(1337);
83 else if (regs[17] == SYS_exit)
84 tohost_exit(regs[10]);
85 else if (regs[17] == SYS_stats)
86 sys_ret = handle_stats(regs[10]);
87 else
88 sys_ret = handle_frontend_syscall(regs[17], regs[10], regs[11], regs[12]);
89
90 regs[10] = sys_ret;
91 return epc+4;
92 }
93
94 static long syscall(long num, long arg0, long arg1, long arg2)
95 {
96 register long a7 asm("a7") = num;
97 register long a0 asm("a0") = arg0;
98 register long a1 asm("a1") = arg1;
99 register long a2 asm("a2") = arg2;
100 asm volatile ("scall" : "+r"(a0) : "r"(a1), "r"(a2), "r"(a7));
101 return a0;
102 }
103
104 void exit(int code)
105 {
106 syscall(SYS_exit, code, 0, 0);
107 while (1);
108 }
109
110 void setStats(int enable)
111 {
112 syscall(SYS_stats, enable, 0, 0);
113 }
114
115 void printstr(const char* s)
116 {
117 syscall(SYS_write, 1, (long)s, strlen(s));
118 }
119
120 void __attribute__((weak)) thread_entry(int cid, int nc)
121 {
122 // multi-threaded programs override this function.
123 // for the case of single-threaded programs, only let core 0 proceed.
124 while (cid != 0);
125 }
126
127 int __attribute__((weak)) main(int argc, char** argv)
128 {
129 // single-threaded programs override this function.
130 printstr("Implement main(), foo!\n");
131 return -1;
132 }
133
134 static void init_tls()
135 {
136 register void* thread_pointer asm("tp");
137 extern char _tls_data;
138 extern __thread char _tdata_begin, _tdata_end, _tbss_end;
139 size_t tdata_size = &_tdata_end - &_tdata_begin;
140 memcpy(thread_pointer, &_tls_data, tdata_size);
141 size_t tbss_size = &_tbss_end - &_tdata_end;
142 memset(thread_pointer + tdata_size, 0, tbss_size);
143 }
144
145 void _init(int cid, int nc)
146 {
147 init_tls();
148 thread_entry(cid, nc);
149
150 // only single-threaded programs should ever get here.
151 int ret = main(0, 0);
152
153 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
154 char* pbuf = buf;
155 for (int i = 0; i < NUM_COUNTERS; i++)
156 if (counters[i])
157 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
158 if (pbuf != buf)
159 printstr(buf);
160
161 exit(ret);
162 }
163
164 #undef putchar
165 int putchar(int ch)
166 {
167 static __thread char buf[64] __attribute__((aligned(64)));
168 static __thread int buflen = 0;
169
170 buf[buflen++] = ch;
171
172 if (ch == '\n' || buflen == sizeof(buf))
173 {
174 syscall(SYS_write, 1, (long)buf, buflen);
175 buflen = 0;
176 }
177
178 return 0;
179 }
180
181 void printhex(uint64_t x)
182 {
183 char str[17];
184 int i;
185 for (i = 0; i < 16; i++)
186 {
187 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
188 x >>= 4;
189 }
190 str[16] = 0;
191
192 printstr(str);
193 }
194
195 static inline void printnum(void (*putch)(int, void**), void **putdat,
196 unsigned long long num, unsigned base, int width, int padc)
197 {
198 unsigned digs[sizeof(num)*CHAR_BIT];
199 int pos = 0;
200
201 while (1)
202 {
203 digs[pos++] = num % base;
204 if (num < base)
205 break;
206 num /= base;
207 }
208
209 while (width-- > pos)
210 putch(padc, putdat);
211
212 while (pos-- > 0)
213 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
214 }
215
216 static unsigned long long getuint(va_list *ap, int lflag)
217 {
218 if (lflag >= 2)
219 return va_arg(*ap, unsigned long long);
220 else if (lflag)
221 return va_arg(*ap, unsigned long);
222 else
223 return va_arg(*ap, unsigned int);
224 }
225
226 static long long getint(va_list *ap, int lflag)
227 {
228 if (lflag >= 2)
229 return va_arg(*ap, long long);
230 else if (lflag)
231 return va_arg(*ap, long);
232 else
233 return va_arg(*ap, int);
234 }
235
236 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
237 {
238 register const char* p;
239 const char* last_fmt;
240 register int ch, err;
241 unsigned long long num;
242 int base, lflag, width, precision, altflag;
243 char padc;
244
245 while (1) {
246 while ((ch = *(unsigned char *) fmt) != '%') {
247 if (ch == '\0')
248 return;
249 fmt++;
250 putch(ch, putdat);
251 }
252 fmt++;
253
254 // Process a %-escape sequence
255 last_fmt = fmt;
256 padc = ' ';
257 width = -1;
258 precision = -1;
259 lflag = 0;
260 altflag = 0;
261 reswitch:
262 switch (ch = *(unsigned char *) fmt++) {
263
264 // flag to pad on the right
265 case '-':
266 padc = '-';
267 goto reswitch;
268
269 // flag to pad with 0's instead of spaces
270 case '0':
271 padc = '0';
272 goto reswitch;
273
274 // width field
275 case '1':
276 case '2':
277 case '3':
278 case '4':
279 case '5':
280 case '6':
281 case '7':
282 case '8':
283 case '9':
284 for (precision = 0; ; ++fmt) {
285 precision = precision * 10 + ch - '0';
286 ch = *fmt;
287 if (ch < '0' || ch > '9')
288 break;
289 }
290 goto process_precision;
291
292 case '*':
293 precision = va_arg(ap, int);
294 goto process_precision;
295
296 case '.':
297 if (width < 0)
298 width = 0;
299 goto reswitch;
300
301 case '#':
302 altflag = 1;
303 goto reswitch;
304
305 process_precision:
306 if (width < 0)
307 width = precision, precision = -1;
308 goto reswitch;
309
310 // long flag (doubled for long long)
311 case 'l':
312 lflag++;
313 goto reswitch;
314
315 // character
316 case 'c':
317 putch(va_arg(ap, int), putdat);
318 break;
319
320 // string
321 case 's':
322 if ((p = va_arg(ap, char *)) == NULL)
323 p = "(null)";
324 if (width > 0 && padc != '-')
325 for (width -= strnlen(p, precision); width > 0; width--)
326 putch(padc, putdat);
327 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
328 putch(ch, putdat);
329 p++;
330 }
331 for (; width > 0; width--)
332 putch(' ', putdat);
333 break;
334
335 // (signed) decimal
336 case 'd':
337 num = getint(&ap, lflag);
338 if ((long long) num < 0) {
339 putch('-', putdat);
340 num = -(long long) num;
341 }
342 base = 10;
343 goto signed_number;
344
345 // unsigned decimal
346 case 'u':
347 base = 10;
348 goto unsigned_number;
349
350 // (unsigned) octal
351 case 'o':
352 // should do something with padding so it's always 3 octits
353 base = 8;
354 goto unsigned_number;
355
356 // pointer
357 case 'p':
358 static_assert(sizeof(long) == sizeof(void*));
359 lflag = 1;
360 putch('0', putdat);
361 putch('x', putdat);
362 /* fall through to 'x' */
363
364 // (unsigned) hexadecimal
365 case 'x':
366 base = 16;
367 unsigned_number:
368 num = getuint(&ap, lflag);
369 signed_number:
370 printnum(putch, putdat, num, base, width, padc);
371 break;
372
373 // escaped '%' character
374 case '%':
375 putch(ch, putdat);
376 break;
377
378 // unrecognized escape sequence - just print it literally
379 default:
380 putch('%', putdat);
381 fmt = last_fmt;
382 break;
383 }
384 }
385 }
386
387 int printf(const char* fmt, ...)
388 {
389 va_list ap;
390 va_start(ap, fmt);
391
392 vprintfmt((void*)putchar, 0, fmt, ap);
393
394 va_end(ap);
395 return 0; // incorrect return value, but who cares, anyway?
396 }
397
398 int sprintf(char* str, const char* fmt, ...)
399 {
400 va_list ap;
401 char* str0 = str;
402 va_start(ap, fmt);
403
404 void sprintf_putch(int ch, void** data)
405 {
406 char** pstr = (char**)data;
407 **pstr = ch;
408 (*pstr)++;
409 }
410
411 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
412 *str = 0;
413
414 va_end(ap);
415 return str - str0;
416 }
417
418 void* memcpy(void* dest, const void* src, size_t len)
419 {
420 if ((((uintptr_t)dest | (uintptr_t)src | len) & (sizeof(uintptr_t)-1)) == 0) {
421 const uintptr_t* s = src;
422 uintptr_t *d = dest;
423 while (d < (uintptr_t*)(dest + len))
424 *d++ = *s++;
425 } else {
426 const char* s = src;
427 char *d = dest;
428 while (d < (char*)(dest + len))
429 *d++ = *s++;
430 }
431 return dest;
432 }
433
434 void* memset(void* dest, int byte, size_t len)
435 {
436 if ((((uintptr_t)dest | len) & (sizeof(uintptr_t)-1)) == 0) {
437 uintptr_t word = byte & 0xFF;
438 word |= word << 8;
439 word |= word << 16;
440 word |= word << 16 << 16;
441
442 uintptr_t *d = dest;
443 while (d < (uintptr_t*)(dest + len))
444 *d++ = word;
445 } else {
446 char *d = dest;
447 while (d < (char*)(dest + len))
448 *d++ = byte;
449 }
450 return dest;
451 }
452
453 size_t strlen(const char *s)
454 {
455 const char *p = s;
456 while (*p)
457 p++;
458 return p - s;
459 }
460
461 size_t strnlen(const char *s, size_t n)
462 {
463 const char *p = s;
464 while (n-- && *p)
465 p++;
466 return p - s;
467 }
468
469 int strcmp(const char* s1, const char* s2)
470 {
471 unsigned char c1, c2;
472
473 do {
474 c1 = *s1++;
475 c2 = *s2++;
476 } while (c1 != 0 && c1 == c2);
477
478 return c1 - c2;
479 }
480
481 char* strcpy(char* dest, const char* src)
482 {
483 char* d = dest;
484 while ((*d++ = *src++))
485 ;
486 return dest;
487 }
488
489 long atol(const char* str)
490 {
491 long res = 0;
492 int sign = 0;
493
494 while (*str == ' ')
495 str++;
496
497 if (*str == '-' || *str == '+') {
498 sign = *str == '-';
499 str++;
500 }
501
502 while (*str) {
503 res *= 10;
504 res += *str++ - '0';
505 }
506
507 return sign ? -res : res;
508 }