swr/rast: increase number of possible draws in flight
[mesa.git] / src / gallium / drivers / swr / rasterizer / core / frontend.cpp
index 99d047da42032121d2b388aa15ac0b959b6ea4fa..8796878c5863d12f75db5833cc2bb4b0cf5e062d 100644 (file)
@@ -507,7 +507,7 @@ static void StreamOut(
     uint32_t soVertsPerPrim = NumVertsPerPrim(pa.binTopology, false);
 
     // The pPrimData buffer is sparse in that we allocate memory for all 32 attributes for each vertex.
-    uint32_t primDataDwordVertexStride = (KNOB_NUM_ATTRIBUTES * sizeof(float) * 4) / sizeof(uint32_t);
+    uint32_t primDataDwordVertexStride = (SWR_VTX_NUM_SLOTS * sizeof(float) * 4) / sizeof(uint32_t);
 
     SWR_STREAMOUT_CONTEXT soContext = { 0 };
 
@@ -518,6 +518,7 @@ static void StreamOut(
     }
 
     uint32_t numPrims = pa.NumPrims();
+
     for (uint32_t primIndex = 0; primIndex < numPrims; ++primIndex)
     {
         DWORD slot = 0;
@@ -526,8 +527,8 @@ static void StreamOut(
         // Write all entries into primitive data buffer for SOS.
         while (_BitScanForward(&slot, soMask))
         {
-            __m128 attrib[MAX_NUM_VERTS_PER_PRIM];    // prim attribs (always 4 wide)
-            uint32_t paSlot = slot + VERTEX_ATTRIB_START_SLOT;
+            simd4scalar attrib[MAX_NUM_VERTS_PER_PRIM];    // prim attribs (always 4 wide)
+            uint32_t paSlot = slot + soState.vertexAttribOffset[streamIndex];
             pa.AssembleSingle(paSlot, primIndex, attrib);
 
             // Attribute offset is relative offset from start of vertex.
@@ -543,6 +544,7 @@ static void StreamOut(
 
                 _mm_store_ps((float*)pPrimDataAttrib, attrib[v]);
             }
+
             soMask &= ~(1 << slot);
         }
 
@@ -604,17 +606,19 @@ INLINE static T RoundDownEven(T value)
 }
 
 //////////////////////////////////////////////////////////////////////////
-/// Pack pairs of simdvertexes into simd16vertexes, in-place
+/// Pack pairs of simdvertexes into simd16vertexes, assume non-overlapping
 ///
 /// vertexCount is in terms of the source simdvertexes and must be even
 ///
 /// attribCount will limit the vector copies to those attribs specified
 ///
-void PackPairsOfSimdVertexIntoSimd16VertexInPlace(simdvertex *vertex, uint32_t vertexCount, uint32_t attribCount)
+/// note: the stride between vertexes is determinded by SWR_VTX_NUM_SLOTS
+///
+void PackPairsOfSimdVertexIntoSimd16Vertex(simd16vertex *vertex_simd16, const simdvertex *vertex, uint32_t vertexCount, uint32_t attribCount)
 {
     SWR_ASSERT(vertex);
-    SWR_ASSERT(IsEven(vertexCount));
-    SWR_ASSERT(attribCount <= KNOB_NUM_ATTRIBUTES);
+    SWR_ASSERT(vertex_simd16);
+    SWR_ASSERT(attribCount <= SWR_VTX_NUM_SLOTS);
 
     simd16vertex temp;
 
@@ -624,14 +628,18 @@ void PackPairsOfSimdVertexIntoSimd16VertexInPlace(simdvertex *vertex, uint32_t v
         {
             for (uint32_t k = 0; k < 4; k += 1)
             {
-                temp.attrib[j][k] = _simd16_insert_ps(_simd16_setzero_ps(),  vertex[i].attrib[j][k], 0);
-                temp.attrib[j][k] = _simd16_insert_ps(temp.attrib[j][k], vertex[i + 1].attrib[j][k], 1);
+                temp.attrib[j][k] = _simd16_insert_ps(_simd16_setzero_ps(), vertex[i].attrib[j][k], 0);
+
+                if ((i + 1) < vertexCount)
+                {
+                    temp.attrib[j][k] = _simd16_insert_ps(temp.attrib[j][k], vertex[i + 1].attrib[j][k], 1);
+                }
             }
         }
 
         for (uint32_t j = 0; j < attribCount; j += 1)
         {
-            reinterpret_cast<simd16vertex *>(vertex)[i >> 1].attrib[j] = temp.attrib[j];
+            vertex_simd16[i >> 1].attrib[j] = temp.attrib[j];
         }
     }
 }
@@ -695,7 +703,7 @@ void ProcessStreamIdBuffer(uint32_t stream, uint8_t* pStreamIdBase, uint32_t num
             }
             curInputByte >>= 2;
         }
-        
+
         *pCutBuffer++ = outByte;
     }
 }
