Merge remote-tracking branch 'origin/priv-1.10' into HEAD
[riscv-tests.git] / benchmarks / common / syscalls.c
1 // See LICENSE for license details.
2
3 #include <stdint.h>
4 #include <string.h>
5 #include <stdarg.h>
6 #include <stdio.h>
7 #include <limits.h>
8 #include <sys/signal.h>
9 #include "util.h"
10
11 #define SYS_write 64
12
13 #undef strcmp
14
15 extern volatile uint64_t tohost;
16 extern volatile uint64_t fromhost;
17
18 static uintptr_t syscall(uintptr_t which, uint64_t arg0, uint64_t arg1, uint64_t arg2)
19 {
20 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
21 magic_mem[0] = which;
22 magic_mem[1] = arg0;
23 magic_mem[2] = arg1;
24 magic_mem[3] = arg2;
25 __sync_synchronize();
26
27 tohost = (uintptr_t)magic_mem;
28 while (fromhost == 0)
29 ;
30 fromhost = 0;
31
32 __sync_synchronize();
33 return magic_mem[0];
34 }
35
36 #define NUM_COUNTERS 2
37 static uintptr_t counters[NUM_COUNTERS];
38 static char* counter_names[NUM_COUNTERS];
39
40 void setStats(int enable)
41 {
42 int i = 0;
43 #define READ_CTR(name) do { \
44 while (i >= NUM_COUNTERS) ; \
45 uintptr_t csr = read_csr(name); \
46 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
47 counters[i++] = csr; \
48 } while (0)
49
50 READ_CTR(mcycle);
51 READ_CTR(minstret);
52
53 #undef READ_CTR
54 }
55
56 void __attribute__((noreturn)) tohost_exit(uintptr_t code)
57 {
58 tohost = (code << 1) | 1;
59 while (1);
60 }
61
62 uintptr_t __attribute__((weak)) handle_trap(uintptr_t cause, uintptr_t epc, uintptr_t regs[32])
63 {
64 tohost_exit(1337);
65 }
66
67 void exit(int code)
68 {
69 tohost_exit(code);
70 }
71
72 void abort()
73 {
74 exit(128 + SIGABRT);
75 }
76
77 void printstr(const char* s)
78 {
79 syscall(SYS_write, 1, (uintptr_t)s, strlen(s));
80 }
81
82 void __attribute__((weak)) thread_entry(int cid, int nc)
83 {
84 // multi-threaded programs override this function.
85 // for the case of single-threaded programs, only let core 0 proceed.
86 while (cid != 0);
87 }
88
89 int __attribute__((weak)) main(int argc, char** argv)
90 {
91 // single-threaded programs override this function.
92 printstr("Implement main(), foo!\n");
93 return -1;
94 }
95
96 static void init_tls()
97 {
98 register void* thread_pointer asm("tp");
99 extern char _tls_data;
100 extern __thread char _tdata_begin, _tdata_end, _tbss_end;
101 size_t tdata_size = &_tdata_end - &_tdata_begin;
102 memcpy(thread_pointer, &_tls_data, tdata_size);
103 size_t tbss_size = &_tbss_end - &_tdata_end;
104 memset(thread_pointer + tdata_size, 0, tbss_size);
105 }
106
107 void _init(int cid, int nc)
108 {
109 init_tls();
110 thread_entry(cid, nc);
111
112 // only single-threaded programs should ever get here.
113 int ret = main(0, 0);
114
115 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
116 char* pbuf = buf;
117 for (int i = 0; i < NUM_COUNTERS; i++)
118 if (counters[i])
119 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
120 if (pbuf != buf)
121 printstr(buf);
122
123 exit(ret);
124 }
125
126 #undef putchar
127 int putchar(int ch)
128 {
129 static __thread char buf[64] __attribute__((aligned(64)));
130 static __thread int buflen = 0;
131
132 buf[buflen++] = ch;
133
134 if (ch == '\n' || buflen == sizeof(buf))
135 {
136 syscall(SYS_write, 1, (uintptr_t)buf, buflen);
137 buflen = 0;
138 }
139
140 return 0;
141 }
142
143 void printhex(uint64_t x)
144 {
145 char str[17];
146 int i;
147 for (i = 0; i < 16; i++)
148 {
149 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
150 x >>= 4;
151 }
152 str[16] = 0;
153
154 printstr(str);
155 }
156
157 static inline void printnum(void (*putch)(int, void**), void **putdat,
158 unsigned long long num, unsigned base, int width, int padc)
159 {
160 unsigned digs[sizeof(num)*CHAR_BIT];
161 int pos = 0;
162
163 while (1)
164 {
165 digs[pos++] = num % base;
166 if (num < base)
167 break;
168 num /= base;
169 }
170
171 while (width-- > pos)
172 putch(padc, putdat);
173
174 while (pos-- > 0)
175 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
176 }
177
178 static unsigned long long getuint(va_list *ap, int lflag)
179 {
180 if (lflag >= 2)
181 return va_arg(*ap, unsigned long long);
182 else if (lflag)
183 return va_arg(*ap, unsigned long);
184 else
185 return va_arg(*ap, unsigned int);
186 }
187
188 static long long getint(va_list *ap, int lflag)
189 {
190 if (lflag >= 2)
191 return va_arg(*ap, long long);
192 else if (lflag)
193 return va_arg(*ap, long);
194 else
195 return va_arg(*ap, int);
196 }
197
198 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
199 {
200 register const char* p;
201 const char* last_fmt;
202 register int ch, err;
203 unsigned long long num;
204 int base, lflag, width, precision, altflag;
205 char padc;
206
207 while (1) {
208 while ((ch = *(unsigned char *) fmt) != '%') {
209 if (ch == '\0')
210 return;
211 fmt++;
212 putch(ch, putdat);
213 }
214 fmt++;
215
216 // Process a %-escape sequence
217 last_fmt = fmt;
218 padc = ' ';
219 width = -1;
220 precision = -1;
221 lflag = 0;
222 altflag = 0;
223 reswitch:
224 switch (ch = *(unsigned char *) fmt++) {
225
226 // flag to pad on the right
227 case '-':
228 padc = '-';
229 goto reswitch;
230
231 // flag to pad with 0's instead of spaces
232 case '0':
233 padc = '0';
234 goto reswitch;
235
236 // width field
237 case '1':
238 case '2':
239 case '3':
240 case '4':
241 case '5':
242 case '6':
243 case '7':
244 case '8':
245 case '9':
246 for (precision = 0; ; ++fmt) {
247 precision = precision * 10 + ch - '0';
248 ch = *fmt;
249 if (ch < '0' || ch > '9')
250 break;
251 }
252 goto process_precision;
253
254 case '*':
255 precision = va_arg(ap, int);
256 goto process_precision;
257
258 case '.':
259 if (width < 0)
260 width = 0;
261 goto reswitch;
262
263 case '#':
264 altflag = 1;
265 goto reswitch;
266
267 process_precision:
268 if (width < 0)
269 width = precision, precision = -1;
270 goto reswitch;
271
272 // long flag (doubled for long long)
273 case 'l':
274 lflag++;
275 goto reswitch;
276
277 // character
278 case 'c':
279 putch(va_arg(ap, int), putdat);
280 break;
281
282 // string
283 case 's':
284 if ((p = va_arg(ap, char *)) == NULL)
285 p = "(null)";
286 if (width > 0 && padc != '-')
287 for (width -= strnlen(p, precision); width > 0; width--)
288 putch(padc, putdat);
289 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
290 putch(ch, putdat);
291 p++;
292 }
293 for (; width > 0; width--)
294 putch(' ', putdat);
295 break;
296
297 // (signed) decimal
298 case 'd':
299 num = getint(&ap, lflag);
300 if ((long long) num < 0) {
301 putch('-', putdat);
302 num = -(long long) num;
303 }
304 base = 10;
305 goto signed_number;
306
307 // unsigned decimal
308 case 'u':
309 base = 10;
310 goto unsigned_number;
311
312 // (unsigned) octal
313 case 'o':
314 // should do something with padding so it's always 3 octits
315 base = 8;
316 goto unsigned_number;
317
318 // pointer
319 case 'p':
320 static_assert(sizeof(long) == sizeof(void*));
321 lflag = 1;
322 putch('0', putdat);
323 putch('x', putdat);
324 /* fall through to 'x' */
325
326 // (unsigned) hexadecimal
327 case 'x':
328 base = 16;
329 unsigned_number:
330 num = getuint(&ap, lflag);
331 signed_number:
332 printnum(putch, putdat, num, base, width, padc);
333 break;
334
335 // escaped '%' character
336 case '%':
337 putch(ch, putdat);
338 break;
339
340 // unrecognized escape sequence - just print it literally
341 default:
342 putch('%', putdat);
343 fmt = last_fmt;
344 break;
345 }
346 }
347 }
348
349 int printf(const char* fmt, ...)
350 {
351 va_list ap;
352 va_start(ap, fmt);
353
354 vprintfmt((void*)putchar, 0, fmt, ap);
355
356 va_end(ap);
357 return 0; // incorrect return value, but who cares, anyway?
358 }
359
360 int sprintf(char* str, const char* fmt, ...)
361 {
362 va_list ap;
363 char* str0 = str;
364 va_start(ap, fmt);
365
366 void sprintf_putch(int ch, void** data)
367 {
368 char** pstr = (char**)data;
369 **pstr = ch;
370 (*pstr)++;
371 }
372
373 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
374 *str = 0;
375
376 va_end(ap);
377 return str - str0;
378 }
379
380 void* memcpy(void* dest, const void* src, size_t len)
381 {
382 if ((((uintptr_t)dest | (uintptr_t)src | len) & (sizeof(uintptr_t)-1)) == 0) {
383 const uintptr_t* s = src;
384 uintptr_t *d = dest;
385 while (d < (uintptr_t*)(dest + len))
386 *d++ = *s++;
387 } else {
388 const char* s = src;
389 char *d = dest;
390 while (d < (char*)(dest + len))
391 *d++ = *s++;
392 }
393 return dest;
394 }
395
396 void* memset(void* dest, int byte, size_t len)
397 {
398 if ((((uintptr_t)dest | len) & (sizeof(uintptr_t)-1)) == 0) {
399 uintptr_t word = byte & 0xFF;
400 word |= word << 8;
401 word |= word << 16;
402 word |= word << 16 << 16;
403
404 uintptr_t *d = dest;
405 while (d < (uintptr_t*)(dest + len))
406 *d++ = word;
407 } else {
408 char *d = dest;
409 while (d < (char*)(dest + len))
410 *d++ = byte;
411 }
412 return dest;
413 }
414
415 size_t strlen(const char *s)
416 {
417 const char *p = s;
418 while (*p)
419 p++;
420 return p - s;
421 }
422
423 size_t strnlen(const char *s, size_t n)
424 {
425 const char *p = s;
426 while (n-- && *p)
427 p++;
428 return p - s;
429 }
430
431 int strcmp(const char* s1, const char* s2)
432 {
433 unsigned char c1, c2;
434
435 do {
436 c1 = *s1++;
437 c2 = *s2++;
438 } while (c1 != 0 && c1 == c2);
439
440 return c1 - c2;
441 }
442
443 char* strcpy(char* dest, const char* src)
444 {
445 char* d = dest;
446 while ((*d++ = *src++))
447 ;
448 return dest;
449 }
450
451 long atol(const char* str)
452 {
453 long res = 0;
454 int sign = 0;
455
456 while (*str == ' ')
457 str++;
458
459 if (*str == '-' || *str == '+') {
460 sign = *str == '-';
461 str++;
462 }
463
464 while (*str) {
465 res *= 10;
466 res += *str++ - '0';
467 }
468
469 return sign ? -res : res;
470 }