nir/range-analysis: Use types in the hash key
authorIan Romanick <ian.d.romanick@intel.com>
Tue, 13 Aug 2019 00:28:35 +0000 (17:28 -0700)
committerIan Romanick <ian.d.romanick@intel.com>
Wed, 25 Sep 2019 22:37:01 +0000 (15:37 -0700)
This allows the reslut of mov and bcsel to be separately interpreted as
float or int depending on the use.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/compiler/nir/nir_range_analysis.c

index 298d9946b567f28f230d81da163ea222f277a895..e1f3eb14bceeef93887a745aa4901f3958eaa239 100644 (file)
@@ -51,8 +51,41 @@ unpack_data(const void *p)
    return (struct ssa_result_range){v & 0xff, (v & 0x0ff00) != 0};
 }
 
+static void *
+pack_key(const struct nir_alu_instr *instr, nir_alu_type type)
+{
+   uintptr_t type_encoding;
+   uintptr_t ptr = (uintptr_t) instr;
+
+   /* The low 2 bits have to be zero or this whole scheme falls apart. */
+   assert((ptr & 0x3) == 0);
+
+   /* NIR is typeless in the sense that sequences of bits have whatever
+    * meaning is attached to them by the instruction that consumes them.
+    * However, the number of bits must match between producer and consumer.
+    * As a result, the number of bits does not need to be encoded here.
+    */
+   switch (nir_alu_type_get_base_type(type)) {
+   case nir_type_int:   type_encoding = 0; break;
+   case nir_type_uint:  type_encoding = 1; break;
+   case nir_type_bool:  type_encoding = 2; break;
+   case nir_type_float: type_encoding = 3; break;
+   default: unreachable("Invalid base type.");
+   }
+
+   return (void *)(ptr | type_encoding);
+}
+
+static nir_alu_type
+nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
+{
+   return nir_alu_type_get_base_type(nir_op_infos[instr->op].input_types[src]) |
+          nir_src_bit_size(instr->src[src].src);
+}
+
 static struct ssa_result_range
