spirv: add support for doubles to OpSpecConstant
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index 07980aa2019d185f464195e7ca5f357b18d55314..b67189e07a68ca29ad2e599fb69b92ecd508a766 100644 (file)
 #include "nir/nir_constant_expressions.h"
 #include "spirv_info.h"
 
+struct spec_constant_value {
+   bool is_double;
+   union {
+      uint32_t data32;
+      uint64_t data64;
+   };
+};
+
 void
 _vtn_warn(const char *file, int line, const char *msg, ...)
 {
@@ -98,11 +106,12 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_FLOAT:
-   case GLSL_TYPE_DOUBLE:
+   case GLSL_TYPE_DOUBLE: {
+      int bit_size = glsl_get_bit_size(type);
       if (glsl_type_is_vector_or_scalar(type)) {
          unsigned num_components = glsl_get_vector_elements(val->type);
          nir_load_const_instr *load =
-            nir_load_const_instr_create(b->shader, num_components, 32);
+            nir_load_const_instr_create(b->shader, num_components, bit_size);
 
          load->value = constant->values[0];
 
@@ -118,7 +127,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
             struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
             col_val->type = glsl_get_column_type(val->type);
             nir_load_const_instr *load =
-               nir_load_const_instr_create(b->shader, rows, 32);
+               nir_load_const_instr_create(b->shader, rows, bit_size);
 
             load->value = constant->values[i];
 
@@ -129,6 +138,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
          }
       }
       break;
+   }
 
    case GLSL_TYPE_ARRAY: {
       unsigned elems = glsl_get_length(val->type);
@@ -704,9 +714,11 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       val->type->type = (signedness ? glsl_int_type() : glsl_uint_type());
       break;
    }
-   case SpvOpTypeFloat:
-      val->type->type = glsl_float_type();
+   case SpvOpTypeFloat: {
+      int bit_size = w[2];
+      val->type->type = bit_size == 64 ? glsl_double_type() : glsl_float_type();
       break;
+   }
 
    case SpvOpTypeVector: {
       struct vtn_type *base = vtn_value(b, w[2], vtn_value_type_type)->type;
@@ -930,7 +942,7 @@ vtn_null_constant(struct vtn_builder *b, const struct glsl_type *type)
 }
 
 static void