@@ -707,12 +715,7 @@ struct GsBufferInfo
 {
     GsBufferInfo(const SWR_GS_STATE &gsState)
     {
-#if USE_SIMD16_FRONTEND
-        // TEMPORARY: pad up to multiple of two, to support in-place conversion from simdvertex to simd16vertex
-        const uint32_t vertexCount = RoundUpEven(gsState.maxNumVerts);
-#else
         const uint32_t vertexCount = gsState.maxNumVerts;
-#endif
         const uint32_t vertexStride = sizeof(SIMDVERTEX);
         const uint32_t numSimdBatches = (vertexCount + SIMD_WIDTH - 1) / SIMD_WIDTH;
 
@@ -784,20 +787,20 @@ static void GeometryShaderStage(
     tlsGsContext.PrimitiveID = primID;
 
     uint32_t numVertsPerPrim = NumVertsPerPrim(pa.binTopology, true);
-    simdvector attrib[MAX_ATTRIBUTES];
+    simdvector attrib[MAX_NUM_VERTS_PER_PRIM];
 
     // assemble all attributes for the input primitive
     for (uint32_t slot = 0; slot < pState->numInputAttribs; ++slot)
     {
-        uint32_t attribSlot = VERTEX_ATTRIB_START_SLOT + slot;
+        uint32_t attribSlot = pState->vertexAttribOffset + slot;
         pa.Assemble(attribSlot, attrib);
 
         for (uint32_t i = 0; i < numVertsPerPrim; ++i)
         {
-            tlsGsContext.vert[i].attrib[attribSlot] = attrib[i];
+            tlsGsContext.vert[i].attrib[VERTEX_ATTRIB_START_SLOT + slot] = attrib[i];
         }
     }
-    
+
     // assemble position
     pa.Assemble(VERTEX_POSITION_SLOT, attrib);
     for (uint32_t i = 0; i < numVertsPerPrim; ++i)
@@ -805,7 +808,11 @@ static void GeometryShaderStage(
         tlsGsContext.vert[i].attrib[VERTEX_POSITION_SLOT] = attrib[i];
     }
 
+#if USE_SIMD16_FRONTEND
+    const GsBufferInfo<simd16vertex, KNOB_SIMD16_WIDTH> bufferInfo(state.gsState);
+#else
     const GsBufferInfo<simdvertex, KNOB_SIMD_WIDTH> bufferInfo(state.gsState);
+#endif
 
     // record valid prims from the frontend to avoid over binning the newly generated
     // prims from the GS
@@ -828,6 +835,20 @@ static void GeometryShaderStage(
     }
 
     // set up new binner and state for the GS output topology
+#if USE_SIMD16_FRONTEND
+    PFN_PROCESS_PRIMS_SIMD16 pfnClipFunc = nullptr;
+    if (HasRastT::value)
+    {
+        switch (pState->outputTopology)
+        {
+        case TOP_TRIANGLE_STRIP:    pfnClipFunc = ClipTriangles_simd16; break;
+        case TOP_LINE_STRIP:        pfnClipFunc = ClipLines_simd16; break;
+        case TOP_POINT_LIST:        pfnClipFunc = ClipPoints_simd16; break;
+        default: SWR_INVALID("Unexpected GS output topology: %d", pState->outputTopology);
+        }
+    }
+
+#else
     PFN_PROCESS_PRIMS pfnClipFunc = nullptr;
     if (HasRastT::value)
     {
@@ -840,6 +861,7 @@ static void GeometryShaderStage(
         }
     }
 
+#endif
     // foreach input prim:
     // - setup a new PA based on the emitted verts for that prim
     // - loop over the new verts, calling PA to assemble each prim
@@ -862,7 +884,7 @@ static void GeometryShaderStage(
 
             uint8_t* pBase = pInstanceBase + instance * bufferInfo.vertexInstanceStride;
             uint8_t* pCutBase = pCutBufferBase + instance * bufferInfo.cutInstanceStride;
-            
+
             uint32_t numAttribs = state.feNumAttributes;
 
             for (uint32_t stream = 0; stream < MAX_SO_STREAMS; ++stream)
@@ -894,21 +916,10 @@ static void GeometryShaderStage(
                 }
 
 #if USE_SIMD16_FRONTEND
-                // TEMPORARY: GS outputs simdvertex, PA inputs simd16vertex, so convert simdvertex to simd16vertex, in-place
-
-                const uint32_t attribCount = VERTEX_ATTRIB_START_SLOT + pState->numInputAttribs;
-
-                PackPairsOfSimdVertexIntoSimd16VertexInPlace(
-                    reinterpret_cast<simdvertex *>(pBase),
-                    RoundUpEven(numEmittedVerts),                               // simd8 -> simd16
-                    attribCount);
-
-#endif
-#if USE_SIMD16_FRONTEND
-                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, reinterpret_cast<simd16mask *>(pCutBuffer), numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
+                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, SWR_VTX_NUM_SLOTS, reinterpret_cast<simd16mask *>(pCutBuffer), numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
 
 #else
-                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, pCutBuffer, numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
+                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, SWR_VTX_NUM_SLOTS, pCutBuffer, numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
 
 #endif
                 while (gsPa.GetNextStreamOutput())
@@ -930,111 +941,22 @@ static void GeometryShaderStage(
 
                             if (HasStreamOutT::value)
                             {
+#if ENABLE_AVX512_SIMD16
+                                gsPa.useAlternateOffset = false;
+#endif
                                 StreamOut(pDC, gsPa, workerId, pSoPrimData, stream);
                             }
 
                             if (HasRastT::value && state.soState.streamToRasterizer == stream)
                             {
 #if USE_SIMD16_FRONTEND
-                                simd16scalari vPrimId;
-                                // pull primitiveID from the GS output if available
-                                if (state.gsState.emitsPrimitiveID)
-                                {
-                                    simd16vector primIdAttrib[3];
-                                    gsPa.Assemble_simd16(VERTEX_PRIMID_SLOT, primIdAttrib);
-                                    vPrimId = _simd16_castps_si(primIdAttrib[state.frontendState.topologyProvokingVertex].x);
-                                }
-                                else
-                                {
-                                    vPrimId = _simd16_set1_epi32(pPrimitiveId[inputPrim]);
-                                }
-
-                                // use viewport array index if GS declares it as an output attribute. Otherwise use index 0.
-                                simd16scalari vViewPortIdx;
-                                if (state.gsState.emitsViewportArrayIndex)
-                                {
-                                    simd16vector vpiAttrib[3];
-                                    gsPa.Assemble_simd16(VERTEX_VIEWPORT_ARRAY_INDEX_SLOT, vpiAttrib);
-
-                                    // OOB indices => forced to zero.
-                                    simd16scalari vNumViewports = _simd16_set1_epi32(KNOB_NUM_VIEWPORTS_SCISSORS);
-                                    simd16scalari vClearMask = _simd16_cmplt_epi32(_simd16_castps_si(vpiAttrib[0].x), vNumViewports);
-                                    vpiAttrib[0].x = _simd16_and_ps(_simd16_castsi_ps(vClearMask), vpiAttrib[0].x);
-
-                                    vViewPortIdx = _simd16_castps_si(vpiAttrib[0].x);
-                                }
-                                else
-                                {
-                                    vViewPortIdx = _simd16_set1_epi32(0);
-                                }
-
-                                const uint32_t primMask = GenMask(gsPa.NumPrims());
-                                const uint32_t primMask_lo = primMask & 255;
-                                const uint32_t primMask_hi = (primMask >> 8) & 255;
-
-                                const simd16scalari primID = vPrimId;
-                                const simdscalari primID_lo = _simd16_extract_si(primID, 0);
-                                const simdscalari primID_hi = _simd16_extract_si(primID, 1);
-
-                                for (uint32_t i = 0; i < 3; i += 1)
-                                {
-                                    for (uint32_t j = 0; j < 4; j += 1)
-                                    {
-                                        attrib[i][j] = _simd16_extract_ps(attrib_simd16[i][j], 0);
-                                    }
-                                }
+                                simd16scalari vPrimId = _simd16_set1_epi32(pPrimitiveId[inputPrim]);
 
                                 gsPa.useAlternateOffset = false;
-                                pfnClipFunc(pDC, gsPa, workerId, attrib, primMask_lo, primID_lo, _simd16_extract_si(vViewPortIdx, 0));
-
-                                if (primMask_hi)
-                                {
-                                    for (uint32_t i = 0; i < 3; i += 1)
-                                    {
-                                        for (uint32_t j = 0; j < 4; j += 1)
-                                        {
-                                            attrib[i][j] = _simd16_extract_ps(attrib_simd16[i][j], 1);
-                                        }
-                                    }
-
-                                    gsPa.useAlternateOffset = true;
-                                    pfnClipFunc(pDC, gsPa, workerId, attrib, primMask_hi, primID_hi, _simd16_extract_si(vViewPortIdx, 1));
-                                }
-
+                                pfnClipFunc(pDC, gsPa, workerId, attrib_simd16, GenMask(gsPa.NumPrims()), vPrimId);
 #else
-                                simdscalari vPrimId;
-                                // pull primitiveID from the GS output if available
-                                if (state.gsState.emitsPrimitiveID)
-                                {
-                                    simdvector primIdAttrib[3];
-                                    gsPa.Assemble(VERTEX_PRIMID_SLOT, primIdAttrib);
-                                    vPrimId = _simd_castps_si(primIdAttrib[state.frontendState.topologyProvokingVertex].x);
-                                }
-                                else
-                                {
-                                    vPrimId = _simd_set1_epi32(pPrimitiveId[inputPrim]);
-                                }
-
-                                // use viewport array index if GS declares it as an output attribute. Otherwise use index 0.
-                                simdscalari vViewPortIdx;
-                                if (state.gsState.emitsViewportArrayIndex)
-                                {
-                                    simdvector vpiAttrib[3];
-                                    gsPa.Assemble(VERTEX_VIEWPORT_ARRAY_INDEX_SLOT, vpiAttrib);
-
-                                    // OOB indices => forced to zero.
-                                    simdscalari vNumViewports = _simd_set1_epi32(KNOB_NUM_VIEWPORTS_SCISSORS);
-                                    simdscalari vClearMask = _simd_cmplt_epi32(_simd_castps_si(vpiAttrib[0].x), vNumViewports);
-                                    vpiAttrib[0].x = _simd_and_ps(_simd_castsi_ps(vClearMask), vpiAttrib[0].x);
-
-                                    vViewPortIdx = _simd_castps_si(vpiAttrib[0].x);
-                                }
-                                else
-                                {
-                                    vViewPortIdx = _simd_set1_epi32(0);
-                                }
-
-                                pfnClipFunc(pDC, gsPa, workerId, attrib, GenMask(gsPa.NumPrims()), vPrimId, vViewPortIdx);
+                                simdscalari vPrimId = _simd_set1_epi32(pPrimitiveId[inputPrim]);
+                                pfnClipFunc(pDC, gsPa, workerId, attrib, GenMask(gsPa.NumPrims()), vPrimId);
 #endif
                             }
                         }
@@ -1173,6 +1095,20 @@ static void TessellationStages(
     }
     SWR_ASSERT(tsCtx);
 
+#if USE_SIMD16_FRONTEND
+    PFN_PROCESS_PRIMS_SIMD16 pfnClipFunc = nullptr;
+    if (HasRastT::value)
+    {
+        switch (tsState.postDSTopology)
+        {
+        case TOP_TRIANGLE_LIST: pfnClipFunc = ClipTriangles_simd16; break;
+        case TOP_LINE_LIST:     pfnClipFunc = ClipLines_simd16; break;
+        case TOP_POINT_LIST:    pfnClipFunc = ClipPoints_simd16; break;
+        default: SWR_INVALID("Unexpected DS output topology: %d", tsState.postDSTopology);
+        }
+    }
+
+#else
     PFN_PROCESS_PRIMS pfnClipFunc = nullptr;
     if (HasRastT::value)
     {
@@ -1185,6 +1121,7 @@ static void TessellationStages(
         }
     }
 
+#endif
     SWR_HS_CONTEXT& hsContext = gt_pTessellationThreadData->hsContext;
     hsContext.pCPout = gt_pTessellationThreadData->patchData;
     hsContext.PrimitiveID = primID;
@@ -1196,12 +1133,12 @@ static void TessellationStages(
     // assemble all attributes for the input primitives
     for (uint32_t slot = 0; slot < tsState.numHsInputAttribs; ++slot)
     {
-        uint32_t attribSlot = VERTEX_ATTRIB_START_SLOT + slot;
+        uint32_t attribSlot = tsState.vertexAttribOffset + slot;
         pa.Assemble(attribSlot, simdattrib);
 
         for (uint32_t i = 0; i < numVertsPerPrim; ++i)
         {
-            hsContext.vert[i].attrib[attribSlot] = simdattrib[i];
+            hsContext.vert[i].attrib[VERTEX_ATTRIB_START_SLOT + slot] = simdattrib[i];
         }
     }
 
@@ -1244,7 +1181,7 @@ static void TessellationStages(
         uint32_t requiredDSVectorInvocations = AlignUp(tsData.NumDomainPoints, KNOB_SIMD_WIDTH) / KNOB_SIMD_WIDTH;
         size_t requiredDSOutputVectors = requiredDSVectorInvocations * tsState.numDsOutputAttribs;
 #if USE_SIMD16_FRONTEND
-        size_t requiredAllocSize = sizeof(simdvector) * RoundUpEven(requiredDSOutputVectors);       // simd8 -> simd16, padding
+        size_t requiredAllocSize = sizeof(simdvector) * RoundUpEven(requiredDSVectorInvocations) * tsState.numDsOutputAttribs;      // simd8 -> simd16, padding
 #else
         size_t requiredAllocSize = sizeof(simdvector) * requiredDSOutputVectors;
 #endif
@@ -1253,7 +1190,7 @@ static void TessellationStages(
             AlignedFree(gt_pTessellationThreadData->pDSOutput);
             gt_pTessellationThreadData->pDSOutput = (simdscalar*)AlignedMalloc(requiredAllocSize, 64);
 #if USE_SIMD16_FRONTEND
-            gt_pTessellationThreadData->numDSOutputVectors = RoundUpEven(requiredDSOutputVectors);  // simd8 -> simd16, padding
+            gt_pTessellationThreadData->numDSOutputVectors = RoundUpEven(requiredDSVectorInvocations) * tsState.numDsOutputAttribs; // simd8 -> simd16, padding
 #else
             gt_pTessellationThreadData->numDSOutputVectors = requiredDSOutputVectors;
 #endif
@@ -1272,7 +1209,11 @@ static void TessellationStages(
         dsContext.pDomainU = (simdscalar*)tsData.pDomainPointsU;
         dsContext.pDomainV = (simdscalar*)tsData.pDomainPointsV;
         dsContext.pOutputData = gt_pTessellationThreadData->pDSOutput;
+#if USE_SIMD16_FRONTEND
+        dsContext.vectorStride = RoundUpEven(requiredDSVectorInvocations);      // simd8 -> simd16
+#else
         dsContext.vectorStride = requiredDSVectorInvocations;
+#endif
 
         uint32_t dsInvocations = 0;
 
@@ -1289,23 +1230,19 @@ static void TessellationStages(
         UPDATE_STAT_FE(DsInvocations, tsData.NumDomainPoints);
 
 #if USE_SIMD16_FRONTEND
-        // TEMPORARY: DS outputs simdvertex, PA inputs simd16vertex, so convert simdvertex to simd16vertex, in-place
-
-        PackPairsOfSimdVertexIntoSimd16VertexInPlace(
-            reinterpret_cast<simdvertex *>(dsContext.pOutputData),
-            RoundUpEven(dsContext.vectorStride),                                // simd8 -> simd16
-            tsState.numDsOutputAttribs);
+        SWR_ASSERT(IsEven(dsContext.vectorStride));                             // simd8 -> simd16
 
 #endif
         PA_TESS tessPa(
             pDC,
 #if USE_SIMD16_FRONTEND
             reinterpret_cast<const simd16scalar *>(dsContext.pOutputData),      // simd8 -> simd16
-            RoundUpEven(dsContext.vectorStride) / 2,                            // simd8 -> simd16
+            dsContext.vectorStride / 2,                                         // simd8 -> simd16
 #else
             dsContext.pOutputData,
             dsContext.vectorStride,
 #endif
+            SWR_VTX_NUM_SLOTS,
             tsState.numDsOutputAttribs,
             tsData.ppIndices,
             tsData.NumPrimitives,
@@ -1318,10 +1255,6 @@ static void TessellationStages(
             const uint32_t numPrims_lo = std::min<uint32_t>(numPrims, KNOB_SIMD_WIDTH);
             const uint32_t numPrims_hi = std::max<uint32_t>(numPrims, KNOB_SIMD_WIDTH) - KNOB_SIMD_WIDTH;
 
-            const uint32_t primMask = GenMask(numPrims);
-            const uint32_t primMask_lo = primMask & 255;
-            const uint32_t primMask_hi = (primMask >> 8) & 255;
-
             const simd16scalari primID = _simd16_set1_epi32(dsContext.PrimitiveID);
             const simdscalari primID_lo = _simd16_extract_si(primID, 0);
             const simdscalari primID_hi = _simd16_extract_si(primID, 1);
@@ -1348,14 +1281,18 @@ static void TessellationStages(
             {
                 if (HasStreamOutT::value)
                 {
+#if ENABLE_AVX512_SIMD16
+                    tessPa.useAlternateOffset = false;
+#endif
                     StreamOut(pDC, tessPa, workerId, pSoPrimData, 0);
                 }
 
                 if (HasRastT::value)
                 {
-                    simdvector      prim[3]; // Only deal with triangles, lines, or points
 #if USE_SIMD16_FRONTEND
-                    simd16vector    prim_simd16[3];
+                    simd16vector    prim_simd16[3]; // Only deal with triangles, lines, or points
+#else
+                    simdvector      prim[3];        // Only deal with triangles, lines, or points
 #endif
                     AR_BEGIN(FEPAAssemble, pDC->drawId);
                     bool assemble =
@@ -1369,33 +1306,11 @@ static void TessellationStages(
 
                     SWR_ASSERT(pfnClipFunc);
 #if USE_SIMD16_FRONTEND
-                    for (uint32_t i = 0; i < 3; i += 1)
-                    {
-                        for (uint32_t j = 0; j < 4; j += 1)
-                        {
-                            prim[i][j] = _simd16_extract_ps(prim_simd16[i][j], 0);
-                        }
-                    }
-
                     tessPa.useAlternateOffset = false;
-                    pfnClipFunc(pDC, tessPa, workerId, prim, primMask_lo, primID_lo, _simd_set1_epi32(0));
-
-                    if (primMask_hi)
-                    {
-                        for (uint32_t i = 0; i < 3; i += 1)
-                        {
-                            for (uint32_t j = 0; j < 4; j += 1)
-                            {
-                                prim[i][j] = _simd16_extract_ps(prim_simd16[i][j], 1);
-                            }
-                        }
-
-                        tessPa.useAlternateOffset = true;
-                        pfnClipFunc(pDC, tessPa, workerId, prim, primMask_hi, primID_hi, _simd_set1_epi32(0));
-                    }
+                    pfnClipFunc(pDC, tessPa, workerId, prim_simd16, GenMask(numPrims), primID);
 #else
                     pfnClipFunc(pDC, tessPa, workerId, prim,
-                        GenMask(tessPa.NumPrims()), _simd_set1_epi32(dsContext.PrimitiveID), _simd_set1_epi32(0));
+                        GenMask(tessPa.NumPrims()), _simd_set1_epi32(dsContext.PrimitiveID));
 #endif
                 }
             }
@@ -1405,9 +1320,21 @@ static void TessellationStages(
         } // while (tessPa.HasWork())
     } // for (uint32_t p = 0; p < numPrims; ++p)
 
+#if USE_SIMD16_FRONTEND
+    if (gt_pTessellationThreadData->pDSOutput != nullptr)
+    {
+        AlignedFree(gt_pTessellationThreadData->pDSOutput);
+        gt_pTessellationThreadData->pDSOutput = nullptr;
+    }
+    gt_pTessellationThreadData->numDSOutputVectors = 0;
+
+#endif
     TSDestroyCtx(tsCtx);
 }
 
+THREAD PA_STATE::SIMDVERTEX *pVertexStore = nullptr;
+THREAD uint32_t gVertexStoreSize = 0;
+
 //////////////////////////////////////////////////////////////////////////
 /// @brief FE handler for SwrDraw.
 /// @tparam IsIndexedT - Is indexed drawing enabled
@@ -1486,7 +1413,11 @@ void ProcessDraw(
     void* pStreamCutBuffer = nullptr;
     if (HasGeometryShaderT::value)
     {
+#if USE_SIMD16_FRONTEND
+        AllocateGsBuffers<simd16vertex, KNOB_SIMD16_WIDTH>(pDC, state, &pGsOut, &pCutBuffer, &pStreamCutBuffer);
+#else
         AllocateGsBuffers<simdvertex, KNOB_SIMD_WIDTH>(pDC, state, &pGsOut, &pCutBuffer, &pStreamCutBuffer);
+#endif
     }
 
     if (HasTessellationT::value)
@@ -1511,8 +1442,36 @@ void ProcessDraw(
         pSoPrimData = (uint32_t*)pDC->pArena->AllocAligned(4096, 16);
     }
 
+    const uint32_t vertexCount = NumVertsPerPrim(state.topology, true);
+#if USE_SIMD16_FRONTEND
+    uint32_t simdVertexSizeBytes = state.frontendState.vsVertexSize * sizeof(simd16vector);
+#else
+    uint32_t simdVertexSizeBytes = state.frontendState.vsVertexSize * sizeof(simdvector);
+#endif
+
+    SWR_ASSERT(vertexCount <= MAX_NUM_VERTS_PER_PRIM);
+
+    // Compute storage requirements for vertex store
+    // TODO: allocation needs to be rethought for better cut support
+    uint32_t numVerts = vertexCount + 2; // Need extra space for PA state machine
+    uint32_t vertexStoreSize = numVerts * simdVertexSizeBytes;
+
+    // grow the vertex store for the PA as necessary
+    if (gVertexStoreSize < vertexStoreSize)
+    {
+        if (pVertexStore != nullptr)
+        {
+            AlignedFree(pVertexStore);
+        }
+
+        pVertexStore = reinterpret_cast<PA_STATE::SIMDVERTEX *>(AlignedMalloc(vertexStoreSize, 64));
+        gVertexStoreSize = vertexStoreSize;
+
+        SWR_ASSERT(pVertexStore != nullptr);
+    }
+
     // choose primitive assembler
-    PA_FACTORY<IsIndexedT, IsCutIndexEnabledT> paFactory(pDC, state.topology, work.numVerts);
+    PA_FACTORY<IsIndexedT, IsCutIndexEnabledT> paFactory(pDC, state.topology, work.numVerts, pVertexStore, numVerts, state.frontendState.vsVertexSize);
     PA_STATE& pa = paFactory.GetPA();
 
 #if USE_SIMD16_FRONTEND
@@ -1523,6 +1482,8 @@ void ProcessDraw(
 
     vsContext_lo.pVin = &vin_lo;
     vsContext_hi.pVin = &vin_hi;
+    vsContext_lo.AlternateOffset = 0;
+    vsContext_hi.AlternateOffset = 1;
 
     SWR_FETCH_CONTEXT   fetchInfo_lo = { 0 };
 
@@ -1592,20 +1553,18 @@ void ProcessDraw(
                 pvCutIndices_hi = &reinterpret_cast<simdmask *>(&pa.GetNextVsIndices())[1];
             }
 
-            simdvertex vout_lo;
-            simdvertex vout_hi;
-
-            vsContext_lo.pVout = &vout_lo;
-            vsContext_hi.pVout = &vout_hi;
-
             simd16vertex &vout = pa.GetNextVsOutput();
 
+            vsContext_lo.pVout = reinterpret_cast<simdvertex *>(&vout);
+            vsContext_hi.pVout = reinterpret_cast<simdvertex *>(&vout);
+
             if (i < endVertex)
             {
                 // 1. Execute FS/VS for a single SIMD.
                 AR_BEGIN(FEFetchShader, pDC->drawId);
                 state.pfnFetchFunc(fetchInfo_lo, vin_lo);
-                if ((i + KNOB_SIMD_WIDTH) < endVertex)
+
+                if ((i + KNOB_SIMD_WIDTH) < endVertex)  // 1/2 of KNOB_SIMD16_WIDTH
                 {
                     state.pfnFetchFunc(fetchInfo_hi, vin_hi);
                 }
@@ -1635,35 +1594,9 @@ void ProcessDraw(
                     AR_BEGIN(FEVertexShader, pDC->drawId);
                     state.pfnVertexFunc(GetPrivateState(pDC), &vsContext_lo);
 
-                    // copy SIMD vout_lo to lo part of SIMD16 vout
-                    {
-                        const uint32_t voutNumSlots = VERTEX_ATTRIB_START_SLOT + state.feNumAttributes;
-
-                        for (uint32_t i = 0; i < voutNumSlots; i += 1)
-                        {
-                            for (uint32_t j = 0; j < 4; j += 1)
-                            {
-                                vout.attrib[i][j] = _simd16_insert_ps(_simd16_setzero_ps(), vout_lo.attrib[i][j], 0);
-                            }
-                        }
-                    }
-
-                    if ((i + KNOB_SIMD_WIDTH) < endVertex)
+                    if ((i + KNOB_SIMD_WIDTH) < endVertex)  // 1/2 of KNOB_SIMD16_WIDTH
                     {
                         state.pfnVertexFunc(GetPrivateState(pDC), &vsContext_hi);
-
-                        // copy SIMD vout_hi to hi part of SIMD16 vout
-                        {
-                            const uint32_t voutNumSlots = VERTEX_ATTRIB_START_SLOT + state.feNumAttributes;
-
-                            for (uint32_t i = 0; i < voutNumSlots; i += 1)
-                            {
-                                for (uint32_t j = 0; j < 4; j += 1)
-                                {
-                                    vout.attrib[i][j] = _simd16_insert_ps(vout.attrib[i][j], vout_hi.attrib[i][j], 1);
-                                }
-                            }
-                        }
                     }
                     AR_END(FEVertexShader, 0);
 
@@ -1696,10 +1629,6 @@ void ProcessDraw(
                             const uint32_t numPrims_lo = std::min<uint32_t>(numPrims, KNOB_SIMD_WIDTH);
                             const uint32_t numPrims_hi = std::max<uint32_t>(numPrims, KNOB_SIMD_WIDTH) - KNOB_SIMD_WIDTH;
 
-                            const uint32_t primMask = GenMask(numPrims);
-                            const uint32_t primMask_lo = primMask & 255;
-                            const uint32_t primMask_hi = (primMask >> 8) & 255;
-
                             const simd16scalari primID = pa.GetPrimID(work.startPrimID);
                             const simdscalari primID_lo = _simd16_extract_si(primID, 0);
                             const simdscalari primID_hi = _simd16_extract_si(primID, 1);
@@ -1731,40 +1660,16 @@ void ProcessDraw(
                                 // If streamout is enabled then stream vertices out to memory.
                                 if (HasStreamOutT::value)
                                 {
-                                    pa.useAlternateOffset = false;  // StreamOut() is SIMD16-compatible..
+                                    pa.useAlternateOffset = false;
                                     StreamOut(pDC, pa, workerId, pSoPrimData, 0);
                                 }
 
                                 if (HasRastT::value)
                                 {
-                                    SWR_ASSERT(pDC->pState->pfnProcessPrims);
-
-                                    simdvector prim[MAX_NUM_VERTS_PER_PRIM];
-
-                                    for (uint32_t i = 0; i < 3; i += 1)
-                                    {
-                                        for (uint32_t j = 0; j < 4; j += 1)
-                                        {
-                                            prim[i][j] = _simd16_extract_ps(prim_simd16[i][j], 0);
-                                        }
-                                    }
+                                    SWR_ASSERT(pDC->pState->pfnProcessPrims_simd16);
 
                                     pa.useAlternateOffset = false;
-                                    pDC->pState->pfnProcessPrims(pDC, pa, workerId, prim, primMask_lo, primID_lo, _simd_setzero_si());
-
-                                    if (primMask_hi)
-                                    {
-                                        for (uint32_t i = 0; i < 3; i += 1)
-                                        {
-                                            for (uint32_t j = 0; j < 4; j += 1)
-                                            {
-                                                prim[i][j] = _simd16_extract_ps(prim_simd16[i][j], 1);
-                                            }
-                                        }
-
-                                        pa.useAlternateOffset = true;
-                                        pDC->pState->pfnProcessPrims(pDC, pa, workerId, prim, primMask_hi, primID_hi, _simd_setzero_si());
-                                    }
+                                    pDC->pState->pfnProcessPrims_simd16(pDC, pa, workerId, prim_simd16, GenMask(numPrims), primID);
                                 }
                             }
                         }
@@ -1789,11 +1694,7 @@ void ProcessDraw(
     }
 
 #else
-    simdvertex          vin;
     SWR_VS_CONTEXT      vsContext;
-
-    vsContext.pVin = &vin;
-
     SWR_FETCH_CONTEXT   fetchInfo = { 0 };
 
     fetchInfo.pStreams = &state.vertexBuffers[0];
@@ -1849,6 +1750,7 @@ void ProcessDraw(
             }
 
             simdvertex& vout = pa.GetNextVsOutput();
+            vsContext.pVin = &vout;
             vsContext.pVout = &vout;
 
             if (i < endVertex)
@@ -1856,7 +1758,7 @@ void ProcessDraw(
 
                 // 1. Execute FS/VS for a single SIMD.
                 AR_BEGIN(FEFetchShader, pDC->drawId);
-                state.pfnFetchFunc(fetchInfo, vin);
+                state.pfnFetchFunc(fetchInfo, vout);
                 AR_END(FEFetchShader, 0);
 
                 // forward fetch generated vertex IDs to the vertex shader
@@ -1929,7 +1831,7 @@ void ProcessDraw(
                                     SWR_ASSERT(pDC->pState->pfnProcessPrims);
 
                                     pDC->pState->pfnProcessPrims(pDC, pa, workerId, prim,
-                                        GenMask(pa.NumPrims()), pa.GetPrimID(work.startPrimID), _simd_set1_epi32(0));
+                                        GenMask(pa.NumPrims()), pa.GetPrimID(work.startPrimID));
                                 }
                             }
                         }