nir: Rewrite nir_type_conversion_op
[mesa.git] / src / compiler / nir / nir.c
index a9fac96d1e4e5c85a3e51e7a4a8d2914c5ad9fa8..37fd9cb5c56bf2b560ad092fa1411da351ab63f4 100644 (file)
@@ -1967,87 +1967,116 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst)
    unsigned src_bitsize = nir_alu_type_get_type_size(src);
    unsigned dst_bitsize = nir_alu_type_get_type_size(dst);
 
-   if (src_base_type == dst_base_type) {
-      if (src_bitsize == dst_bitsize)
-         return (src_base_type == nir_type_float) ? nir_op_fmov : nir_op_imov;
-
-      assert(src_bitsize == 64 || dst_bitsize == 64);
-      if (src_base_type == nir_type_float)
-         /* TODO: implement support for float16 */
-         return (src_bitsize == 64) ? nir_op_d2f : nir_op_f2d;
-      else if (src_base_type == nir_type_uint)
-         return (src_bitsize == 64) ? nir_op_imov : nir_op_u2u64;
-      else if (src_base_type == nir_type_int)
-         return (src_bitsize == 64) ? nir_op_imov : nir_op_i2i64;
-      unreachable("Invalid conversion");
-   }
-
-   /* Different base type but same bit_size */
    if (src_bitsize == dst_bitsize) {
-      /* TODO: This does not include specific conversions between
-       * signed or unsigned integer types of bit size different than 32 yet.
-       */
-      assert(src_bitsize == 32);
       switch (src_base_type) {
-      case nir_type_uint:
-         return (dst_base_type == nir_type_float) ? nir_op_u2f : nir_op_imov;
       case nir_type_int:
-         return (dst_base_type == nir_type_float) ? nir_op_i2f : nir_op_imov;
-      case nir_type_bool:
-         return (dst_base_type == nir_type_float) ? nir_op_b2f : nir_op_b2i;
+      case nir_type_uint:
+         if (dst_base_type == nir_type_uint || dst_base_type == nir_type_int)
+            return nir_op_imov;
+         break;
       case nir_type_float:
-         switch (dst_base_type) {
-         case nir_type_uint:
-            return nir_op_f2u;
-         case nir_type_bool:
-            return nir_op_f2b;
-         default:
-            return nir_op_f2i;
-         };
+         if (dst_base_type == nir_type_float)
+            return nir_op_fmov;
+         break;
+      case nir_type_bool:
+         if (dst_base_type == nir_type_bool)
+            return nir_op_imov;
+         break;
       default:
          unreachable("Invalid conversion");
-      };
+      }
    }
 
-   /* Different bit_size and different base type */
-   /* TODO: Implement integer support for types with bit_size != 32 */
    switch (src_base_type) {
-   case nir_type_uint:
-      if (dst == nir_type_float64)
-         return nir_op_u2d;
-      else if (dst == nir_type_int64)
-         return nir_op_u2i64;
-      break;
    case nir_type_int:
-      if (dst == nir_type_float64)
-         return nir_op_i2d;
-      else if (dst == nir_type_uint64)
-         return nir_op_i2i64;
-      break;
-   case nir_type_bool:
-      assert(dst == nir_type_float64);
-      return nir_op_u2d;
-   case nir_type_float:
-      assert(src_bitsize == 32 || src_bitsize == 64);
-      if (src_bitsize != 64) {
-         assert(dst == nir_type_float64);
-         return nir_op_f2d;
+      switch (dst_base_type) {
+      case nir_type_int:
+         assert(src_bitsize != dst_bitsize);
+         return (dst_bitsize == 32) ? nir_op_i2i32 : nir_op_i2i64;
+      case nir_type_uint:
+         assert(src_bitsize != dst_bitsize);
+         return (dst_bitsize == 32) ? nir_op_i2u32 : nir_op_i2u64;
+      case nir_type_float:
+         switch (src_bitsize) {
+         case 32:
+            return (dst_bitsize == 32) ? nir_op_i2f : nir_op_i2d;
+         case 64:
+            return (dst_bitsize == 32) ? nir_op_i642f : nir_op_i642d;
+         default:
+            unreachable("Invalid conversion");
+         }
+      case nir_type_bool:
+         return (src_bitsize == 32) ? nir_op_i2b : nir_op_i642b;
+      default:
+         unreachable("Invalid conversion");
       }
-      assert(dst_bitsize == 32);
+
+   case nir_type_uint:
       switch (dst_base_type) {
+      case nir_type_int:
+         assert(src_bitsize != dst_bitsize);
+         return (dst_bitsize == 32) ? nir_op_u2i32 : nir_op_u2i64;
       case nir_type_uint:
-         return nir_op_d2u;
+         assert(src_bitsize != dst_bitsize);
+         return (dst_bitsize == 32) ? nir_op_u2u32 : nir_op_u2u64;
+      case nir_type_float:
+         switch (src_bitsize) {
+         case 32:
+            return (dst_bitsize == 32) ? nir_op_u2f : nir_op_u2d;
+         case 64:
+            return (dst_bitsize == 32) ? nir_op_u642f : nir_op_u642d;
+         default:
+            unreachable("Invalid conversion");
+         }
+      case nir_type_bool:
+         return (src_bitsize == 32) ? nir_op_i2b : nir_op_i642b;
+      default:
+         unreachable("Invalid conversion");
+      }
+
+   case nir_type_float:
+      switch (dst_base_type) {
       case nir_type_int:
-         return nir_op_d2i;
+         switch (src_bitsize) {
+         case 32:
+            return (dst_bitsize == 32) ? nir_op_f2i : nir_op_f2i64;
+         case 64:
+            return (dst_bitsize == 32) ? nir_op_d2i : nir_op_f2i64;
+         default:
+            unreachable("Invalid conversion");
+         }
+      case nir_type_uint:
+         switch (src_bitsize) {
+         case 32:
+            return (dst_bitsize == 32) ? nir_op_f2u : nir_op_f2u64;
+         case 64:
+            return (dst_bitsize == 32) ? nir_op_d2u : nir_op_f2u64;
+         default:
+            unreachable("Invalid conversion");
+         }
+      case nir_type_float:
+         assert(src_bitsize != dst_bitsize);
+         return (dst_bitsize == 32) ? nir_op_d2f : nir_op_f2d;
       case nir_type_bool:
-         return nir_op_d2b;
+         return (src_bitsize == 32) ? nir_op_f2b : nir_op_d2b;
+      default:
+         unreachable("Invalid conversion");
+      }
+
+   case nir_type_bool:
+      switch (dst_base_type) {
+      case nir_type_int:
+      case nir_type_uint:
+         return (dst_bitsize == 32) ? nir_op_b2i : nir_op_b2i64;
       case nir_type_float:
-         return nir_op_d2f;
+         /* GLSL just emits f2d(b2f(x)) for b2d */
+         assert(dst_bitsize == 32);
+         return nir_op_b2f;
       default:
          unreachable("Invalid conversion");
-      };
+      }
+
    default:
       unreachable("Invalid conversion");
-   };
-   unreachable("Invalid conversion");
+   }
 }