swr/rast: increase number of possible draws in flight
[mesa.git] / src / gallium / drivers / swr / rasterizer / core / frontend.cpp
index c8dce10c9dea05cd5c27ee02ec5c13a30474573c..8796878c5863d12f75db5833cc2bb4b0cf5e062d 100644 (file)
@@ -302,7 +302,7 @@ uint32_t GetNumPrims(
     case TOP_TRI_STRIP_REVERSE:
     case TOP_PATCHLIST_BASE:
     case TOP_UNKNOWN:
-        SWR_ASSERT(false, "Unsupported topology: %d", mode);
+        SWR_INVALID("Unsupported topology: %d", mode);
         return 0;
     }
 
@@ -378,7 +378,7 @@ uint32_t GetNumVerts(
     case TOP_TRI_STRIP_REVERSE:
     case TOP_PATCHLIST_BASE:
     case TOP_UNKNOWN:
-        SWR_ASSERT(false, "Unsupported topology: %d", mode);
+        SWR_INVALID("Unsupported topology: %d", mode);
         return 0;
     }
 
@@ -455,7 +455,7 @@ INLINE uint32_t NumVertsPerPrim(PRIMITIVE_TOPOLOGY topology, bool includeAdjVert
         numVerts = topology - TOP_PATCHLIST_BASE;
         break;
     default:
-        SWR_ASSERT(false, "Unsupported topology: %d", topology);
+        SWR_INVALID("Unsupported topology: %d", topology);
         break;
     }
 
