nir/range-analysis: Rudimentary value range analysis pass
[mesa.git] / src / compiler / nir / nir_range_analysis.c
1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23 #include <math.h>
24 #include <float.h>
25 #include "nir.h"
26 #include "nir_range_analysis.h"
27 #include "util/hash_table.h"
28
29 /**
30 * Analyzes a sequence of operations to determine some aspects of the range of
31 * the result.
32 */
33
34 static bool
35 is_not_zero(enum ssa_ranges r)
36 {
37 return r == gt_zero || r == lt_zero || r == ne_zero;
38 }
39
40 static void *
41 pack_data(const struct ssa_result_range r)
42 {
43 return (void *)(uintptr_t)(r.range | r.is_integral << 8);
44 }
45
46 static struct ssa_result_range
47 unpack_data(const void *p)
48 {
49 const uintptr_t v = (uintptr_t) p;
50
51 return (struct ssa_result_range){v & 0xff, (v & 0x0ff00) != 0};
52 }
53
54 static struct ssa_result_range
55 analyze_constant(const struct nir_alu_instr *instr, unsigned src)
56 {
57 uint8_t swizzle[4] = { 0, 1, 2, 3 };
58
59 /* If the source is an explicitly sized source, then we need to reset
60 * both the number of components and the swizzle.
61 */
62 const unsigned num_components = nir_ssa_alu_instr_src_components(instr, src);
63
64 for (unsigned i = 0; i < num_components; ++i)
65 swizzle[i] = instr->src[src].swizzle[i];
66
67 const nir_load_const_instr *const load =
68 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
69
70 struct ssa_result_range r = { unknown, false };
71
72 switch (nir_op_infos[instr->op].input_types[src]) {
73 case nir_type_float: {
74 double min_value = DBL_MAX;
75 double max_value = -DBL_MAX;
76 bool any_zero = false;
77 bool all_zero = true;
78
79 r.is_integral = true;
80
81 for (unsigned i = 0; i < num_components; ++i) {
82 const double v = nir_const_value_as_float(load->value[swizzle[i]],
83 load->def.bit_size);
84
85 if (floor(v) != v)
86 r.is_integral = false;
87
88 any_zero = any_zero || (v == 0.0);
89 all_zero = all_zero && (v == 0.0);
90 min_value = MIN2(min_value, v);
91 max_value = MAX2(max_value, v);
92 }
93
94 assert(any_zero >= all_zero);
95 assert(isnan(max_value) || max_value >= min_value);
96
97 if (all_zero)
98 r.range = eq_zero;
99 else if (min_value > 0.0)
100 r.range = gt_zero;
101 else if (min_value == 0.0)
102 r.range = ge_zero;
103 else if (max_value < 0.0)
104 r.range = lt_zero;
105 else if (max_value == 0.0)
106 r.range = le_zero;
107 else if (!any_zero)
108 r.range = ne_zero;
109 else
110 r.range = unknown;
111
112 return r;
113 }
114
115 case nir_type_int:
116 case nir_type_bool: {
117 int64_t min_value = INT_MAX;
118 int64_t max_value = INT_MIN;
119 bool any_zero = false;
120 bool all_zero = true;
121
122 for (unsigned i = 0; i < num_components; ++i) {
123 const int64_t v = nir_const_value_as_int(load->value[swizzle[i]],
124 load->def.bit_size);
125
126 any_zero = any_zero || (v == 0);
127 all_zero = all_zero && (v == 0);
128 min_value = MIN2(min_value, v);
129 max_value = MAX2(max_value, v);
130 }
131
132 assert(any_zero >= all_zero);
133 assert(max_value >= min_value);
134
135 if (all_zero)
136 r.range = eq_zero;
137 else if (min_value > 0)
138 r.range = gt_zero;
139 else if (min_value == 0)
140 r.range = ge_zero;
141 else if (max_value < 0)
142 r.range = lt_zero;
143 else if (max_value == 0)
144 r.range = le_zero;
145 else if (!any_zero)
146 r.range = ne_zero;
147 else
148 r.range = unknown;
149
150 return r;
151 }
152
153 case nir_type_uint: {
154 bool any_zero = false;
155 bool all_zero = true;
156
157 for (unsigned i = 0; i < num_components; ++i) {
158 const uint64_t v = nir_const_value_as_uint(load->value[swizzle[i]],
159 load->def.bit_size);
160
161 any_zero = any_zero || (v == 0);
162 all_zero = all_zero && (v == 0);
163 }
164
165 assert(any_zero >= all_zero);
166
167 if (all_zero)
168 r.range = eq_zero;
169 else if (any_zero)
170 r.range = ge_zero;
171 else
172 r.range = gt_zero;
173
174 return r;
175 }
176
177 default:
178 unreachable("Invalid alu source type");
179 }
180 }
181
182 #ifndef NDEBUG
183 #define ASSERT_TABLE_IS_COMMUTATIVE(t) \
184 do { \
185 for (unsigned r = 0; r < ARRAY_SIZE(t); r++) { \
186 for (unsigned c = 0; c < ARRAY_SIZE(t[0]); c++) \
187 assert(t[r][c] == t[c][r]); \
188 } \
189 } while (false)
190
191 #define ASSERT_TABLE_IS_DIAGONAL(t) \
192 do { \
193 for (unsigned r = 0; r < ARRAY_SIZE(t); r++) \
194 assert(t[r][r] == r); \
195 } while (false)
196 #else
197 #define ASSERT_TABLE_IS_COMMUTATIVE(t)
198 #define ASSERT_TABLE_IS_DIAGONAL(t)
199 #endif
200
201 /**
202 * Short-hand name for use in the tables in analyze_expression. If this name
203 * becomes a problem on some compiler, we can change it to _.
204 */
205 #define _______ unknown
206
207 /**
208 * Analyze an expression to determine the range of its result
209 *
210 * The end result of this analysis is a token that communicates something
211 * about the range of values. There's an implicit grammar that produces
212 * tokens from sequences of literal values, other tokens, and operations.
213 * This function implements this grammar as a recursive-descent parser. Some
214 * (but not all) of the grammar is listed in-line in the function.
215 */
216 static struct ssa_result_range
217 analyze_expression(const nir_alu_instr *instr, unsigned src,
218 struct hash_table *ht)
219 {
220 if (nir_src_is_const(instr->src[src].src))
221 return analyze_constant(instr, src);
222
223 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
224 return (struct ssa_result_range){unknown, false};
225
226 const struct nir_alu_instr *const alu =
227 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
228
229 struct hash_entry *he = _mesa_hash_table_search(ht, alu);
230 if (he != NULL)
231 return unpack_data(he->data);
232
233 struct ssa_result_range r = {unknown, false};
234
235 switch (alu->op) {
236 case nir_op_b2f32:
237 case nir_op_b2i32:
238 r = (struct ssa_result_range){ge_zero, alu->op == nir_op_b2f32};
239 break;
240
241 case nir_op_i2f32:
242 case nir_op_u2f32:
243 r = analyze_expression(alu, 0, ht);
244
245 r.is_integral = true;
246
247 if (r.range == unknown && alu->op == nir_op_u2f32)
248 r.range = ge_zero;
249
250 break;
251
252 case nir_op_fabs:
253 r = analyze_expression(alu, 0, ht);
254
255 switch (r.range) {
256 case unknown:
257 case le_zero:
258 case ge_zero:
259 r.range = ge_zero;
260 break;
261
262 case lt_zero:
263 case gt_zero:
264 case ne_zero:
265 r.range = gt_zero;
266 break;
267
268 case eq_zero:
269 break;
270 }
271
272 break;
273
274 case nir_op_fadd: {
275 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
276 const struct ssa_result_range right = analyze_expression(alu, 1, ht);
277
278 r.is_integral = left.is_integral && right.is_integral;
279
280 /* ge_zero: ge_zero + ge_zero
281 *
282 * gt_zero: gt_zero + eq_zero
283 * | gt_zero + ge_zero
284 * | eq_zero + gt_zero # Addition is commutative
285 * | ge_zero + gt_zero # Addition is commutative
286 * | gt_zero + gt_zero
287 * ;
288 *
289 * le_zero: le_zero + le_zero
290 *
291 * lt_zero: lt_zero + eq_zero
292 * | lt_zero + le_zero
293 * | eq_zero + lt_zero # Addition is commutative
294 * | le_zero + lt_zero # Addition is commutative
295 * | lt_zero + lt_zero
296 * ;
297 *
298 * eq_zero: eq_zero + eq_zero
299 *
300 * All other cases are 'unknown'.
301 */
302 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
303 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
304 /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
305 /* lt_zero */ { _______, lt_zero, lt_zero, _______, _______, _______, lt_zero },
306 /* le_zero */ { _______, lt_zero, le_zero, _______, _______, _______, le_zero },
307 /* gt_zero */ { _______, _______, _______, gt_zero, gt_zero, _______, gt_zero },
308 /* ge_zero */ { _______, _______, _______, gt_zero, ge_zero, _______, ge_zero },
309 /* ne_zero */ { _______, _______, _______, _______, _______, ne_zero, ne_zero },
310 /* eq_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
311 };
312
313 ASSERT_TABLE_IS_COMMUTATIVE(table);
314 ASSERT_TABLE_IS_DIAGONAL(table);
315
316 r.range = table[left.range][right.range];
317 break;
318 }
319
320 case nir_op_fexp2:
321 r = (struct ssa_result_range){gt_zero, analyze_expression(alu, 0, ht).is_integral};
322 break;
323
324 case nir_op_fmax: {
325 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
326 const struct ssa_result_range right = analyze_expression(alu, 1, ht);
327
328 r.is_integral = left.is_integral && right.is_integral;
329
330 /* gt_zero: fmax(gt_zero, *)
331 * | fmax(*, gt_zero) # Treat fmax as commutative
332 * ;
333 *
334 * ge_zero: fmax(ge_zero, ne_zero)
335 * | fmax(ge_zero, lt_zero)
336 * | fmax(ge_zero, le_zero)
337 * | fmax(ge_zero, eq_zero)
338 * | fmax(ne_zero, ge_zero) # Treat fmax as commutative
339 * | fmax(lt_zero, ge_zero) # Treat fmax as commutative
340 * | fmax(le_zero, ge_zero) # Treat fmax as commutative
341 * | fmax(eq_zero, ge_zero) # Treat fmax as commutative
342 * | fmax(ge_zero, ge_zero)
343 * ;
344 *
345 * le_zero: fmax(le_zero, lt_zero)
346 * | fmax(lt_zero, le_zero) # Treat fmax as commutative
347 * | fmax(le_zero, le_zero)
348 * ;
349 *
350 * lt_zero: fmax(lt_zero, lt_zero)
351 * ;
352 *
353 * ne_zero: fmax(ne_zero, lt_zero)
354 * | fmax(lt_zero, ne_zero) # Treat fmax as commutative
355 * | fmax(ne_zero, ne_zero)
356 * ;
357 *
358 * eq_zero: fmax(eq_zero, le_zero)
359 * | fmax(eq_zero, lt_zero)
360 * | fmax(le_zero, eq_zero) # Treat fmax as commutative
361 * | fmax(lt_zero, eq_zero) # Treat fmax as commutative
362 * | fmax(eq_zero, eq_zero)
363 * ;
364 *
365 * All other cases are 'unknown'.
366 */
367 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
368 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
369 /* unknown */ { _______, _______, _______, gt_zero, ge_zero, _______, _______ },
370 /* lt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
371 /* le_zero */ { _______, le_zero, le_zero, gt_zero, ge_zero, _______, eq_zero },
372 /* gt_zero */ { gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero },
373 /* ge_zero */ { ge_zero, ge_zero, ge_zero, gt_zero, ge_zero, ge_zero, ge_zero },
374 /* ne_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, _______ },
375 /* eq_zero */ { _______, eq_zero, eq_zero, gt_zero, ge_zero, _______, eq_zero }
376 };
377
378 /* Treat fmax as commutative. */
379 ASSERT_TABLE_IS_COMMUTATIVE(table);
380 ASSERT_TABLE_IS_DIAGONAL(table);
381
382 r.range = table[left.range][right.range];
383 break;
384 }
385
386 case nir_op_fmin: {
387 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
388 const struct ssa_result_range right = analyze_expression(alu, 1, ht);
389
390 r.is_integral = left.is_integral && right.is_integral;
391
392 /* lt_zero: fmin(lt_zero, *)
393 * | fmin(*, lt_zero) # Treat fmin as commutative
394 * ;
395 *
396 * le_zero: fmin(le_zero, ne_zero)
397 * | fmin(le_zero, gt_zero)
398 * | fmin(le_zero, ge_zero)
399 * | fmin(le_zero, eq_zero)
400 * | fmin(ne_zero, le_zero) # Treat fmin as commutative
401 * | fmin(gt_zero, le_zero) # Treat fmin as commutative
402 * | fmin(ge_zero, le_zero) # Treat fmin as commutative
403 * | fmin(eq_zero, le_zero) # Treat fmin as commutative
404 * | fmin(le_zero, le_zero)
405 * ;
406 *
407 * ge_zero: fmin(ge_zero, gt_zero)
408 * | fmin(gt_zero, ge_zero) # Treat fmin as commutative
409 * | fmin(ge_zero, ge_zero)
410 * ;
411 *
412 * gt_zero: fmin(gt_zero, gt_zero)
413 * ;
414 *
415 * ne_zero: fmin(ne_zero, gt_zero)
416 * | fmin(gt_zero, ne_zero) # Treat fmin as commutative
417 * | fmin(ne_zero, ne_zero)
418 * ;
419 *
420 * eq_zero: fmin(eq_zero, ge_zero)
421 * | fmin(eq_zero, gt_zero)
422 * | fmin(ge_zero, eq_zero) # Treat fmin as commutative
423 * | fmin(gt_zero, eq_zero) # Treat fmin as commutative
424 * | fmin(eq_zero, eq_zero)
425 * ;
426 *
427 * All other cases are 'unknown'.
428 */
429 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
430 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
431 /* unknown */ { _______, lt_zero, le_zero, _______, _______, _______, _______ },
432 /* lt_zero */ { lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero },
433 /* le_zero */ { le_zero, lt_zero, le_zero, le_zero, le_zero, le_zero, le_zero },
434 /* gt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
435 /* ge_zero */ { _______, lt_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
436 /* ne_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, _______ },
437 /* eq_zero */ { _______, lt_zero, le_zero, eq_zero, eq_zero, _______, eq_zero }
438 };
439
440 /* Treat fmin as commutative. */
441 ASSERT_TABLE_IS_COMMUTATIVE(table);
442 ASSERT_TABLE_IS_DIAGONAL(table);
443
444 r.range = table[left.range][right.range];
445 break;
446 }
447
448 case nir_op_fmul: {
449 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
450 const struct ssa_result_range right = analyze_expression(alu, 1, ht);
451
452 r.is_integral = left.is_integral && right.is_integral;
453
454 /* ge_zero: ge_zero * ge_zero
455 * | ge_zero * gt_zero
456 * | ge_zero * eq_zero
457 * | le_zero * lt_zero
458 * | lt_zero * le_zero # Multiplication is commutative
459 * | le_zero * le_zero
460 * | gt_zero * ge_zero # Multiplication is commutative
461 * | eq_zero * ge_zero # Multiplication is commutative
462 * | a * a # Left source == right source
463 * ;
464 *
465 * gt_zero: gt_zero * gt_zero
466 * | lt_zero * lt_zero
467 * ;
468 *
469 * le_zero: ge_zero * le_zero
470 * | ge_zero * lt_zero
471 * | lt_zero * ge_zero # Multiplication is commutative
472 * | le_zero * ge_zero # Multiplication is commutative
473 * | le_zero * gt_zero
474 * ;
475 *
476 * lt_zero: lt_zero * gt_zero
477 * | gt_zero * lt_zero # Multiplication is commutative
478 * ;
479 *
480 * ne_zero: ne_zero * gt_zero
481 * | ne_zero * lt_zero
482 * | gt_zero * ne_zero # Multiplication is commutative
483 * | lt_zero * ne_zero # Multiplication is commutative
484 * | ne_zero * ne_zero
485 * ;
486 *
487 * eq_zero: eq_zero * <any>
488 * <any> * eq_zero # Multiplication is commutative
489 *
490 * All other cases are 'unknown'.
491 */
492 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
493 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
494 /* unknown */ { _______, _______, _______, _______, _______, _______, eq_zero },
495 /* lt_zero */ { _______, gt_zero, ge_zero, lt_zero, le_zero, ne_zero, eq_zero },
496 /* le_zero */ { _______, ge_zero, ge_zero, le_zero, le_zero, _______, eq_zero },
497 /* gt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
498 /* ge_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
499 /* ne_zero */ { _______, ne_zero, _______, ne_zero, _______, ne_zero, eq_zero },
500 /* eq_zero */ { eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero }
501 };
502
503 ASSERT_TABLE_IS_COMMUTATIVE(table);
504
505 /* x * x => ge_zero */
506 if (left.range != eq_zero && nir_alu_srcs_equal(alu, alu, 0, 1)) {
507 /* x * x => ge_zero or gt_zero depending on the range of x. */
508 r.range = is_not_zero(left.range) ? gt_zero : ge_zero;
509 } else if (left.range != eq_zero && nir_alu_srcs_negative_equal(alu, alu, 0, 1)) {
510 /* -x * x => le_zero or lt_zero depending on the range of x. */
511 r.range = is_not_zero(left.range) ? lt_zero : le_zero;
512 } else
513 r.range = table[left.range][right.range];
514
515 break;
516 }
517
518 case nir_op_frcp:
519 r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, false};
520 break;
521
522 case nir_op_mov:
523 r = analyze_expression(alu, 0, ht);
524 break;
525
526 case nir_op_fneg:
527 r = analyze_expression(alu, 0, ht);
528
529 switch (r.range) {
530 case le_zero:
531 r.range = ge_zero;
532 break;
533
534 case ge_zero:
535 r.range = le_zero;
536 break;
537
538 case lt_zero:
539 r.range = gt_zero;
540 break;
541
542 case gt_zero:
543 r.range = lt_zero;
544 break;
545
546 case ne_zero:
547 case eq_zero:
548 case unknown:
549 /* Negation doesn't change anything about these ranges. */
550 break;
551 }
552
553 break;
554
555 case nir_op_fsat:
556 r = (struct ssa_result_range){ge_zero, analyze_expression(alu, 0, ht).is_integral};
557 break;
558
559 case nir_op_fsign:
560 r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, true};
561 break;
562
563 case nir_op_fsqrt:
564 case nir_op_frsq:
565 r = (struct ssa_result_range){ge_zero, false};
566 break;
567
568 case nir_op_ffloor: {
569 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
570
571 r.is_integral = true;
572
573 if (left.is_integral || left.range == le_zero || left.range == lt_zero)
574 r.range = left.range;
575 else if (left.range == ge_zero || left.range == gt_zero)
576 r.range = ge_zero;
577 else if (left.range == ne_zero)
578 r.range = unknown;
579
580 break;
581 }
582
583 case nir_op_fceil: {
584 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
585
586 r.is_integral = true;
587
588 if (left.is_integral || left.range == ge_zero || left.range == gt_zero)
589 r.range = left.range;
590 else if (left.range == le_zero || left.range == lt_zero)
591 r.range = le_zero;
592 else if (left.range == ne_zero)
593 r.range = unknown;
594
595 break;
596 }
597
598 case nir_op_ftrunc: {
599 const struct ssa_result_range left = analyze_expression(alu, 0, ht);
600
601 r.is_integral = true;
602
603 if (left.is_integral)
604 r.range = left.range;
605 else if (left.range == ge_zero || left.range == gt_zero)
606 r.range = ge_zero;
607 else if (left.range == le_zero || left.range == lt_zero)
608 r.range = le_zero;
609 else if (left.range == ne_zero)
610 r.range = unknown;
611
612 break;
613 }
614
615 case nir_op_flt:
616 case nir_op_fge:
617 case nir_op_feq:
618 case nir_op_fne:
619 case nir_op_ilt:
620 case nir_op_ige:
621 case nir_op_ieq:
622 case nir_op_ine:
623 case nir_op_ult:
624 case nir_op_uge:
625 /* Boolean results are 0 or -1. */
626 r = (struct ssa_result_range){le_zero, false};
627 break;
628
629 default:
630 r = (struct ssa_result_range){unknown, false};
631 break;
632 }
633
634 if (r.range == eq_zero)
635 r.is_integral = true;
636
637 _mesa_hash_table_insert(ht, alu, pack_data(r));
638 return r;
639 }
640
641 #undef _______
642
643 struct ssa_result_range
644 nir_analyze_range(const nir_alu_instr *instr, unsigned src)
645 {
646 struct hash_table *ht = _mesa_pointer_hash_table_create(NULL);
647
648 const struct ssa_result_range r = analyze_expression(instr, src, ht);
649
650 _mesa_hash_table_destroy(ht, NULL);
651
652 return r;
653 }