sizeof(privBase) + sizeof(spillBase) + sizeof(ldsChunk) +
computeUnit->wfSize() * sizeof(ReconvergenceStackEntry);
}
+
+void
+Wavefront::getContext(const void *out)
+{
+ uint8_t *iter = (uint8_t *)out;
+ for (int i = 0; i < barCnt.size(); i++) {
+ *(int *)iter = barCnt[i]; iter += sizeof(barCnt[i]);
+ }
+ *(int *)iter = wfId; iter += sizeof(wfId);
+ *(int *)iter = maxBarCnt; iter += sizeof(maxBarCnt);
+ *(int *)iter = oldBarrierCnt; iter += sizeof(oldBarrierCnt);
+ *(int *)iter = barrierCnt; iter += sizeof(barrierCnt);
+ *(int *)iter = computeUnit->cu_id; iter += sizeof(computeUnit->cu_id);
+ *(uint32_t *)iter = wgId; iter += sizeof(wgId);
+ *(uint32_t *)iter = barrierId; iter += sizeof(barrierId);
+ *(uint64_t *)iter = initMask.to_ullong(); iter += sizeof(initMask.to_ullong());
+ *(Addr *)iter = privBase; iter += sizeof(privBase);
+ *(Addr *)iter = spillBase; iter += sizeof(spillBase);
+
+ int stackSize = reconvergenceStack.size();
+ ReconvergenceStackEntry empty = {std::numeric_limits<uint32_t>::max(),
+ std::numeric_limits<uint32_t>::max(),
+ std::numeric_limits<uint64_t>::max()};
+ for (int i = 0; i < workItemId[0].size(); i++) {
+ if (i < stackSize) {
+ *(ReconvergenceStackEntry *)iter = *reconvergenceStack.back();
+ iter += sizeof(ReconvergenceStackEntry);
+ reconvergenceStack.pop_back();
+ } else {
+ *(ReconvergenceStackEntry *)iter = empty;
+ iter += sizeof(ReconvergenceStackEntry);
+ }
+ }
+
+ int wf_size = computeUnit->wfSize();
+ for (int i = 0; i < maxSpVgprs; i++) {
+ uint32_t vgprIdx = remap(i, sizeof(uint32_t), 1);
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint32_t regVal = computeUnit->vrf[simdId]->
+ read<uint32_t>(vgprIdx,lane);
+ *(uint32_t *)iter = regVal; iter += sizeof(regVal);
+ }
+ }
+
+ for (int i = 0; i < maxDpVgprs; i++) {
+ uint32_t vgprIdx = remap(i, sizeof(uint64_t), 1);
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint64_t regVal = computeUnit->vrf[simdId]->
+ read<uint64_t>(vgprIdx,lane);
+ *(uint64_t *)iter = regVal; iter += sizeof(regVal);
+ }
+ }
+
+ for (int i = 0; i < condRegState->numRegs(); i++) {
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint64_t regVal = condRegState->read<uint64_t>(i, lane);
+ *(uint64_t *)iter = regVal; iter += sizeof(regVal);
+ }
+ }
+
+ /* saving LDS content */
+ if (ldsChunk)
+ for (int i = 0; i < ldsChunk->size(); i++) {
+ char val = ldsChunk->read<char>(i);
+ *(char *) iter = val; iter += sizeof(val);
+ }
+}
+
+void
+Wavefront::setContext(const void *in)
+{
+ uint8_t *iter = (uint8_t *)in;
+ for (int i = 0; i < barCnt.size(); i++) {
+ barCnt[i] = *(int *)iter; iter += sizeof(barCnt[i]);
+ }
+ wfId = *(int *)iter; iter += sizeof(wfId);
+ maxBarCnt = *(int *)iter; iter += sizeof(maxBarCnt);
+ oldBarrierCnt = *(int *)iter; iter += sizeof(oldBarrierCnt);
+ barrierCnt = *(int *)iter; iter += sizeof(barrierCnt);
+ computeUnit->cu_id = *(int *)iter; iter += sizeof(computeUnit->cu_id);
+ wgId = *(uint32_t *)iter; iter += sizeof(wgId);
+ barrierId = *(uint32_t *)iter; iter += sizeof(barrierId);
+ initMask = VectorMask(*(uint64_t *)iter); iter += sizeof(initMask);
+ privBase = *(Addr *)iter; iter += sizeof(privBase);
+ spillBase = *(Addr *)iter; iter += sizeof(spillBase);
+
+ for (int i = 0; i < workItemId[0].size(); i++) {
+ ReconvergenceStackEntry newEntry = *(ReconvergenceStackEntry *)iter;
+ iter += sizeof(ReconvergenceStackEntry);
+ if (newEntry.pc != std::numeric_limits<uint32_t>::max()) {
+ pushToReconvergenceStack(newEntry.pc, newEntry.rpc,
+ newEntry.execMask);
+ }
+ }
+ int wf_size = computeUnit->wfSize();
+
+ for (int i = 0; i < maxSpVgprs; i++) {
+ uint32_t vgprIdx = remap(i, sizeof(uint32_t), 1);
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint32_t regVal = *(uint32_t *)iter; iter += sizeof(regVal);
+ computeUnit->vrf[simdId]->write<uint32_t>(vgprIdx, regVal, lane);
+ }
+ }
+
+ for (int i = 0; i < maxDpVgprs; i++) {
+ uint32_t vgprIdx = remap(i, sizeof(uint64_t), 1);
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint64_t regVal = *(uint64_t *)iter; iter += sizeof(regVal);
+ computeUnit->vrf[simdId]->write<uint64_t>(vgprIdx, regVal, lane);
+ }
+ }
+
+ for (int i = 0; i < condRegState->numRegs(); i++) {
+ for (int lane = 0; lane < wf_size; lane++) {
+ uint64_t regVal = *(uint64_t *)iter; iter += sizeof(regVal);
+ condRegState->write<uint64_t>(i, lane, regVal);
+ }
+ }
+ /** Restoring LDS contents */
+ if (ldsChunk)
+ for (int i = 0; i < ldsChunk->size(); i++) {
+ char val = *(char *) iter; iter += sizeof(val);
+ ldsChunk->write<char>(i, val);
+ }
+}