@@ -507,7 +507,7 @@ static void StreamOut(
     uint32_t soVertsPerPrim = NumVertsPerPrim(pa.binTopology, false);
 
     // The pPrimData buffer is sparse in that we allocate memory for all 32 attributes for each vertex.
-    uint32_t primDataDwordVertexStride = (KNOB_NUM_ATTRIBUTES * sizeof(float) * 4) / sizeof(uint32_t);
+    uint32_t primDataDwordVertexStride = (SWR_VTX_NUM_SLOTS * sizeof(float) * 4) / sizeof(uint32_t);
 
     SWR_STREAMOUT_CONTEXT soContext = { 0 };
 
@@ -518,6 +518,7 @@ static void StreamOut(
     }
 
     uint32_t numPrims = pa.NumPrims();
+
     for (uint32_t primIndex = 0; primIndex < numPrims; ++primIndex)
     {
         DWORD slot = 0;
@@ -526,8 +527,8 @@ static void StreamOut(
         // Write all entries into primitive data buffer for SOS.
         while (_BitScanForward(&slot, soMask))
         {
-            __m128 attrib[MAX_NUM_VERTS_PER_PRIM];    // prim attribs (always 4 wide)
-            uint32_t paSlot = slot + VERTEX_ATTRIB_START_SLOT;
+            simd4scalar attrib[MAX_NUM_VERTS_PER_PRIM];    // prim attribs (always 4 wide)
+            uint32_t paSlot = slot + soState.vertexAttribOffset[streamIndex];
             pa.AssembleSingle(paSlot, primIndex, attrib);
 
             // Attribute offset is relative offset from start of vertex.
@@ -543,6 +544,7 @@ static void StreamOut(
 
                 _mm_store_ps((float*)pPrimDataAttrib, attrib[v]);
             }
+
             soMask &= ~(1 << slot);
         }
 
@@ -575,6 +577,74 @@ static void StreamOut(
     AR_END(FEStreamout, 1);
 }
 
+#if USE_SIMD16_FRONTEND
+//////////////////////////////////////////////////////////////////////////
+/// Is value an even number (a multiple of two)
+///
+template <typename T>
+INLINE static bool IsEven(T value)
+{
+    return (value & 1) == 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+/// Round up value to an even number (a multiple of two)
+///
+template <typename T>
+INLINE static T RoundUpEven(T value)
+{
+    return (value + 1) & ~1;
+}
+
+//////////////////////////////////////////////////////////////////////////
+/// Round down value to an even number (a multiple of two)
+///
+template <typename T>
+INLINE static T RoundDownEven(T value)
+{
+    return value & ~1;
+}
+
+//////////////////////////////////////////////////////////////////////////
+/// Pack pairs of simdvertexes into simd16vertexes, assume non-overlapping
+///
+/// vertexCount is in terms of the source simdvertexes and must be even
+///
+/// attribCount will limit the vector copies to those attribs specified
+///
+/// note: the stride between vertexes is determinded by SWR_VTX_NUM_SLOTS
+///
+void PackPairsOfSimdVertexIntoSimd16Vertex(simd16vertex *vertex_simd16, const simdvertex *vertex, uint32_t vertexCount, uint32_t attribCount)
+{
+    SWR_ASSERT(vertex);
+    SWR_ASSERT(vertex_simd16);
+    SWR_ASSERT(attribCount <= SWR_VTX_NUM_SLOTS);
+
+    simd16vertex temp;
+
+    for (uint32_t i = 0; i < vertexCount; i += 2)
+    {
+        for (uint32_t j = 0; j < attribCount; j += 1)
+        {
+            for (uint32_t k = 0; k < 4; k += 1)
+            {
+                temp.attrib[j][k] = _simd16_insert_ps(_simd16_setzero_ps(), vertex[i].attrib[j][k], 0);
+
+                if ((i + 1) < vertexCount)
+                {
+                    temp.attrib[j][k] = _simd16_insert_ps(temp.attrib[j][k], vertex[i + 1].attrib[j][k], 1);
+                }
+            }
+        }
+
+        for (uint32_t j = 0; j < attribCount; j += 1)
+        {
+            vertex_simd16[i >> 1].attrib[j] = temp.attrib[j];
+        }
+    }
+}
+
+#endif
 //////////////////////////////////////////////////////////////////////////
 /// @brief Computes number of invocations. The current index represents
 ///        the start of the SIMD. The max index represents how much work
@@ -587,7 +657,11 @@ static INLINE uint32_t GetNumInvocations(
     uint32_t maxIndex)
 {
     uint32_t remainder = (maxIndex - curIndex);
+#if USE_SIMD16_FRONTEND
+    return (remainder >= KNOB_SIMD16_WIDTH) ? KNOB_SIMD16_WIDTH : remainder;
+#else
     return (remainder >= KNOB_SIMD_WIDTH) ? KNOB_SIMD_WIDTH : remainder;
+#endif
 }
 
 //////////////////////////////////////////////////////////////////////////
@@ -629,13 +703,53 @@ void ProcessStreamIdBuffer(uint32_t stream, uint8_t* pStreamIdBase, uint32_t num
             }
             curInputByte >>= 2;
         }
-        
+
         *pCutBuffer++ = outByte;
     }
 }
 
 THREAD SWR_GS_CONTEXT tlsGsContext;
 
+template<typename SIMDVERTEX, uint32_t SIMD_WIDTH>
+struct GsBufferInfo
+{
+    GsBufferInfo(const SWR_GS_STATE &gsState)
+    {
+        const uint32_t vertexCount = gsState.maxNumVerts;
+        const uint32_t vertexStride = sizeof(SIMDVERTEX);
+        const uint32_t numSimdBatches = (vertexCount + SIMD_WIDTH - 1) / SIMD_WIDTH;
+
+        vertexPrimitiveStride = vertexStride * numSimdBatches;
+        vertexInstanceStride = vertexPrimitiveStride * SIMD_WIDTH;
+
+        if (gsState.isSingleStream)
+        {
+            cutPrimitiveStride = (vertexCount + 7) / 8;
+            cutInstanceStride = cutPrimitiveStride * SIMD_WIDTH;
+
+            streamCutPrimitiveStride = 0;
+            streamCutInstanceStride = 0;
+        }
+        else
+        {
+            cutPrimitiveStride = AlignUp(vertexCount * 2 / 8, 4);
+            cutInstanceStride = cutPrimitiveStride * SIMD_WIDTH;
+
+            streamCutPrimitiveStride = (vertexCount + 7) / 8;
+            streamCutInstanceStride = streamCutPrimitiveStride * SIMD_WIDTH;
+        }
+    }
+
+    uint32_t vertexPrimitiveStride;
+    uint32_t vertexInstanceStride;
+
+    uint32_t cutPrimitiveStride;
+    uint32_t cutInstanceStride;
+
+    uint32_t streamCutPrimitiveStride;
+    uint32_t streamCutInstanceStride;
+};
+
 //////////////////////////////////////////////////////////////////////////
 /// @brief Implements GS stage.
 /// @param pDC - pointer to draw context.
@@ -653,6 +767,9 @@ static void GeometryShaderStage(
     void* pCutBuffer,
     void* pStreamCutBuffer,
     uint32_t* pSoPrimData,
+#if USE_SIMD16_FRONTEND
+    uint32_t numPrims_simd8,
+#endif
     simdscalari primID)
 {
     SWR_CONTEXT *pContext = pDC->pContext;
@@ -670,20 +787,20 @@ static void GeometryShaderStage(
     tlsGsContext.PrimitiveID = primID;
 
     uint32_t numVertsPerPrim = NumVertsPerPrim(pa.binTopology, true);
-    simdvector attrib[MAX_ATTRIBUTES];
+    simdvector attrib[MAX_NUM_VERTS_PER_PRIM];
 
     // assemble all attributes for the input primitive
     for (uint32_t slot = 0; slot < pState->numInputAttribs; ++slot)
     {
-        uint32_t attribSlot = VERTEX_ATTRIB_START_SLOT + slot;
+        uint32_t attribSlot = pState->vertexAttribOffset + slot;
         pa.Assemble(attribSlot, attrib);
 
         for (uint32_t i = 0; i < numVertsPerPrim; ++i)
         {
-            tlsGsContext.vert[i].attrib[attribSlot] = attrib[i];
+            tlsGsContext.vert[i].attrib[VERTEX_ATTRIB_START_SLOT + slot] = attrib[i];
         }
     }
-    
+
     // assemble position
     pa.Assemble(VERTEX_POSITION_SLOT, attrib);
     for (uint32_t i = 0; i < numVertsPerPrim; ++i)
@@ -691,27 +808,19 @@ static void GeometryShaderStage(
         tlsGsContext.vert[i].attrib[VERTEX_POSITION_SLOT] = attrib[i];
     }
 
-    const uint32_t vertexStride = sizeof(simdvertex);
-    const uint32_t numSimdBatches = (state.gsState.maxNumVerts + KNOB_SIMD_WIDTH - 1) / KNOB_SIMD_WIDTH;
-    const uint32_t inputPrimStride = numSimdBatches * vertexStride;
-    const uint32_t instanceStride = inputPrimStride * KNOB_SIMD_WIDTH;
-    uint32_t cutPrimStride;
-    uint32_t cutInstanceStride;
-
-    if (pState->isSingleStream)
-    {
-        cutPrimStride = (state.gsState.maxNumVerts + 7) / 8;
-        cutInstanceStride = cutPrimStride * KNOB_SIMD_WIDTH;
-    }
-    else
-    {
-        cutPrimStride = AlignUp(state.gsState.maxNumVerts * 2 / 8, 4);
-        cutInstanceStride = cutPrimStride * KNOB_SIMD_WIDTH;
-    }
+#if USE_SIMD16_FRONTEND
+    const GsBufferInfo<simd16vertex, KNOB_SIMD16_WIDTH> bufferInfo(state.gsState);
+#else
+    const GsBufferInfo<simdvertex, KNOB_SIMD_WIDTH> bufferInfo(state.gsState);
+#endif
 
     // record valid prims from the frontend to avoid over binning the newly generated
     // prims from the GS
+#if USE_SIMD16_FRONTEND
+    uint32_t numInputPrims = numPrims_simd8;
+#else
     uint32_t numInputPrims = pa.NumPrims();
+#endif
 
     for (uint32_t instance = 0; instance < pState->instanceCount; ++instance)
     {
@@ -721,11 +830,25 @@ static void GeometryShaderStage(
         // execute the geometry shader
         state.pfnGsFunc(GetPrivateState(pDC), &tlsGsContext);
 
-        tlsGsContext.pStream += instanceStride;
-        tlsGsContext.pCutOrStreamIdBuffer += cutInstanceStride;
+        tlsGsContext.pStream += bufferInfo.vertexInstanceStride;
+        tlsGsContext.pCutOrStreamIdBuffer += bufferInfo.cutInstanceStride;
     }
 
     // set up new binner and state for the GS output topology
+#if USE_SIMD16_FRONTEND
+    PFN_PROCESS_PRIMS_SIMD16 pfnClipFunc = nullptr;
+    if (HasRastT::value)
+    {
+        switch (pState->outputTopology)
+        {
+        case TOP_TRIANGLE_STRIP:    pfnClipFunc = ClipTriangles_simd16; break;
+        case TOP_LINE_STRIP:        pfnClipFunc = ClipLines_simd16; break;
+        case TOP_POINT_LIST:        pfnClipFunc = ClipPoints_simd16; break;
+        default: SWR_INVALID("Unexpected GS output topology: %d", pState->outputTopology);
+        }
+    }
+
+#else
     PFN_PROCESS_PRIMS pfnClipFunc = nullptr;
     if (HasRastT::value)
     {
@@ -734,10 +857,11 @@ static void GeometryShaderStage(
         case TOP_TRIANGLE_STRIP:    pfnClipFunc = ClipTriangles; break;
         case TOP_LINE_STRIP:        pfnClipFunc = ClipLines; break;
         case TOP_POINT_LIST:        pfnClipFunc = ClipPoints; break;
-        default: SWR_ASSERT(false, "Unexpected GS output topology: %d", pState->outputTopology);
+        default: SWR_INVALID("Unexpected GS output topology: %d", pState->outputTopology);
         }
     }
 
+#endif
     // foreach input prim:
     // - setup a new PA based on the emitted verts for that prim
     // - loop over the new verts, calling PA to assemble each prim
@@ -747,8 +871,9 @@ static void GeometryShaderStage(
     uint32_t totalPrimsGenerated = 0;
     for (uint32_t inputPrim = 0; inputPrim < numInputPrims; ++inputPrim)
     {
-        uint8_t* pInstanceBase = (uint8_t*)pGsOut + inputPrim * inputPrimStride;
-        uint8_t* pCutBufferBase = (uint8_t*)pCutBuffer + inputPrim * cutPrimStride;
+        uint8_t* pInstanceBase = (uint8_t*)pGsOut + inputPrim * bufferInfo.vertexPrimitiveStride;
+        uint8_t* pCutBufferBase = (uint8_t*)pCutBuffer + inputPrim * bufferInfo.cutPrimitiveStride;
+
         for (uint32_t instance = 0; instance < pState->instanceCount; ++instance)
         {
             uint32_t numEmittedVerts = pVertexCount[inputPrim];
@@ -757,9 +882,9 @@ static void GeometryShaderStage(
                 continue;
             }
 
-            uint8_t* pBase = pInstanceBase + instance * instanceStride;
-            uint8_t* pCutBase = pCutBufferBase + instance * cutInstanceStride;
-            
+            uint8_t* pBase = pInstanceBase + instance * bufferInfo.vertexInstanceStride;
+            uint8_t* pCutBase = pCutBufferBase + instance * bufferInfo.cutInstanceStride;
+
             uint32_t numAttribs = state.feNumAttributes;
 
             for (uint32_t stream = 0; stream < MAX_SO_STREAMS; ++stream)
@@ -790,58 +915,49 @@ static void GeometryShaderStage(
                     processCutVerts = false;
                 }
 
-                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, pCutBuffer, numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
+#if USE_SIMD16_FRONTEND
+                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, SWR_VTX_NUM_SLOTS, reinterpret_cast<simd16mask *>(pCutBuffer), numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
+
+#else
+                PA_STATE_CUT gsPa(pDC, pBase, numEmittedVerts, SWR_VTX_NUM_SLOTS, pCutBuffer, numEmittedVerts, numAttribs, pState->outputTopology, processCutVerts);
 
+#endif
                 while (gsPa.GetNextStreamOutput())
                 {
                     do
                     {
+#if USE_SIMD16_FRONTEND
+                        simd16vector attrib_simd16[3];
+
+                        bool assemble = gsPa.Assemble_simd16(VERTEX_POSITION_SLOT, attrib_simd16);
+
+#else
                         bool assemble = gsPa.Assemble(VERTEX_POSITION_SLOT, attrib);
 
+#endif
                         if (assemble)
                         {
                             totalPrimsGenerated += gsPa.NumPrims();
 
                             if (HasStreamOutT::value)
                             {
+#if ENABLE_AVX512_SIMD16
+                                gsPa.useAlternateOffset = false;
+#endif
                                 StreamOut(pDC, gsPa, workerId, pSoPrimData, stream);
                             }
 
                             if (HasRastT::value && state.soState.streamToRasterizer == stream)
                             {
-                                simdscalari vPrimId;
-                                // pull primitiveID from the GS output if available
-                                if (state.gsState.emitsPrimitiveID)
-                                {
-                                    simdvector primIdAttrib[3];
-                                    gsPa.Assemble(VERTEX_PRIMID_SLOT, primIdAttrib);
-                                    vPrimId = _simd_castps_si(primIdAttrib[0].x);
-                                }
-                                else
-                                {
-                                    vPrimId = _simd_set1_epi32(pPrimitiveId[inputPrim]);
-                                }
-
-                                // use viewport array index if GS declares it as an output attribute. Otherwise use index 0.
-                                simdscalari vViewPortIdx;
-                                if (state.gsState.emitsViewportArrayIndex)
-                                {
-                                    simdvector vpiAttrib[3];
-                                    gsPa.Assemble(VERTEX_VIEWPORT_ARRAY_INDEX_SLOT, vpiAttrib);
-
-                                    // OOB indices => forced to zero.
-                                    simdscalari vNumViewports = _simd_set1_epi32(KNOB_NUM_VIEWPORTS_SCISSORS);
-                                    simdscalari vClearMask = _simd_cmplt_epi32(_simd_castps_si(vpiAttrib[0].x), vNumViewports);
-                                    vpiAttrib[0].x = _simd_and_ps(_simd_castsi_ps(vClearMask), vpiAttrib[0].x);
-
-                                    vViewPortIdx = _simd_castps_si(vpiAttrib[0].x);
-                                }
-                                else
-                                {
-                                    vViewPortIdx = _simd_set1_epi32(0);
-                                }
-
-                                pfnClipFunc(pDC, gsPa, workerId, attrib, GenMask(gsPa.NumPrims()), vPrimId, vViewPortIdx);
+#if USE_SIMD16_FRONTEND
+                                simd16scalari vPrimId = _simd16_set1_epi32(pPrimitiveId[inputPrim]);
+
+                                gsPa.useAlternateOffset = false;
+                                pfnClipFunc(pDC, gsPa, workerId, attrib_simd16, GenMask(gsPa.NumPrims()), vPrimId);
+#else
+                                simdscalari vPrimId = _simd_set1_epi32(pPrimitiveId[inputPrim]);
+                                pfnClipFunc(pDC, gsPa, workerId, attrib, GenMask(gsPa.NumPrims()), vPrimId);
+#endif
                             }
                         }
                     } while (gsPa.NextPrim());
@@ -853,7 +969,7 @@ static void GeometryShaderStage(
     // update GS pipeline stats
     UPDATE_STAT_FE(GsInvocations, numInputPrims * pState->instanceCount);
     UPDATE_STAT_FE(GsPrimitives, totalPrimsGenerated);
-       AR_EVENT(GSPrimInfo(numInputPrims, totalPrimsGenerated, numVertsPerPrim*numInputPrims));
+    AR_EVENT(GSPrimInfo(numInputPrims, totalPrimsGenerated, numVertsPerPrim*numInputPrims));
     AR_END(FEGeometryShader, 1);
 }
 
@@ -863,24 +979,23 @@ static void GeometryShaderStage(
 /// @param state - API state
 /// @param ppGsOut - pointer to GS output buffer allocation
 /// @param ppCutBuffer - pointer to GS output cut buffer allocation
+template<typename SIMDVERTEX, uint32_t SIMD_WIDTH>
 static INLINE void AllocateGsBuffers(DRAW_CONTEXT* pDC, const API_STATE& state, void** ppGsOut, void** ppCutBuffer,
     void **ppStreamCutBuffer)
 {
     auto pArena = pDC->pArena;
     SWR_ASSERT(pArena != nullptr);
     SWR_ASSERT(state.gsState.gsEnable);
+
     // allocate arena space to hold GS output verts
     // @todo pack attribs
     // @todo support multiple streams
-    const uint32_t vertexStride = sizeof(simdvertex);
-    const uint32_t numSimdBatches = (state.gsState.maxNumVerts + KNOB_SIMD_WIDTH - 1) / KNOB_SIMD_WIDTH;
-    uint32_t size = state.gsState.instanceCount * numSimdBatches * vertexStride * KNOB_SIMD_WIDTH;
-    *ppGsOut = pArena->AllocAligned(size, KNOB_SIMD_WIDTH * sizeof(float));
 
-    const uint32_t cutPrimStride = (state.gsState.maxNumVerts + 7) / 8;
-    const uint32_t streamIdPrimStride = AlignUp(state.gsState.maxNumVerts * 2 / 8, 4);
-    const uint32_t cutBufferSize = cutPrimStride * state.gsState.instanceCount * KNOB_SIMD_WIDTH;
-    const uint32_t streamIdSize = streamIdPrimStride * state.gsState.instanceCount * KNOB_SIMD_WIDTH;
+    const GsBufferInfo<SIMDVERTEX, SIMD_WIDTH> bufferInfo(state.gsState);
+
+    const uint32_t vertexBufferSize = state.gsState.instanceCount * bufferInfo.vertexInstanceStride;
+
+    *ppGsOut = pArena->AllocAligned(vertexBufferSize, SIMD_WIDTH * sizeof(float));
 
     // allocate arena space to hold cut or streamid buffer, which is essentially a bitfield sized to the
     // maximum vertex output as defined by the GS state, per SIMD lane, per GS instance
@@ -888,15 +1003,19 @@ static INLINE void AllocateGsBuffers(DRAW_CONTEXT* pDC, const API_STATE& state,
     // allocate space for temporary per-stream cut buffer if multi-stream is enabled
     if (state.gsState.isSingleStream)
     {
-        *ppCutBuffer = pArena->AllocAligned(cutBufferSize, KNOB_SIMD_WIDTH * sizeof(float));
+        const uint32_t cutBufferSize = state.gsState.instanceCount * bufferInfo.cutInstanceStride;
+
+        *ppCutBuffer = pArena->AllocAligned(cutBufferSize, SIMD_WIDTH * sizeof(float));
         *ppStreamCutBuffer = nullptr;
     }
     else
     {
-        *ppCutBuffer = pArena->AllocAligned(streamIdSize, KNOB_SIMD_WIDTH * sizeof(float));
-        *ppStreamCutBuffer = pArena->AllocAligned(cutBufferSize, KNOB_SIMD_WIDTH * sizeof(float));
-    }
+        const uint32_t cutBufferSize = state.gsState.instanceCount * bufferInfo.cutInstanceStride;
+        const uint32_t streamCutBufferSize = state.gsState.instanceCount * bufferInfo.streamCutInstanceStride;
 
+        *ppCutBuffer = pArena->AllocAligned(cutBufferSize, SIMD_WIDTH * sizeof(float));
+        *ppStreamCutBuffer = pArena->AllocAligned(streamCutBufferSize, SIMD_WIDTH * sizeof(float));
+    }
 }
 
 //////////////////////////////////////////////////////////////////////////
@@ -947,6 +1066,9 @@ static void TessellationStages(
     void* pCutBuffer,
     void* pCutStreamBuffer,
     uint32_t* pSoPrimData,
+#if USE_SIMD16_FRONTEND
+    uint32_t numPrims_simd8,
+#endif
     simdscalari primID)
 {
     SWR_CONTEXT *pContext = pDC->pContext;
@@ -973,6 +1095,20 @@ static void TessellationStages(
     }
     SWR_ASSERT(tsCtx);
 
+#if USE_SIMD16_FRONTEND
+    PFN_PROCESS_PRIMS_SIMD16 pfnClipFunc = nullptr;
+    if (HasRastT::value)
+    {
+        switch (tsState.postDSTopology)
+        {
+        case TOP_TRIANGLE_LIST: pfnClipFunc = ClipTriangles_simd16; break;
+        case TOP_LINE_LIST:     pfnClipFunc = ClipLines_simd16; break;
+        case TOP_POINT_LIST:    pfnClipFunc = ClipPoints_simd16; break;
+        default: SWR_INVALID("Unexpected DS output topology: %d", tsState.postDSTopology);
+        }
+    }
+
+#else
     PFN_PROCESS_PRIMS pfnClipFunc = nullptr;
     if (HasRastT::value)
     {
@@ -981,10 +1117,11 @@ static void TessellationStages(
         case TOP_TRIANGLE_LIST: pfnClipFunc = ClipTriangles; break;
         case TOP_LINE_LIST:     pfnClipFunc = ClipLines; break;
         case TOP_POINT_LIST:    pfnClipFunc = ClipPoints; break;
-        default: SWR_ASSERT(false, "Unexpected DS output topology: %d", tsState.postDSTopology);
+        default: SWR_INVALID("Unexpected DS output topology: %d", tsState.postDSTopology);
         }
     }
 
+#endif
     SWR_HS_CONTEXT& hsContext = gt_pTessellationThreadData->hsContext;
     hsContext.pCPout = gt_pTessellationThreadData->patchData;
     hsContext.PrimitiveID = primID;
@@ -996,12 +1133,12 @@ static void TessellationStages(
     // assemble all attributes for the input primitives
     for (uint32_t slot = 0; slot < tsState.numHsInputAttribs; ++slot)
     {
-        uint32_t attribSlot = VERTEX_ATTRIB_START_SLOT + slot;
+        uint32_t attribSlot = tsState.vertexAttribOffset + slot;
         pa.Assemble(attribSlot, simdattrib);
 
         for (uint32_t i = 0; i < numVertsPerPrim; ++i)
         {
-            hsContext.vert[i].attrib[attribSlot] = simdattrib[i];
+            hsContext.vert[i].attrib[VERTEX_ATTRIB_START_SLOT + slot] = simdattrib[i];
         }
     }
 
@@ -1009,7 +1146,11 @@ static void TessellationStages(
     memset(hsContext.pCPout, 0x90, sizeof(ScalarPatch) * KNOB_SIMD_WIDTH);
 #endif
 
+#if USE_SIMD16_FRONTEND
+    uint32_t numPrims = numPrims_simd8;
+#else
     uint32_t numPrims = pa.NumPrims();
+#endif
     hsContext.mask = GenerateMask(numPrims);
 
     // Run the HS
@@ -1027,7 +1168,7 @@ static void TessellationStages(
         SWR_TS_TESSELLATED_DATA tsData = { 0 };
         AR_BEGIN(FETessellation, pDC->drawId);
         TSTessellate(tsCtx, hsContext.pCPout[p].tessFactors, tsData);
-               AR_EVENT(TessPrimCount(1));
+        AR_EVENT(TessPrimCount(1));
         AR_END(FETessellation, 0);
 
         if (tsData.NumPrimitives == 0)
@@ -1039,12 +1180,20 @@ static void TessellationStages(
         // Allocate DS Output memory
         uint32_t requiredDSVectorInvocations = AlignUp(tsData.NumDomainPoints, KNOB_SIMD_WIDTH) / KNOB_SIMD_WIDTH;
         size_t requiredDSOutputVectors = requiredDSVectorInvocations * tsState.numDsOutputAttribs;
+#if USE_SIMD16_FRONTEND
+        size_t requiredAllocSize = sizeof(simdvector) * RoundUpEven(requiredDSVectorInvocations) * tsState.numDsOutputAttribs;      // simd8 -> simd16, padding
+#else
         size_t requiredAllocSize = sizeof(simdvector) * requiredDSOutputVectors;
+#endif
         if (requiredDSOutputVectors > gt_pTessellationThreadData->numDSOutputVectors)
         {
             AlignedFree(gt_pTessellationThreadData->pDSOutput);
             gt_pTessellationThreadData->pDSOutput = (simdscalar*)AlignedMalloc(requiredAllocSize, 64);
+#if USE_SIMD16_FRONTEND
+            gt_pTessellationThreadData->numDSOutputVectors = RoundUpEven(requiredDSVectorInvocations) * tsState.numDsOutputAttribs; // simd8 -> simd16, padding
+#else
             gt_pTessellationThreadData->numDSOutputVectors = requiredDSOutputVectors;
+#endif
         }
         SWR_ASSERT(gt_pTessellationThreadData->pDSOutput);
         SWR_ASSERT(gt_pTessellationThreadData->numDSOutputVectors >= requiredDSOutputVectors);
@@ -1060,7 +1209,11 @@ static void TessellationStages(
         dsContext.pDomainU = (simdscalar*)tsData.pDomainPointsU;
         dsContext.pDomainV = (simdscalar*)tsData.pDomainPointsV;
         dsContext.pOutputData = gt_pTessellationThreadData->pDSOutput;
+#if USE_SIMD16_FRONTEND
+        dsContext.vectorStride = RoundUpEven(requiredDSVectorInvocations);      // simd8 -> simd16
+#else
         dsContext.vectorStride = requiredDSVectorInvocations;
+#endif
 
         uint32_t dsInvocations = 0;
 
@@ -1076,10 +1229,20 @@ static void TessellationStages(
         }
         UPDATE_STAT_FE(DsInvocations, tsData.NumDomainPoints);
 
+#if USE_SIMD16_FRONTEND
+        SWR_ASSERT(IsEven(dsContext.vectorStride));                             // simd8 -> simd16
+
+#endif
         PA_TESS tessPa(
             pDC,
+#if USE_SIMD16_FRONTEND
+            reinterpret_cast<const simd16scalar *>(dsContext.pOutputData),      // simd8 -> simd16
+            dsContext.vectorStride / 2,                                         // simd8 -> simd16
+#else
             dsContext.pOutputData,
             dsContext.vectorStride,
+#endif
+            SWR_VTX_NUM_SLOTS,
             tsState.numDsOutputAttribs,
             tsData.ppIndices,
             tsData.NumPrimitives,
@@ -1087,33 +1250,68 @@ static void TessellationStages(
 
         while (tessPa.HasWork())
         {
+#if USE_SIMD16_FRONTEND
+            const uint32_t numPrims = tessPa.NumPrims();
+            const uint32_t numPrims_lo = std::min<uint32_t>(numPrims, KNOB_SIMD_WIDTH);
+            const uint32_t numPrims_hi = std::max<uint32_t>(numPrims, KNOB_SIMD_WIDTH) - KNOB_SIMD_WIDTH;
+
+            const simd16scalari primID = _simd16_set1_epi32(dsContext.PrimitiveID);
+            const simdscalari primID_lo = _simd16_extract_si(primID, 0);
+            const simdscalari primID_hi = _simd16_extract_si(primID, 1);
+
+#endif
             if (HasGeometryShaderT::value)
             {
+#if USE_SIMD16_FRONTEND
+                tessPa.useAlternateOffset = false;
+                GeometryShaderStage<HasStreamOutT, HasRastT>(pDC, workerId, tessPa, pGsOut, pCutBuffer, pCutStreamBuffer, pSoPrimData, numPrims_lo, primID_lo);
+
+                if (numPrims_hi)
+                {
+                    tessPa.useAlternateOffset = true;
+                    GeometryShaderStage<HasStreamOutT, HasRastT>(pDC, workerId, tessPa, pGsOut, pCutBuffer, pCutStreamBuffer, pSoPrimData, numPrims_hi, primID_hi);
+                }
+#else
                 GeometryShaderStage<HasStreamOutT, HasRastT>(
                     pDC, workerId, tessPa, pGsOut, pCutBuffer, pCutStreamBuffer, pSoPrimData,
                     _simd_set1_epi32(dsContext.PrimitiveID));
+#endif
             }
             else
             {
                 if (HasStreamOutT::value)
                 {
+#if ENABLE_AVX512_SIMD16
+                    tessPa.useAlternateOffset = false;
+#endif
                     StreamOut(pDC, tessPa, workerId, pSoPrimData, 0);
                 }
 
                 if (HasRastT::value)
                 {
-                    simdvector prim[3]; // Only deal with triangles, lines, or points
+#if USE_SIMD16_FRONTEND
+                    simd16vector    prim_simd16[3]; // Only deal with triangles, lines, or points
+#else
+                    simdvector      prim[3];        // Only deal with triangles, lines, or points
+#endif
                     AR_BEGIN(FEPAAssemble, pDC->drawId);
-#if SWR_ENABLE_ASSERTS
                     bool assemble =
-#endif
+#if USE_SIMD16_FRONTEND
+                        tessPa.Assemble_simd16(VERTEX_POSITION_SLOT, prim_simd16);
+#else
                         tessPa.Assemble(VERTEX_POSITION_SLOT, prim);
+#endif
                     AR_END(FEPAAssemble, 1);
                     SWR_ASSERT(assemble);
 
                     SWR_ASSERT(pfnClipFunc);
+#if USE_SIMD16_FRONTEND
+                    tessPa.useAlternateOffset = false;
+                    pfnClipFunc(pDC, tessPa, workerId, prim_simd16, GenMask(numPrims), primID);
+#else
                     pfnClipFunc(pDC, tessPa, workerId, prim,
-                        GenMask(tessPa.NumPrims()), _simd_set1_epi32(dsContext.PrimitiveID), _simd_set1_epi32(0));
+                        GenMask(tessPa.NumPrims()), _simd_set1_epi32(dsContext.PrimitiveID));
+#endif
                 }
             }
 
@@ -1122,9 +1320,21 @@ static void TessellationStages(
         } // while (tessPa.HasWork())
     } // for (uint32_t p = 0; p < numPrims; ++p)
 
+#if USE_SIMD16_FRONTEND
+    if (gt_pTessellationThreadData->pDSOutput != nullptr)
+    {
+        AlignedFree(gt_pTessellationThreadData->pDSOutput);
+        gt_pTessellationThreadData->pDSOutput = nullptr;
+    }
+    gt_pTessellationThreadData->numDSOutputVectors = 0;
+
+#endif
     TSDestroyCtx(tsCtx);
 }
 
+THREAD PA_STATE::SIMDVERTEX *pVertexStore = nullptr;
+THREAD uint32_t gVertexStoreSize = 0;
+
 //////////////////////////////////////////////////////////////////////////
 /// @brief FE handler for SwrDraw.
 /// @tparam IsIndexedT - Is indexed drawing enabled
@@ -1161,12 +1371,9 @@ void ProcessDraw(
 
     DRAW_WORK&          work = *(DRAW_WORK*)pUserData;
     const API_STATE&    state = GetApiState(pDC);
-    __m256i             vScale = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
-    SWR_VS_CONTEXT      vsContext;
-    simdvertex          vin;
 
-    int indexSize = 0;
-    uint32_t endVertex = work.numVerts; 
+    uint32_t indexSize = 0;
+    uint32_t endVertex = work.numVerts;
 
     const int32_t* pLastRequestedIndex = nullptr;
     if (IsIndexedT::value)
@@ -1188,7 +1395,7 @@ void ProcessDraw(
             pLastRequestedIndex = (int32_t*)(&(((uint8_t*)work.pIB)[endVertex]));
             break;
         default:
-            SWR_ASSERT(0);
+            SWR_INVALID("Invalid work.type: %d", work.type);
         }
     }
     else
@@ -1197,30 +1404,6 @@ void ProcessDraw(
         endVertex = GetNumVerts(state.topology, GetNumPrims(state.topology, work.numVerts));
     }
 
-    SWR_FETCH_CONTEXT fetchInfo = { 0 };
-    fetchInfo.pStreams = &state.vertexBuffers[0];
-    fetchInfo.StartInstance = work.startInstance;
-    fetchInfo.StartVertex = 0;
-
-    vsContext.pVin = &vin;
-
-    if (IsIndexedT::value)
-    {
-        fetchInfo.BaseVertex = work.baseVertex;
-
-        // if the entire index buffer isn't being consumed, set the last index
-        // so that fetches < a SIMD wide will be masked off
-        fetchInfo.pLastIndex = (const int32_t*)(((uint8_t*)state.indexBuffer.pIndices) + state.indexBuffer.size);
-        if (pLastRequestedIndex < fetchInfo.pLastIndex)
-        {
-            fetchInfo.pLastIndex = pLastRequestedIndex;
-        }
-    }
-    else
-    {
-        fetchInfo.StartVertex = work.startVertex;
-    }
-
 #if defined(KNOB_ENABLE_RDTSC) || defined(KNOB_ENABLE_AR)
     uint32_t numPrims = GetNumPrims(state.topology, work.numVerts);
 #endif
@@ -1230,7 +1413,11 @@ void ProcessDraw(
     void* pStreamCutBuffer = nullptr;
     if (HasGeometryShaderT::value)
     {
-        AllocateGsBuffers(pDC, state, &pGsOut, &pCutBuffer, &pStreamCutBuffer);
+#if USE_SIMD16_FRONTEND
+        AllocateGsBuffers<simd16vertex, KNOB_SIMD16_WIDTH>(pDC, state, &pGsOut, &pCutBuffer, &pStreamCutBuffer);
+#else
+        AllocateGsBuffers<simdvertex, KNOB_SIMD_WIDTH>(pDC, state, &pGsOut, &pCutBuffer, &pStreamCutBuffer);
+#endif
     }
 
     if (HasTessellationT::value)
@@ -1255,10 +1442,284 @@ void ProcessDraw(
         pSoPrimData = (uint32_t*)pDC->pArena->AllocAligned(4096, 16);
     }
 
+    const uint32_t vertexCount = NumVertsPerPrim(state.topology, true);
+#if USE_SIMD16_FRONTEND
+    uint32_t simdVertexSizeBytes = state.frontendState.vsVertexSize * sizeof(simd16vector);
+#else
+    uint32_t simdVertexSizeBytes = state.frontendState.vsVertexSize * sizeof(simdvector);
+#endif
+
+    SWR_ASSERT(vertexCount <= MAX_NUM_VERTS_PER_PRIM);
+
+    // Compute storage requirements for vertex store
+    // TODO: allocation needs to be rethought for better cut support
+    uint32_t numVerts = vertexCount + 2; // Need extra space for PA state machine
+    uint32_t vertexStoreSize = numVerts * simdVertexSizeBytes;
+
+    // grow the vertex store for the PA as necessary
+    if (gVertexStoreSize < vertexStoreSize)
+    {
+        if (pVertexStore != nullptr)
+        {
+            AlignedFree(pVertexStore);
+        }
+
+        pVertexStore = reinterpret_cast<PA_STATE::SIMDVERTEX *>(AlignedMalloc(vertexStoreSize, 64));
+        gVertexStoreSize = vertexStoreSize;
+
+        SWR_ASSERT(pVertexStore != nullptr);
+    }
+
     // choose primitive assembler
-    PA_FACTORY<IsIndexedT, IsCutIndexEnabledT> paFactory(pDC, state.topology, work.numVerts);
+    PA_FACTORY<IsIndexedT, IsCutIndexEnabledT> paFactory(pDC, state.topology, work.numVerts, pVertexStore, numVerts, state.frontendState.vsVertexSize);
     PA_STATE& pa = paFactory.GetPA();
 
+#if USE_SIMD16_FRONTEND
+    simdvertex          vin_lo;
+    simdvertex          vin_hi;
+    SWR_VS_CONTEXT      vsContext_lo;
+    SWR_VS_CONTEXT      vsContext_hi;
+
+    vsContext_lo.pVin = &vin_lo;
+    vsContext_hi.pVin = &vin_hi;
+    vsContext_lo.AlternateOffset = 0;
+    vsContext_hi.AlternateOffset = 1;
+
+    SWR_FETCH_CONTEXT   fetchInfo_lo = { 0 };
+
+    fetchInfo_lo.pStreams = &state.vertexBuffers[0];
+    fetchInfo_lo.StartInstance = work.startInstance;
+    fetchInfo_lo.StartVertex = 0;
+
+    if (IsIndexedT::value)
+    {
+        fetchInfo_lo.BaseVertex = work.baseVertex;
+
+        // if the entire index buffer isn't being consumed, set the last index
+        // so that fetches < a SIMD wide will be masked off
+        fetchInfo_lo.pLastIndex = (const int32_t*)(((uint8_t*)state.indexBuffer.pIndices) + state.indexBuffer.size);
+        if (pLastRequestedIndex < fetchInfo_lo.pLastIndex)
+        {
+            fetchInfo_lo.pLastIndex = pLastRequestedIndex;
+        }
+    }
+    else
+    {
+        fetchInfo_lo.StartVertex = work.startVertex;
+    }
+
+    SWR_FETCH_CONTEXT   fetchInfo_hi = fetchInfo_lo;
+
+    const simd16scalari vScale = _simd16_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
+
+    for (uint32_t instanceNum = 0; instanceNum < work.numInstances; instanceNum++)
+    {
+        uint32_t  i = 0;
+
+        simd16scalari vIndex;
+
+        if (IsIndexedT::value)
+        {
+            fetchInfo_lo.pIndices = work.pIB;
+            fetchInfo_hi.pIndices = (int32_t *)((uint8_t *)fetchInfo_lo.pIndices + KNOB_SIMD_WIDTH * indexSize);    // 1/2 of KNOB_SIMD16_WIDTH
+        }
+        else
+        {
+            vIndex = _simd16_add_epi32(_simd16_set1_epi32(work.startVertexID), vScale);
+
+            fetchInfo_lo.pIndices = (const int32_t *)&vIndex;
+            fetchInfo_hi.pIndices = (const int32_t *)&vIndex + KNOB_SIMD_WIDTH; // 1/2 of KNOB_SIMD16_WIDTH
+        }
+
+        fetchInfo_lo.CurInstance = instanceNum;
+        fetchInfo_hi.CurInstance = instanceNum;
+
+        vsContext_lo.InstanceID = instanceNum;
+        vsContext_hi.InstanceID = instanceNum;
+
+        while (pa.HasWork())
+        {
+            // GetNextVsOutput currently has the side effect of updating some PA state machine state.
+            // So we need to keep this outside of (i < endVertex) check.
+
+            simdmask *pvCutIndices_lo = nullptr;
+            simdmask *pvCutIndices_hi = nullptr;
+
+            if (IsIndexedT::value)
+            {
+                // simd16mask <=> simdmask[2]
+
+                pvCutIndices_lo = &reinterpret_cast<simdmask *>(&pa.GetNextVsIndices())[0];
+                pvCutIndices_hi = &reinterpret_cast<simdmask *>(&pa.GetNextVsIndices())[1];
+            }
+
+            simd16vertex &vout = pa.GetNextVsOutput();
+
+            vsContext_lo.pVout = reinterpret_cast<simdvertex *>(&vout);
+            vsContext_hi.pVout = reinterpret_cast<simdvertex *>(&vout);
+
+            if (i < endVertex)
+            {
+                // 1. Execute FS/VS for a single SIMD.
+                AR_BEGIN(FEFetchShader, pDC->drawId);
+                state.pfnFetchFunc(fetchInfo_lo, vin_lo);
+
+                if ((i + KNOB_SIMD_WIDTH) < endVertex)  // 1/2 of KNOB_SIMD16_WIDTH
+                {
+                    state.pfnFetchFunc(fetchInfo_hi, vin_hi);
+                }
+                AR_END(FEFetchShader, 0);
+
+                // forward fetch generated vertex IDs to the vertex shader
+                vsContext_lo.VertexID = fetchInfo_lo.VertexID;
+                vsContext_hi.VertexID = fetchInfo_hi.VertexID;
+
+                // Setup active mask for vertex shader.
+                vsContext_lo.mask = GenerateMask(endVertex - i);
+                vsContext_hi.mask = GenerateMask(endVertex - (i + KNOB_SIMD_WIDTH));
+
+                // forward cut mask to the PA
+                if (IsIndexedT::value)
+                {
+                    *pvCutIndices_lo = _simd_movemask_ps(_simd_castsi_ps(fetchInfo_lo.CutMask));
+                    *pvCutIndices_hi = _simd_movemask_ps(_simd_castsi_ps(fetchInfo_hi.CutMask));
+                }
+
+                UPDATE_STAT_FE(IaVertices, GetNumInvocations(i, endVertex));
+
+#if KNOB_ENABLE_TOSS_POINTS
+                if (!KNOB_TOSS_FETCH)
+#endif
+                {
+                    AR_BEGIN(FEVertexShader, pDC->drawId);
+                    state.pfnVertexFunc(GetPrivateState(pDC), &vsContext_lo);
+
+                    if ((i + KNOB_SIMD_WIDTH) < endVertex)  // 1/2 of KNOB_SIMD16_WIDTH
+                    {
+                        state.pfnVertexFunc(GetPrivateState(pDC), &vsContext_hi);
+                    }
+                    AR_END(FEVertexShader, 0);
+
+                    UPDATE_STAT_FE(VsInvocations, GetNumInvocations(i, endVertex));
+                }
+            }
+
+            // 2. Assemble primitives given the last two SIMD.
+            do
+            {
+                simd16vector prim_simd16[MAX_NUM_VERTS_PER_PRIM];
+
+                RDTSC_START(FEPAAssemble);
+                bool assemble = pa.Assemble_simd16(VERTEX_POSITION_SLOT, prim_simd16);
+                RDTSC_STOP(FEPAAssemble, 1, 0);
+
+#if KNOB_ENABLE_TOSS_POINTS
+                if (!KNOB_TOSS_FETCH)
+#endif
+                {
+#if KNOB_ENABLE_TOSS_POINTS
+                    if (!KNOB_TOSS_VS)
+#endif
+                    {
+                        if (assemble)
+                        {
+                            UPDATE_STAT_FE(IaPrimitives, pa.NumPrims());
+
+                            const uint32_t numPrims = pa.NumPrims();
+                            const uint32_t numPrims_lo = std::min<uint32_t>(numPrims, KNOB_SIMD_WIDTH);
+                            const uint32_t numPrims_hi = std::max<uint32_t>(numPrims, KNOB_SIMD_WIDTH) - KNOB_SIMD_WIDTH;
+
+                            const simd16scalari primID = pa.GetPrimID(work.startPrimID);
+                            const simdscalari primID_lo = _simd16_extract_si(primID, 0);
+                            const simdscalari primID_hi = _simd16_extract_si(primID, 1);
+
+                            if (HasTessellationT::value)
+                            {
+                                pa.useAlternateOffset = false;
+                                TessellationStages<HasGeometryShaderT, HasStreamOutT, HasRastT>(pDC, workerId, pa, pGsOut, pCutBuffer, pStreamCutBuffer, pSoPrimData, numPrims_lo, primID_lo);
+
+                                if (numPrims_hi)
+                                {
+                                    pa.useAlternateOffset = true;
+                                    TessellationStages<HasGeometryShaderT, HasStreamOutT, HasRastT>(pDC, workerId, pa, pGsOut, pCutBuffer, pStreamCutBuffer, pSoPrimData, numPrims_hi, primID_hi);
+                                }
+                            }
+                            else if (HasGeometryShaderT::value)
+                            {
+                                pa.useAlternateOffset = false;
+                                GeometryShaderStage<HasStreamOutT, HasRastT>(pDC, workerId, pa, pGsOut, pCutBuffer, pStreamCutBuffer, pSoPrimData, numPrims_lo, primID_lo);
+
+                                if (numPrims_hi)
+                                {
+                                    pa.useAlternateOffset = true;
+                                    GeometryShaderStage<HasStreamOutT, HasRastT>(pDC, workerId, pa, pGsOut, pCutBuffer, pStreamCutBuffer, pSoPrimData, numPrims_hi, primID_hi);
+                                }
+                            }
+                            else
+                            {
+                                // If streamout is enabled then stream vertices out to memory.
+                                if (HasStreamOutT::value)
+                                {
+                                    pa.useAlternateOffset = false;
+                                    StreamOut(pDC, pa, workerId, pSoPrimData, 0);
+                                }
+
+                                if (HasRastT::value)
+                                {
+                                    SWR_ASSERT(pDC->pState->pfnProcessPrims_simd16);
+
+                                    pa.useAlternateOffset = false;
+                                    pDC->pState->pfnProcessPrims_simd16(pDC, pa, workerId, prim_simd16, GenMask(numPrims), primID);
+                                }
+                            }
+                        }
+                    }
+                }
+            } while (pa.NextPrim());
+
+            if (IsIndexedT::value)
+            {
+                fetchInfo_lo.pIndices = (int32_t *)((uint8_t*)fetchInfo_lo.pIndices + KNOB_SIMD16_WIDTH * indexSize);
+                fetchInfo_hi.pIndices = (int32_t *)((uint8_t*)fetchInfo_hi.pIndices + KNOB_SIMD16_WIDTH * indexSize);
+            }
+            else
+            {
+                vIndex = _simd16_add_epi32(vIndex, _simd16_set1_epi32(KNOB_SIMD16_WIDTH));
+            }
+
+            i += KNOB_SIMD16_WIDTH;
+        }
+
+        pa.Reset();
+    }
+
+#else
+    SWR_VS_CONTEXT      vsContext;
+    SWR_FETCH_CONTEXT   fetchInfo = { 0 };
+
+    fetchInfo.pStreams = &state.vertexBuffers[0];
+    fetchInfo.StartInstance = work.startInstance;
+    fetchInfo.StartVertex = 0;
+
+    if (IsIndexedT::value)
+    {
+        fetchInfo.BaseVertex = work.baseVertex;
+
+        // if the entire index buffer isn't being consumed, set the last index
+        // so that fetches < a SIMD wide will be masked off
+        fetchInfo.pLastIndex = (const int32_t*)(((uint8_t*)state.indexBuffer.pIndices) + state.indexBuffer.size);
+        if (pLastRequestedIndex < fetchInfo.pLastIndex)
+        {
+            fetchInfo.pLastIndex = pLastRequestedIndex;
+        }
+    }
+    else
+    {
+        fetchInfo.StartVertex = work.startVertex;
+    }
+
+    const simdscalari   vScale = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
+
     /// @todo: temporarily move instance loop in the FE to ensure SO ordering
     for (uint32_t instanceNum = 0; instanceNum < work.numInstances; instanceNum++)
     {
@@ -1280,7 +1741,7 @@ void ProcessDraw(
 
         while (pa.HasWork())
         {
-            // PaGetNextVsOutput currently has the side effect of updating some PA state machine state.
+            // GetNextVsOutput currently has the side effect of updating some PA state machine state.
             // So we need to keep this outside of (i < endVertex) check.
             simdmask* pvCutIndices = nullptr;
             if (IsIndexedT::value)
@@ -1289,6 +1750,7 @@ void ProcessDraw(
             }
 
             simdvertex& vout = pa.GetNextVsOutput();
+            vsContext.pVin = &vout;
             vsContext.pVout = &vout;
 
             if (i < endVertex)
@@ -1296,7 +1758,7 @@ void ProcessDraw(
 
                 // 1. Execute FS/VS for a single SIMD.
                 AR_BEGIN(FEFetchShader, pDC->drawId);
-                state.pfnFetchFunc(fetchInfo, vin);
+                state.pfnFetchFunc(fetchInfo, vout);
                 AR_END(FEFetchShader, 0);
 
                 // forward fetch generated vertex IDs to the vertex shader
@@ -1367,8 +1829,9 @@ void ProcessDraw(
                                 if (HasRastT::value)
                                 {
                                     SWR_ASSERT(pDC->pState->pfnProcessPrims);
+
                                     pDC->pState->pfnProcessPrims(pDC, pa, workerId, prim,
-                                        GenMask(pa.NumPrims()), pa.GetPrimID(work.startPrimID), _simd_set1_epi32(0));
+                                        GenMask(pa.NumPrims()), pa.GetPrimID(work.startPrimID));
                                 }
                             }
                         }
@@ -1376,7 +1839,6 @@ void ProcessDraw(
                 }
             } while (pa.NextPrim());
 
-            i += KNOB_SIMD_WIDTH;
             if (IsIndexedT::value)
             {
                 fetchInfo.pIndices = (int*)((uint8_t*)fetchInfo.pIndices + KNOB_SIMD_WIDTH * indexSize);
@@ -1385,10 +1847,13 @@ void ProcessDraw(
             {
                 vIndex = _simd_add_epi32(vIndex, _simd_set1_epi32(KNOB_SIMD_WIDTH));
             }
+
+            i += KNOB_SIMD_WIDTH;
         }
         pa.Reset();
     }
 
+#endif
 
     AR_END(FEProcessDraw, numPrims * work.numInstances);
 }
@@ -1415,4 +1880,4 @@ PFN_FE_WORK_FUNC GetProcessDrawFunc(
     bool HasRasterization)
 {
     return TemplateArgUnroller<FEDrawChooser>::GetFunc(IsIndexed, IsCutIndexEnabled, HasTessellation, HasGeometryShader, HasStreamOut, HasRasterization);
-}
\ No newline at end of file
+}