-spec_constant_deocoration_cb(struct vtn_builder *b, struct vtn_value *v,
+spec_constant_decoration_cb(struct vtn_builder *b, struct vtn_value *v,
                              int member, const struct vtn_decoration *dec,
                              void *data)
 {
@@ -938,11 +950,14 @@ spec_constant_deocoration_cb(struct vtn_builder *b, struct vtn_value *v,
    if (dec->decoration != SpvDecorationSpecId)
       return;
 
-   uint32_t *const_value = data;
+   struct spec_constant_value *const_value = data;
 
    for (unsigned i = 0; i < b->num_specializations; i++) {
       if (b->specializations[i].id == dec->literals[0]) {
-         *const_value = b->specializations[i].data;
+         if (const_value->is_double)
+            const_value->data64 = b->specializations[i].data64;
+         else
+            const_value->data32 = b->specializations[i].data32;
          return;
       }
    }
@@ -952,8 +967,22 @@ static uint32_t
 get_specialization(struct vtn_builder *b, struct vtn_value *val,
                    uint32_t const_value)
 {
-   vtn_foreach_decoration(b, val, spec_constant_deocoration_cb, &const_value);
-   return const_value;
+   struct spec_constant_value data;
+   data.is_double = false;
+   data.data32 = const_value;
+   vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
+   return data.data32;
+}
+
+static uint64_t
+get_specialization64(struct vtn_builder *b, struct vtn_value *val,
+                   uint64_t const_value)
+{
+   struct spec_constant_value data;
+   data.is_double = true;
+   data.data64 = const_value;
+   vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
+   return data.data64;
 }
 
 static void
@@ -1001,14 +1030,29 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
-   case SpvOpConstant:
+   case SpvOpConstant: {
       assert(glsl_type_is_scalar(val->const_type));
-      val->constant->values[0].u32[0] = w[3];
+      int bit_size = glsl_get_bit_size(val->const_type);
+      if (bit_size == 64) {
+         val->constant->values->u32[0] = w[3];
+         val->constant->values->u32[1] = w[4];
+      } else {
+         assert(bit_size == 32);
+         val->constant->values->u32[0] = w[3];
+      }
       break;
-   case SpvOpSpecConstant:
+   }
+   case SpvOpSpecConstant: {
       assert(glsl_type_is_scalar(val->const_type));
       val->constant->values[0].u32[0] = get_specialization(b, val, w[3]);
+      int bit_size = glsl_get_bit_size(val->const_type);
+      if (bit_size == 64)
+         val->constant->values[0].u64[0] =
+            get_specialization64(b, val, vtn_u64_literal(&w[3]));
+      else
+         val->constant->values[0].u32[0] = get_specialization(b, val, w[3]);
       break;
+   }
    case SpvOpSpecConstantComposite:
    case SpvOpConstantComposite: {
       unsigned elem_count = count - 3;
@@ -1021,6 +1065,8 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       case GLSL_TYPE_INT:
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_BOOL:
+      case GLSL_TYPE_DOUBLE: {
+         int bit_size = glsl_get_bit_size(val->const_type);
          if (glsl_type_is_matrix(val->const_type)) {
             assert(glsl_get_matrix_columns(val->const_type) == elem_count);
             for (unsigned i = 0; i < elem_count; i++)
@@ -1028,12 +1074,18 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
          } else {
             assert(glsl_type_is_vector(val->const_type));
             assert(glsl_get_vector_elements(val->const_type) == elem_count);
-            for (unsigned i = 0; i < elem_count; i++)
-               val->constant->values[0].u32[i] = elems[i]->values[0].u32[0];
+            for (unsigned i = 0; i < elem_count; i++) {
+               if (bit_size == 64) {
+                  val->constant->values[0].u64[i] = elems[i]->values[0].u64[0];
+               } else {
+                  assert(bit_size == 32);
+                  val->constant->values[0].u32[i] = elems[i]->values[0].u32[0];
+               }
+            }
          }
          ralloc_free(elems);
          break;
-
+      }
       case GLSL_TYPE_STRUCT:
       case GLSL_TYPE_ARRAY:
          ralloc_steal(val->constant, elems);
@@ -1056,18 +1108,46 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
          unsigned len0 = glsl_get_vector_elements(v0->const_type);
          unsigned len1 = glsl_get_vector_elements(v1->const_type);
 
-         uint32_t u[8];
-         for (unsigned i = 0; i < len0; i++)
-            u[i] = v0->constant->values[0].u32[i];
-         for (unsigned i = 0; i < len1; i++)
-            u[len0 + i] = v1->constant->values[0].u32[i];
-
-         for (unsigned i = 0; i < count - 6; i++) {
-            uint32_t comp = w[i + 6];
-            if (comp == (uint32_t)-1) {
-               val->constant->values[0].u32[i] = 0xdeadbeef;
-            } else {
-               val->constant->values[0].u32[i] = u[comp];
+         assert(len0 + len1 < 16);
+
+         unsigned bit_size = glsl_get_bit_size(val->const_type);
+         assert(bit_size == glsl_get_bit_size(v0->const_type) &&
+                bit_size == glsl_get_bit_size(v1->const_type));
+
+         if (bit_size == 64) {
+            uint64_t u64[8];
+            for (unsigned i = 0; i < len0; i++)
+               u64[i] = v0->constant->values[0].u64[i];
+            for (unsigned i = 0; i < len1; i++)
+               u64[len0 + i] = v1->constant->values[0].u64[i];
+
+            for (unsigned i = 0, j = 0; i < count - 6; i++, j++) {
+               uint32_t comp = w[i + 6];
+               /* If component is not used, set the value to a known constant
+                * to detect if it is wrongly used.
+                */
+               if (comp == (uint32_t)-1)
+                  val->constant->values[0].u64[j] = 0xdeadbeefdeadbeef;
+               else
+                  val->constant->values[0].u64[j] = u64[comp];
+            }
+         } else {
+            uint32_t u32[8];
+            for (unsigned i = 0; i < len0; i++)
+               u32[i] = v0->constant->values[0].u32[i];
+
+            for (unsigned i = 0; i < len1; i++)
+               u32[len0 + i] = v1->constant->values[0].u32[i];
+
+            for (unsigned i = 0, j = 0; i < count - 6; i++, j++) {
+               uint32_t comp = w[i + 6];
+               /* If component is not used, set the value to a known constant
+                * to detect if it is wrongly used.
+                */
+               if (comp == (uint32_t)-1)
+                  val->constant->values[0].u32[j] = 0xdeadbeef;
+               else
+                  val->constant->values[0].u32[j] = u32[comp];
             }
          }
          break;
@@ -1098,6 +1178,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             case GLSL_TYPE_UINT:
             case GLSL_TYPE_INT:
             case GLSL_TYPE_FLOAT:
+            case GLSL_TYPE_DOUBLE:
             case GLSL_TYPE_BOOL:
                /* If we hit this granularity, we're picking off an element */
                if (glsl_type_is_matrix(type)) {
@@ -1132,8 +1213,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                val->constant = *c;
             } else {
                unsigned num_components = glsl_get_vector_elements(type);
+               unsigned bit_size = glsl_get_bit_size(type);
                for (unsigned i = 0; i < num_components; i++)
-                  val->constant->values[0].u32[i] = (*c)->values[col].u32[elem + i];
+                  if (bit_size == 64) {
+                     val->constant->values[0].u64[i] = (*c)->values[col].u64[elem + i];
+                  } else {
+                     assert(bit_size == 32);
+                     val->constant->values[0].u32[i] = (*c)->values[col].u32[elem + i];
+                  }
             }
          } else {
             struct vtn_value *insert =
@@ -1143,8 +1230,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                *c = insert->constant;
             } else {
                unsigned num_components = glsl_get_vector_elements(type);
+               unsigned bit_size = glsl_get_bit_size(type);
                for (unsigned i = 0; i < num_components; i++)
-                  (*c)->values[col].u32[elem + i] = insert->constant->values[0].u32[i];
+                  if (bit_size == 64) {
+                     (*c)->values[col].u64[elem + i] = insert->constant->values[0].u64[i];
+                  } else {
+                     assert(bit_size == 32);
+                     (*c)->values[col].u32[elem + i] = insert->constant->values[0].u32[i];
+                  }
             }
          }
          break;
@@ -1152,7 +1245,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
 
       default: {
          bool swap;
-         nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
+         nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(val->const_type);
+         nir_alu_type src_alu_type = dst_alu_type;
+         nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
 
          unsigned num_components = glsl_get_vector_elements(val->const_type);
          unsigned bit_size =
@@ -1206,7 +1301,7 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
       struct vtn_value *arg = vtn_untyped_value(b, arg_id);
       if (arg->value_type == vtn_value_type_access_chain) {
          nir_deref_var *d = vtn_access_chain_to_deref(b, arg->access_chain);
-         call->params[i] = nir_deref_as_var(nir_copy_deref(call, &d->deref));
+         call->params[i] = nir_deref_var_clone(d, call);
       } else {
          struct vtn_ssa_value *arg_ssa = vtn_ssa_value(b, arg_id);
 
@@ -1542,15 +1637,15 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
    }
 
    nir_deref_var *sampler = vtn_access_chain_to_deref(b, sampled.sampler);
-   nir_deref *texture;
+   nir_deref_var *texture;
    if (sampled.image) {
       nir_deref_var *image = vtn_access_chain_to_deref(b, sampled.image);
-      texture = &image->deref;
+      texture = image;
    } else {
-      texture = &sampler->deref;
+      texture = sampler;
    }
 
-   instr->texture = nir_deref_as_var(nir_copy_deref(instr, texture));
+   instr->texture = nir_deref_var_clone(texture, instr);
 
    switch (instr->op) {
    case nir_texop_tex:
@@ -1558,7 +1653,7 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
    case nir_texop_txl:
    case nir_texop_txd:
       /* These operations require a sampler */
-      instr->sampler = nir_deref_as_var(nir_copy_deref(instr, &sampler->deref));
+      instr->sampler = nir_deref_var_clone(sampler, instr);
       break;
    case nir_texop_txf:
    case nir_texop_txf_ms:
@@ -1599,8 +1694,7 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
          instrs[i]->is_new_style_shadow = instr->is_new_style_shadow;
          instrs[i]->component = instr->component;
          instrs[i]->dest_type = instr->dest_type;
-         instrs[i]->texture =
-            nir_deref_as_var(nir_copy_deref(instrs[i], texture));
+         instrs[i]->texture = nir_deref_var_clone(texture, instrs[i]);
          instrs[i]->sampler = NULL;
 
          memcpy(instrs[i]->src, srcs, instr->num_srcs * sizeof(*instr->src));
@@ -1807,8 +1901,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
    nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->shader, op);
 
    nir_deref_var *image_deref = vtn_access_chain_to_deref(b, image.image);
-   intrin->variables[0] =
-      nir_deref_as_var(nir_copy_deref(&intrin->instr, &image_deref->deref));
+   intrin->variables[0] = nir_deref_var_clone(image_deref, intrin);
 
    /* ImageQuerySize doesn't take any extra parameters */
    if (opcode != SpvOpImageQuerySize) {
@@ -1967,10 +2060,10 @@ vtn_handle_ssbo_or_shared_atomic(struct vtn_builder *b, SpvOp opcode,
 
    if (chain->var->mode == vtn_variable_mode_workgroup) {
       struct vtn_type *type = chain->var->type;
-      nir_deref *deref = &vtn_access_chain_to_deref(b, chain)->deref;
+      nir_deref_var *deref = vtn_access_chain_to_deref(b, chain);
       nir_intrinsic_op op = get_shared_nir_atomic_op(opcode);
       atomic = nir_intrinsic_instr_create(b->nb.shader, op);
-      atomic->variables[0] = nir_deref_as_var(nir_copy_deref(atomic, deref));
+      atomic->variables[0] = nir_deref_var_clone(deref, atomic);
 
       switch (opcode) {
       case SpvOpAtomicLoad: