nir: Embed the shader_info in the nir_shader again
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index aecda172717ccc3deb5b27fdd318daad28f48b99..0a5eb0eb6b063beb10ed2e632ad3c0501c596c4e 100644 (file)
@@ -104,6 +104,8 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
    switch (glsl_get_base_type(type)) {
    case GLSL_TYPE_INT:
    case GLSL_TYPE_UINT:
+   case GLSL_TYPE_INT64:
+   case GLSL_TYPE_UINT64:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_DOUBLE: {
@@ -420,6 +422,8 @@ vtn_type_copy(struct vtn_builder *b, struct vtn_type *src)
       switch (glsl_get_base_type(src->type)) {
       case GLSL_TYPE_INT:
       case GLSL_TYPE_UINT:
+      case GLSL_TYPE_INT64:
+      case GLSL_TYPE_UINT64:
       case GLSL_TYPE_BOOL:
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_DOUBLE:
@@ -561,6 +565,9 @@ struct_member_decoration_cb(struct vtn_builder *b,
       vtn_warn("Decoration only allowed for CL-style kernels: %s",
                spirv_decoration_to_string(dec->decoration));
       break;
+
+   default:
+      unreachable("Unhandled decoration");
    }
 }
 
@@ -609,7 +616,7 @@ type_decoration_cb(struct vtn_builder *b,
    case SpvDecorationOffset:
    case SpvDecorationXfbBuffer:
    case SpvDecorationXfbStride:
-      vtn_warn("Decoraiton only allowed for struct members: %s",
+      vtn_warn("Decoration only allowed for struct members: %s",
                spirv_decoration_to_string(dec->decoration));
       break;
 
@@ -625,7 +632,7 @@ type_decoration_cb(struct vtn_builder *b,
    case SpvDecorationLinkageAttributes:
    case SpvDecorationNoContraction:
    case SpvDecorationInputAttachmentIndex:
-      vtn_warn("Decoraiton not allowed on types: %s",
+      vtn_warn("Decoration not allowed on types: %s",
                spirv_decoration_to_string(dec->decoration));
       break;
 
@@ -635,9 +642,12 @@ type_decoration_cb(struct vtn_builder *b,
    case SpvDecorationFPRoundingMode:
    case SpvDecorationFPFastMathMode:
    case SpvDecorationAlignment:
-      vtn_warn("Decoraiton only allowed for CL-style kernels: %s",
+      vtn_warn("Decoration only allowed for CL-style kernels: %s",
                spirv_decoration_to_string(dec->decoration));
       break;
+
+   default:
+      unreachable("Unhandled decoration");
    }
 }
 
@@ -709,8 +719,12 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       val->type->type = glsl_bool_type();
       break;
    case SpvOpTypeInt: {
+      int bit_size = w[2];
       const bool signedness = w[3];
-      val->type->type = (signedness ? glsl_int_type() : glsl_uint_type());
+      if (bit_size == 64)
+         val->type->type = (signedness ? glsl_int64_t_type() : glsl_uint64_t_type());
+      else
+         val->type->type = (signedness ? glsl_int_type() : glsl_uint_type());
       break;
    }
    case SpvOpTypeFloat: {
@@ -868,8 +882,6 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
          val->type->type = glsl_sampler_type(dim, is_shadow, is_array,
                                              glsl_get_base_type(sampled_type));
       } else if (sampled == 2) {
-         assert((dim == GLSL_SAMPLER_DIM_SUBPASS ||
-                 dim == GLSL_SAMPLER_DIM_SUBPASS_MS) || format);
          assert(!is_shadow);
          val->type->type = glsl_image_type(dim, is_array,
                                            glsl_get_base_type(sampled_type));
@@ -913,6 +925,8 @@ vtn_null_constant(struct vtn_builder *b, const struct glsl_type *type)
    switch (glsl_get_base_type(type)) {
    case GLSL_TYPE_INT:
    case GLSL_TYPE_UINT:
+   case GLSL_TYPE_INT64:
+   case GLSL_TYPE_UINT64:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_DOUBLE:
@@ -1003,9 +1017,9 @@ handle_workgroup_size_decoration_cb(struct vtn_builder *b,
 
    assert(val->const_type == glsl_vector_type(GLSL_TYPE_UINT, 3));
 
-   b->shader->info->cs.local_size[0] = val->constant->values[0].u32[0];
-   b->shader->info->cs.local_size[1] = val->constant->values[0].u32[1];
-   b->shader->info->cs.local_size[2] = val->constant->values[0].u32[2];
+   b->shader->info.cs.local_size[0] = val->constant->values[0].u32[0];
+   b->shader->info.cs.local_size[1] = val->constant->values[0].u32[1];
+   b->shader->info.cs.local_size[2] = val->constant->values[0].u32[2];
 }
 
 static void
@@ -1067,6 +1081,8 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       switch (glsl_get_base_type(val->const_type)) {
       case GLSL_TYPE_UINT:
       case GLSL_TYPE_INT:
+      case GLSL_TYPE_UINT64:
+      case GLSL_TYPE_INT64:
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_BOOL:
       case GLSL_TYPE_DOUBLE: {
@@ -1107,23 +1123,44 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       SpvOp opcode = get_specialization(b, val, w[3]);
       switch (opcode) {
       case SpvOpVectorShuffle: {
-         struct vtn_value *v0 = vtn_value(b, w[4], vtn_value_type_constant);
-         struct vtn_value *v1 = vtn_value(b, w[5], vtn_value_type_constant);
-         unsigned len0 = glsl_get_vector_elements(v0->const_type);
-         unsigned len1 = glsl_get_vector_elements(v1->const_type);
+         struct vtn_value *v0 = &b->values[w[4]];
+         struct vtn_value *v1 = &b->values[w[5]];
+
+         assert(v0->value_type == vtn_value_type_constant ||
+                v0->value_type == vtn_value_type_undef);
+         assert(v1->value_type == vtn_value_type_constant ||
+                v1->value_type == vtn_value_type_undef);
+
+         unsigned len0 = v0->value_type == vtn_value_type_constant ?
+                         glsl_get_vector_elements(v0->const_type) :
+                         glsl_get_vector_elements(v0->type->type);
+         unsigned len1 = v1->value_type == vtn_value_type_constant ?
+                         glsl_get_vector_elements(v1->const_type) :
+                         glsl_get_vector_elements(v1->type->type);
 
          assert(len0 + len1 < 16);
 
          unsigned bit_size = glsl_get_bit_size(val->const_type);
-         assert(bit_size == glsl_get_bit_size(v0->const_type) &&
-                bit_size == glsl_get_bit_size(v1->const_type));
+         unsigned bit_size0 = v0->value_type == vtn_value_type_constant ?
+                              glsl_get_bit_size(v0->const_type) :
+                              glsl_get_bit_size(v0->type->type);
+         unsigned bit_size1 = v1->value_type == vtn_value_type_constant ?
+                              glsl_get_bit_size(v1->const_type) :
+                              glsl_get_bit_size(v1->type->type);
+
+         assert(bit_size == bit_size0 && bit_size == bit_size1);
+         (void)bit_size0; (void)bit_size1;
 
          if (bit_size == 64) {
             uint64_t u64[8];
-            for (unsigned i = 0; i < len0; i++)
-               u64[i] = v0->constant->values[0].u64[i];
-            for (unsigned i = 0; i < len1; i++)
-               u64[len0 + i] = v1->constant->values[0].u64[i];
+            if (v0->value_type == vtn_value_type_constant) {
+               for (unsigned i = 0; i < len0; i++)
+                  u64[i] = v0->constant->values[0].u64[i];
+            }
+            if (v1->value_type == vtn_value_type_constant) {
+               for (unsigned i = 0; i < len1; i++)
+                  u64[len0 + i] = v1->constant->values[0].u64[i];
+            }
 
             for (unsigned i = 0, j = 0; i < count - 6; i++, j++) {
                uint32_t comp = w[i + 6];
@@ -1137,11 +1174,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             }
          } else {
             uint32_t u32[8];
-            for (unsigned i = 0; i < len0; i++)
-               u32[i] = v0->constant->values[0].u32[i];
-
-            for (unsigned i = 0; i < len1; i++)
-               u32[len0 + i] = v1->constant->values[0].u32[i];
+            if (v0->value_type == vtn_value_type_constant) {
+               for (unsigned i = 0; i < len0; i++)
+                  u32[i] = v0->constant->values[0].u32[i];
+            }
+            if (v1->value_type == vtn_value_type_constant) {
+               for (unsigned i = 0; i < len1; i++)
+                  u32[len0 + i] = v1->constant->values[0].u32[i];
+            }
 
             for (unsigned i = 0, j = 0; i < count - 6; i++, j++) {
                uint32_t comp = w[i + 6];
@@ -1181,6 +1221,8 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             switch (glsl_get_base_type(type)) {
             case GLSL_TYPE_UINT:
             case GLSL_TYPE_INT:
+            case GLSL_TYPE_UINT64:
+            case GLSL_TYPE_INT64:
             case GLSL_TYPE_FLOAT:
             case GLSL_TYPE_DOUBLE:
             case GLSL_TYPE_BOOL:
@@ -1350,6 +1392,8 @@ vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
          switch (glsl_get_base_type(type)) {
          case GLSL_TYPE_INT:
          case GLSL_TYPE_UINT:
+         case GLSL_TYPE_INT64:
+         case GLSL_TYPE_UINT64:
          case GLSL_TYPE_BOOL:
          case GLSL_TYPE_FLOAT:
          case GLSL_TYPE_DOUBLE:
@@ -1524,7 +1568,8 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
          coord_components++;
 
       coord = vtn_ssa_value(b, w[idx++])->def;
-      p->src = nir_src_for_ssa(coord);
+      p->src = nir_src_for_ssa(nir_channels(&b->nb, coord,
+                                            (1 << coord_components) - 1));
       p->src_type = nir_tex_src_coord;
       p++;
       break;
@@ -1953,17 +1998,21 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
    if (opcode != SpvOpImageWrite) {
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
       struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
-      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 4, 32, NULL);
+
+      unsigned dest_components =
+         nir_intrinsic_infos[intrin->intrinsic].dest_components;
+      if (intrin->intrinsic == nir_intrinsic_image_size) {
+         dest_components = intrin->num_components =
+            glsl_get_vector_elements(type->type);
+      }
+
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest,
+                        dest_components, 32, NULL);
 
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      /* The image intrinsics always return 4 channels but we may not want
-       * that many.  Emit a mov to trim it down.
-       */
-      unsigned swiz[4] = {0, 1, 2, 3};
       val->ssa = vtn_create_ssa_value(b, type->type);
-      val->ssa->def = nir_swizzle(&b->nb, &intrin->dest.ssa, swiz,
-                                  glsl_get_vector_elements(type->type), false);
+      val->ssa->def = &intrin->dest.ssa;
    } else {
       nir_builder_instr_insert(&b->nb, &intrin->instr);
    }
@@ -2305,9 +2354,17 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
    nir_alu_instr *vec = create_vec(b->shader, num_components,
                                    srcs[0]->bit_size);
 
+   /* From the SPIR-V 1.1 spec for OpCompositeConstruct:
+    *
+    *    "When constructing a vector, there must be at least two Constituent
+    *    operands."
+    */
+   assert(num_srcs >= 2);
+
    unsigned dest_idx = 0;
    for (unsigned i = 0; i < num_srcs; i++) {
       nir_ssa_def *src = srcs[i];
+      assert(dest_idx + src->num_components <= num_components);
       for (unsigned j = 0; j < src->num_components; j++) {
          vec->src[dest_idx].src = nir_src_for_ssa(src);
          vec->src[dest_idx].swizzle[0] = j;
@@ -2315,6 +2372,13 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
       }
    }
 
+   /* From the SPIR-V 1.1 spec for OpCompositeConstruct:
+    *
+    *    "When constructing a vector, the total number of components in all
+    *    the operands must equal the number of components in Result Type."
+    */
+   assert(dest_idx == num_components);
+
    nir_builder_instr_insert(&b->nb, &vec->instr);
 
    return &vec->dest.dest.ssa;
@@ -2611,7 +2675,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityVector16:
       case SpvCapabilityFloat16Buffer:
       case SpvCapabilityFloat16:
-      case SpvCapabilityInt64:
       case SpvCapabilityInt64Atomics:
       case SpvCapabilityAtomicStorage:
       case SpvCapabilityInt16:
@@ -2621,8 +2684,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilitySparseResidency:
       case SpvCapabilityMinLod:
       case SpvCapabilityTransformFeedback:
-      case SpvCapabilityStorageImageReadWithoutFormat:
-      case SpvCapabilityStorageImageWriteWithoutFormat:
          vtn_warn("Unsupported SPIR-V capability: %s",
                   spirv_capability_to_string(cap));
          break;
@@ -2630,6 +2691,9 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityFloat64:
          spv_check_supported(float64, cap);
          break;
+      case SpvCapabilityInt64:
+         spv_check_supported(int64, cap);
+         break;
 
       case SpvCapabilityAddresses:
       case SpvCapabilityKernel:
@@ -2653,6 +2717,25 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityTessellationPointSize:
          spv_check_supported(tessellation, cap);
          break;
+
+      case SpvCapabilityDrawParameters:
+         spv_check_supported(draw_parameters, cap);
+         break;
+
+      case SpvCapabilityStorageImageReadWithoutFormat:
+         spv_check_supported(image_read_without_format, cap);
+         break;
+
+      case SpvCapabilityStorageImageWriteWithoutFormat:
+         spv_check_supported(image_write_without_format, cap);
+         break;
+
+      case SpvCapabilityMultiView:
+         spv_check_supported(multiview, cap);
+         break;
+
+      default:
+         unreachable("Unhandled capability");
       }
       break;
    }
@@ -2725,36 +2808,36 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
 
    case SpvExecutionModeEarlyFragmentTests:
       assert(b->shader->stage == MESA_SHADER_FRAGMENT);
-      b->shader->info->fs.early_fragment_tests = true;
+      b->shader->info.fs.early_fragment_tests = true;
       break;
 
    case SpvExecutionModeInvocations:
       assert(b->shader->stage == MESA_SHADER_GEOMETRY);
-      b->shader->info->gs.invocations = MAX2(1, mode->literals[0]);
+      b->shader->info.gs.invocations = MAX2(1, mode->literals[0]);
       break;
 
    case SpvExecutionModeDepthReplacing:
       assert(b->shader->stage == MESA_SHADER_FRAGMENT);
-      b->shader->info->fs.depth_layout = FRAG_DEPTH_LAYOUT_ANY;
+      b->shader->info.fs.depth_layout = FRAG_DEPTH_LAYOUT_ANY;
       break;
    case SpvExecutionModeDepthGreater:
       assert(b->shader->stage == MESA_SHADER_FRAGMENT);
-      b->shader->info->fs.depth_layout = FRAG_DEPTH_LAYOUT_GREATER;
+      b->shader->info.fs.depth_layout = FRAG_DEPTH_LAYOUT_GREATER;
       break;
    case SpvExecutionModeDepthLess:
       assert(b->shader->stage == MESA_SHADER_FRAGMENT);
-      b->shader->info->fs.depth_layout = FRAG_DEPTH_LAYOUT_LESS;
+      b->shader->info.fs.depth_layout = FRAG_DEPTH_LAYOUT_LESS;
       break;
    case SpvExecutionModeDepthUnchanged:
       assert(b->shader->stage == MESA_SHADER_FRAGMENT);
-      b->shader->info->fs.depth_layout = FRAG_DEPTH_LAYOUT_UNCHANGED;
+      b->shader->info.fs.depth_layout = FRAG_DEPTH_LAYOUT_UNCHANGED;
       break;
 
    case SpvExecutionModeLocalSize:
       assert(b->shader->stage == MESA_SHADER_COMPUTE);
-      b->shader->info->cs.local_size[0] = mode->literals[0];
-      b->shader->info->cs.local_size[1] = mode->literals[1];
-      b->shader->info->cs.local_size[2] = mode->literals[2];
+      b->shader->info.cs.local_size[0] = mode->literals[0];
+      b->shader->info.cs.local_size[1] = mode->literals[1];
+      b->shader->info.cs.local_size[2] = mode->literals[2];
       break;
    case SpvExecutionModeLocalSizeHint:
       break; /* Nothing to do with this */
@@ -2762,10 +2845,10 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
    case SpvExecutionModeOutputVertices:
       if (b->shader->stage == MESA_SHADER_TESS_CTRL ||
           b->shader->stage == MESA_SHADER_TESS_EVAL) {
-         b->shader->info->tess.tcs_vertices_out = mode->literals[0];
+         b->shader->info.tess.tcs_vertices_out = mode->literals[0];
       } else {
          assert(b->shader->stage == MESA_SHADER_GEOMETRY);
-         b->shader->info->gs.vertices_out = mode->literals[0];
+         b->shader->info.gs.vertices_out = mode->literals[0];
       }
       break;
 
@@ -2778,11 +2861,11 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
    case SpvExecutionModeIsolines:
       if (b->shader->stage == MESA_SHADER_TESS_CTRL ||
           b->shader->stage == MESA_SHADER_TESS_EVAL) {
-         b->shader->info->tess.primitive_mode =
+         b->shader->info.tess.primitive_mode =
             gl_primitive_from_spv_execution_mode(mode->exec_mode);
       } else {
          assert(b->shader->stage == MESA_SHADER_GEOMETRY);
-         b->shader->info->gs.vertices_in =
+         b->shader->info.gs.vertices_in =
             vertices_in_from_spv_execution_mode(mode->exec_mode);
       }
       break;
@@ -2791,24 +2874,24 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
    case SpvExecutionModeOutputLineStrip:
    case SpvExecutionModeOutputTriangleStrip:
       assert(b->shader->stage == MESA_SHADER_GEOMETRY);
-      b->shader->info->gs.output_primitive =
+      b->shader->info.gs.output_primitive =
          gl_primitive_from_spv_execution_mode(mode->exec_mode);
       break;
 
    case SpvExecutionModeSpacingEqual:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
              b->shader->stage == MESA_SHADER_TESS_EVAL);
-      b->shader->info->tess.spacing = TESS_SPACING_EQUAL;
+      b->shader->info.tess.spacing = TESS_SPACING_EQUAL;
       break;
    case SpvExecutionModeSpacingFractionalEven:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
              b->shader->stage == MESA_SHADER_TESS_EVAL);
-      b->shader->info->tess.spacing = TESS_SPACING_FRACTIONAL_EVEN;
+      b->shader->info.tess.spacing = TESS_SPACING_FRACTIONAL_EVEN;
       break;
    case SpvExecutionModeSpacingFractionalOdd:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
              b->shader->stage == MESA_SHADER_TESS_EVAL);
-      b->shader->info->tess.spacing = TESS_SPACING_FRACTIONAL_ODD;
+      b->shader->info.tess.spacing = TESS_SPACING_FRACTIONAL_ODD;
       break;
    case SpvExecutionModeVertexOrderCw:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
@@ -2817,18 +2900,18 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
        * but be the opposite of OpenGL.  Currently NIR follows GL semantics,
        * so we set it backwards here.
        */
-      b->shader->info->tess.ccw = true;
+      b->shader->info.tess.ccw = true;
       break;
    case SpvExecutionModeVertexOrderCcw:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
              b->shader->stage == MESA_SHADER_TESS_EVAL);
       /* Backwards; see above */
-      b->shader->info->tess.ccw = false;
+      b->shader->info.tess.ccw = false;
       break;
    case SpvExecutionModePointMode:
       assert(b->shader->stage == MESA_SHADER_TESS_CTRL ||
              b->shader->stage == MESA_SHADER_TESS_EVAL);
-      b->shader->info->tess.point_mode = true;
+      b->shader->info.tess.point_mode = true;
       break;
 
    case SpvExecutionModePixelCenterInteger:
@@ -2842,6 +2925,9 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
    case SpvExecutionModeVecTypeHint:
    case SpvExecutionModeContractionOff:
       break; /* OpenCL */
+
+   default:
+      unreachable("Unhandled execution mode");
    }
 }
 
@@ -2907,6 +2993,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_constant(b, opcode, w, count);
       break;
 
+   case SpvOpUndef:
    case SpvOpVariable:
       vtn_handle_variables(b, opcode, w, count);
       break;
@@ -3200,7 +3287,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    b->shader = nir_shader_create(NULL, stage, options, NULL);
 
    /* Set shader info defaults */
-   b->shader->info->gs.invocations = 1;
+   b->shader->info.gs.invocations = 1;
 
    /* Parse execution modes */
    vtn_foreach_execution_mode(b, b->entry_point,