-analyze_constant(const struct nir_alu_instr *instr, unsigned src)
+analyze_constant(const struct nir_alu_instr *instr, unsigned src,
+                 nir_alu_type use_type)
 {
    uint8_t swizzle[4] = { 0, 1, 2, 3 };
 
@@ -69,7 +102,7 @@ analyze_constant(const struct nir_alu_instr *instr, unsigned src)
 
    struct ssa_result_range r = { unknown, false };
 
-   switch (nir_op_infos[instr->op].input_types[src]) {
+   switch (nir_alu_type_get_base_type(use_type)) {
    case nir_type_float: {
       double min_value = DBL_MAX;
       double max_value = -DBL_MAX;
@@ -321,13 +354,13 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b)
  */
 static struct ssa_result_range
 analyze_expression(const nir_alu_instr *instr, unsigned src,
-                   struct hash_table *ht)
+                   struct hash_table *ht, nir_alu_type use_type)
 {
    if (!instr->src[src].src.is_ssa)
       return (struct ssa_result_range){unknown, false};
 
    if (nir_src_is_const(instr->src[src].src))
-      return analyze_constant(instr, src);
+      return analyze_constant(instr, src, use_type);
 
    if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
       return (struct ssa_result_range){unknown, false};
@@ -335,8 +368,6 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    const struct nir_alu_instr *const alu =
        nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
 
-   const nir_alu_type use_type = nir_op_infos[instr->op].input_types[src];
-
    /* Bail if the type of the instruction generating the value does not match
     * the type the value will be interpreted as.  int/uint/bool can be
     * reinterpreted trivially.  The most important cases are between float and
@@ -355,7 +386,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       }
    }
 
-   struct hash_entry *he = _mesa_hash_table_search(ht, alu);
+   struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type));
    if (he != NULL)
       return unpack_data(he->data);
 
@@ -466,8 +497,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_bcsel: {
-      const struct ssa_result_range left = analyze_expression(alu, 1, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 2, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
 
       /* If either source is a constant load that is not zero, punt.  The type
        * will always be uint regardless of the actual type.  We can't even
@@ -545,7 +578,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
 
    case nir_op_i2f32:
    case nir_op_u2f32:
-      r = analyze_expression(alu, 0, ht);
+      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       r.is_integral = true;
 
@@ -555,7 +588,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_fabs:
-      r = analyze_expression(alu, 0, ht);
+      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       switch (r.range) {
       case unknown:
@@ -577,8 +610,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_fadd: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 1, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
 
       r.is_integral = left.is_integral && right.is_integral;
       r.range = fadd_table[left.range][right.range];
@@ -595,7 +630,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
          ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero
       };
 
-      r = analyze_expression(alu, 0, ht);
+      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table);
       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table);
@@ -606,8 +641,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fmax: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 1, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -669,8 +706,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fmin: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 1, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -732,8 +771,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fmul: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 1, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -753,11 +794,15 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_frcp:
-      r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, false};
+      r = (struct ssa_result_range){
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
+         false
+      };
       break;
 
    case nir_op_mov: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       /* See commentary in nir_op_bcsel for the reasons this is necessary. */
       if (nir_src_is_const(alu->src[0].src) && left.range != eq_zero)
@@ -768,13 +813,13 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fneg:
-      r = analyze_expression(alu, 0, ht);
+      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       r.range = fneg_table[r.range];
       break;
 
    case nir_op_fsat:
-      r = analyze_expression(alu, 0, ht);
+      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       switch (r.range) {
       case le_zero:
@@ -799,7 +844,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_fsign:
-      r = (struct ssa_result_range){analyze_expression(alu, 0, ht).range, true};
+      r = (struct ssa_result_range){
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
+         true
+      };
       break;
 
    case nir_op_fsqrt:
@@ -808,7 +856,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_ffloor: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       r.is_integral = true;
 
@@ -823,7 +872,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fceil: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       r.is_integral = true;
 
@@ -838,7 +888,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_ftrunc: {
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
 
       r.is_integral = true;
 
@@ -919,8 +970,10 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
          /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero },
       };
 
-      const struct ssa_result_range left = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range right = analyze_expression(alu, 1, ht);
+      const struct ssa_result_range left =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range right =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
 
       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table);
       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table);
@@ -932,9 +985,12 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_ffma: {
-      const struct ssa_result_range first = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range second = analyze_expression(alu, 1, ht);
-      const struct ssa_result_range third = analyze_expression(alu, 2, ht);
+      const struct ssa_result_range first =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range second =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range third =
+         analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
 
       r.is_integral = first.is_integral && second.is_integral &&
                       third.is_integral;
@@ -957,9 +1013,12 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_flrp: {
-      const struct ssa_result_range first = analyze_expression(alu, 0, ht);
-      const struct ssa_result_range second = analyze_expression(alu, 1, ht);
-      const struct ssa_result_range third = analyze_expression(alu, 2, ht);
+      const struct ssa_result_range first =
+         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range second =
+         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range third =
+         analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
 
       r.is_integral = first.is_integral && second.is_integral &&
                       third.is_integral;
@@ -983,7 +1042,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    if (r.range == eq_zero)
       r.is_integral = true;
 
-   _mesa_hash_table_insert(ht, alu, pack_data(r));
+   _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r));
    return r;
 }
 
@@ -994,7 +1053,8 @@ nir_analyze_range(const nir_alu_instr *instr, unsigned src)
 {
    struct hash_table *ht = _mesa_pointer_hash_table_create(NULL);
 
-   const struct ssa_result_range r = analyze_expression(instr, src, ht);
+   const struct ssa_result_range r =
+      analyze_expression(instr, src, ht, nir_alu_src_type(instr, src));
 
    _mesa_hash_table_destroy(ht, NULL);