zink: start using per-stage flags for new shaders, refcount shader modules
[mesa.git] / src / gallium / drivers / zink / zink_pipeline.c
index 261bdde5201a8141cfeba11edd52fa36a2c3d5af..637e880eb3ce6b5afa0821ee8f0826373ca50821 100644 (file)
@@ -49,7 +49,20 @@ zink_create_gfx_pipeline(struct zink_screen *screen,
    VkPipelineInputAssemblyStateCreateInfo primitive_state = {};
    primitive_state.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
    primitive_state.topology = primitive_topology;
-   primitive_state.primitiveRestartEnable = VK_FALSE;
+   switch (primitive_topology) {
+   case VK_PRIMITIVE_TOPOLOGY_POINT_LIST:
+   case VK_PRIMITIVE_TOPOLOGY_LINE_LIST:
+   case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST:
+   case VK_PRIMITIVE_TOPOLOGY_LINE_LIST_WITH_ADJACENCY:
+   case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST_WITH_ADJACENCY:
+   case VK_PRIMITIVE_TOPOLOGY_PATCH_LIST:
+      if (state->primitive_restart)
+         debug_printf("restart_index set with unsupported primitive topology %u\n", primitive_topology);
+      primitive_state.primitiveRestartEnable = VK_FALSE;
+      break;
+   default:
+      primitive_state.primitiveRestartEnable = state->primitive_restart ? VK_TRUE : VK_FALSE;
+   }
 
    VkPipelineColorBlendStateCreateInfo blend_state = {};
    blend_state.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;
@@ -127,16 +140,16 @@ zink_create_gfx_pipeline(struct zink_screen *screen,
    pci.pDepthStencilState = &depth_stencil_state;
    pci.pDynamicState = &pipelineDynamicStateCreateInfo;
 
-   VkPipelineShaderStageCreateInfo shader_stages[PIPE_SHADER_TYPES - 1];
+   VkPipelineShaderStageCreateInfo shader_stages[ZINK_SHADER_COUNT];
    uint32_t num_stages = 0;
-   for (int i = 0; i < PIPE_SHADER_TYPES - 1; ++i) {
+   for (int i = 0; i < ZINK_SHADER_COUNT; ++i) {
       if (!prog->stages[i])
          continue;
 
       VkPipelineShaderStageCreateInfo stage = {};
       stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
       stage.stage = zink_shader_stage(i);
-      stage.module = prog->stages[i]->shader_module;
+      stage.module = prog->stages[i]->shader;
       stage.pName = "main";
       shader_stages[num_stages++] = stage;
    }