Skip to content

Commit 7087545

Browse files
author
XIN XIE
committed
Refactor NEON gate row helpers
1 parent d005491 commit 7087545

1 file changed

Lines changed: 102 additions & 105 deletions

File tree

lib/simulator_neon.h

Lines changed: 102 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -317,98 +317,103 @@ class SimulatorNEON final : public SimulatorBase {
317317
return 0;
318318
}
319319

320-
static unsigned SIMDRegisterSize() { return 4; }
320+
static constexpr unsigned SIMDRegisterSize() {
321+
return sizeof(float32x4_t) / sizeof(float32_t);
322+
}
321323

322324
private:
325+
struct Complex {
326+
float32x4_t re;
327+
float32x4_t im;
328+
};
329+
330+
template <unsigned Size, typename GateCoeff>
331+
static Complex ApplyGateRow(
332+
const float32x4_t* state_re, const float32x4_t* state_im,
333+
GateCoeff gate_coeff) {
334+
auto gate = gate_coeff(0);
335+
auto re = vmulq_f32(state_re[0], gate.re);
336+
auto im = vmulq_f32(state_re[0], gate.im);
337+
re = vfmsq_f32(re, state_im[0], gate.im);
338+
im = vfmaq_f32(im, state_im[0], gate.re);
339+
340+
for (unsigned in = 1; in < Size; ++in) {
341+
// Complex MAC: out += gate * state.
342+
// re += state.re * gate.re - state.im * gate.im
343+
// im += state.re * gate.im + state.im * gate.re
344+
gate = gate_coeff(in);
345+
re = vfmaq_f32(re, state_re[in], gate.re);
346+
im = vfmaq_f32(im, state_re[in], gate.im);
347+
re = vfmsq_f32(re, state_im[in], gate.im);
348+
im = vfmaq_f32(im, state_im[in], gate.re);
349+
}
350+
351+
return Complex{re, im};
352+
}
353+
354+
static void StoreStateAmplitudeRow(
355+
fp_type* state_block, const uint64_t* state_offsets,
356+
unsigned output_basis, const Complex& output_amplitudes) {
357+
const auto addr_re = state_block + state_offsets[output_basis];
358+
const auto addr_im = addr_re + SIMDRegisterSize();
359+
vst1q_f32(addr_re, output_amplitudes.re);
360+
vst1q_f32(addr_im, output_amplitudes.im);
361+
}
362+
323363
template <unsigned H>
324364
void ApplyGateH(
325365
const std::vector<unsigned>& qs, const fp_type* matrix,
326366
State& state) const {
327-
auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* v,
328-
const uint64_t* ms, const uint64_t* xss, fp_type* rstate) {
367+
auto f = [](unsigned n, unsigned m, uint64_t i,
368+
const fp_type* gate_matrix, const uint64_t* masks,
369+
const uint64_t* state_offsets, fp_type* state_data) {
329370
constexpr unsigned hsize = 1 << H;
330371

331-
float32x4_t rs[hsize];
332-
float32x4_t is[hsize];
372+
float32x4_t state_re[hsize];
373+
float32x4_t state_im[hsize];
333374

334375
i *= 4;
335376

336-
uint64_t ii = i & ms[0];
377+
uint64_t ii = i & masks[0];
337378
for (unsigned j = 1; j <= H; ++j) {
338379
i *= 2;
339-
ii |= i & ms[j];
380+
ii |= i & masks[j];
340381
}
341382

342-
auto p0 = rstate + 2 * ii;
383+
auto block = state_data + 2 * ii;
343384

344385
for (unsigned k = 0; k < hsize; ++k) {
345-
rs[k] = vld1q_f32(p0 + xss[k]);
346-
is[k] = vld1q_f32(p0 + xss[k] + 4);
386+
state_re[k] = vld1q_f32(block + state_offsets[k]);
387+
state_im[k] = vld1q_f32(block + state_offsets[k] + 4);
347388
}
348389

349-
unsigned k = 0;
350-
351-
for (; k + 1 < hsize; k += 2) {
352-
const fp_type* v0 = v + 2 * k * hsize;
353-
const fp_type* v1 = v0 + 2 * hsize;
354-
355-
float32x4_t ru0 = vdupq_n_f32(v0[0]);
356-
float32x4_t iu0 = vdupq_n_f32(v0[1]);
357-
float32x4_t rn0 = vmulq_f32(rs[0], ru0);
358-
float32x4_t in0 = vmulq_f32(rs[0], iu0);
359-
rn0 = vfmsq_f32(rn0, is[0], iu0);
360-
in0 = vfmaq_f32(in0, is[0], ru0);
361-
362-
float32x4_t ru1 = vdupq_n_f32(v1[0]);
363-
float32x4_t iu1 = vdupq_n_f32(v1[1]);
364-
float32x4_t rn1 = vmulq_f32(rs[0], ru1);
365-
float32x4_t in1 = vmulq_f32(rs[0], iu1);
366-
rn1 = vfmsq_f32(rn1, is[0], iu1);
367-
in1 = vfmaq_f32(in1, is[0], ru1);
368-
369-
for (unsigned l = 1; l < hsize; ++l) {
370-
ru0 = vdupq_n_f32(v0[2 * l]);
371-
iu0 = vdupq_n_f32(v0[2 * l + 1]);
372-
rn0 = vfmaq_f32(rn0, rs[l], ru0);
373-
in0 = vfmaq_f32(in0, rs[l], iu0);
374-
rn0 = vfmsq_f32(rn0, is[l], iu0);
375-
in0 = vfmaq_f32(in0, is[l], ru0);
376-
377-
ru1 = vdupq_n_f32(v1[2 * l]);
378-
iu1 = vdupq_n_f32(v1[2 * l + 1]);
379-
rn1 = vfmaq_f32(rn1, rs[l], ru1);
380-
in1 = vfmaq_f32(in1, rs[l], iu1);
381-
rn1 = vfmsq_f32(rn1, is[l], iu1);
382-
in1 = vfmaq_f32(in1, is[l], ru1);
383-
}
384-
385-
vst1q_f32(p0 + xss[k], rn0);
386-
vst1q_f32(p0 + xss[k] + 4, in0);
387-
vst1q_f32(p0 + xss[k + 1], rn1);
388-
vst1q_f32(p0 + xss[k + 1] + 4, in1);
390+
auto load_gate_row = [](const fp_type* gate_row) {
391+
return [gate_row](unsigned in) {
392+
return Complex{
393+
vdupq_n_f32(gate_row[2 * in]),
394+
vdupq_n_f32(gate_row[2 * in + 1]),
395+
};
396+
};
397+
};
398+
399+
unsigned out = 0;
400+
for (; out + 1 < hsize; out += 2) {
401+
const fp_type* gate_row0 = gate_matrix + 2 * out * hsize;
402+
const fp_type* gate_row1 = gate_row0 + 2 * hsize;
403+
404+
auto out0 = ApplyGateRow<hsize>(state_re, state_im, load_gate_row(gate_row0));
405+
auto out1 = ApplyGateRow<hsize>(state_re, state_im, load_gate_row(gate_row1));
406+
407+
StoreStateAmplitudeRow(block, state_offsets, out, out0);
408+
StoreStateAmplitudeRow(block, state_offsets, out + 1, out1);
389409
}
390410

391-
for (; k < hsize; ++k) {
392-
const fp_type* vk = v + 2 * k * hsize;
393-
394-
float32x4_t ru = vdupq_n_f32(vk[0]);
395-
float32x4_t iu = vdupq_n_f32(vk[1]);
396-
float32x4_t rn = vmulq_f32(rs[0], ru);
397-
float32x4_t in = vmulq_f32(rs[0], iu);
398-
rn = vfmsq_f32(rn, is[0], iu);
399-
in = vfmaq_f32(in, is[0], ru);
411+
for (; out < hsize; ++out) {
412+
const fp_type* gate_row = gate_matrix + 2 * out * hsize;
400413

401-
for (unsigned l = 1; l < hsize; ++l) {
402-
ru = vdupq_n_f32(vk[2 * l]);
403-
iu = vdupq_n_f32(vk[2 * l + 1]);
404-
rn = vfmaq_f32(rn, rs[l], ru);
405-
in = vfmaq_f32(in, rs[l], iu);
406-
rn = vfmsq_f32(rn, is[l], iu);
407-
in = vfmaq_f32(in, is[l], ru);
408-
}
414+
auto out_row = ApplyGateRow<hsize>(state_re, state_im, load_gate_row(gate_row));
409415

410-
vst1q_f32(p0 + xss[k], rn);
411-
vst1q_f32(p0 + xss[k] + 4, in);
416+
StoreStateAmplitudeRow(block, state_offsets, out, out_row);
412417
}
413418
};
414419

@@ -435,8 +440,8 @@ class SimulatorNEON final : public SimulatorBase {
435440
constexpr unsigned hsize = 1 << H;
436441
constexpr unsigned lsize = 1 << L;
437442

438-
float32x4_t rs[gsize];
439-
float32x4_t is[gsize];
443+
float32x4_t state_re[gsize];
444+
float32x4_t state_im[gsize];
440445

441446
i *= 4;
442447

@@ -451,49 +456,41 @@ class SimulatorNEON final : public SimulatorBase {
451456
for (unsigned k = 0; k < hsize; ++k) {
452457
unsigned k2 = lsize * k;
453458

454-
rs[k2] = vld1q_f32(p0 + xss[k]);
455-
is[k2] = vld1q_f32(p0 + xss[k] + 4);
459+
state_re[k2] = vld1q_f32(p0 + xss[k]);
460+
state_im[k2] = vld1q_f32(p0 + xss[k] + 4);
456461

457462
if (L == 1) {
458-
rs[k2 + 1] =
459-
q0 == 0 ? vrev64q_f32(rs[k2]) : vextq_f32(rs[k2], rs[k2], 2);
460-
is[k2 + 1] =
461-
q0 == 0 ? vrev64q_f32(is[k2]) : vextq_f32(is[k2], is[k2], 2);
463+
state_re[k2 + 1] =
464+
q0 == 0 ? vrev64q_f32(state_re[k2])
465+
: vextq_f32(state_re[k2], state_re[k2], 2);
466+
state_im[k2 + 1] =
467+
q0 == 0 ? vrev64q_f32(state_im[k2])
468+
: vextq_f32(state_im[k2], state_im[k2], 2);
462469
} else if (L == 2) {
463-
rs[k2 + 1] = vextq_f32(rs[k2], rs[k2], 1);
464-
is[k2 + 1] = vextq_f32(is[k2], is[k2], 1);
465-
rs[k2 + 2] = vextq_f32(rs[k2], rs[k2], 2);
466-
is[k2 + 2] = vextq_f32(is[k2], is[k2], 2);
467-
rs[k2 + 3] = vextq_f32(rs[k2], rs[k2], 3);
468-
is[k2 + 3] = vextq_f32(is[k2], is[k2], 3);
470+
state_re[k2 + 1] = vextq_f32(state_re[k2], state_re[k2], 1);
471+
state_im[k2 + 1] = vextq_f32(state_im[k2], state_im[k2], 1);
472+
state_re[k2 + 2] = vextq_f32(state_re[k2], state_re[k2], 2);
473+
state_im[k2 + 2] = vextq_f32(state_im[k2], state_im[k2], 2);
474+
state_re[k2 + 3] = vextq_f32(state_re[k2], state_re[k2], 3);
475+
state_im[k2 + 3] = vextq_f32(state_im[k2], state_im[k2], 3);
469476
}
470477
}
471478

472-
uint64_t j = 0;
479+
auto load_gate_row = [](const fp_type* gate_row) {
480+
return [gate_row](unsigned in) {
481+
return Complex{
482+
vld1q_f32(gate_row + 8 * in),
483+
vld1q_f32(gate_row + 8 * in + 4),
484+
};
485+
};
486+
};
473487

474488
for (unsigned k = 0; k < hsize; ++k) {
475-
float32x4_t wre = vld1q_f32(w + 4 * j);
476-
float32x4_t wim = vld1q_f32(w + 4 * (j + 1));
477-
float32x4_t rn = vmulq_f32(rs[0], wre);
478-
float32x4_t in = vmulq_f32(rs[0], wim);
479-
rn = vfmsq_f32(rn, is[0], wim);
480-
in = vfmaq_f32(in, is[0], wre);
489+
const fp_type* gate_row = w + 8 * k * gsize;
490+
auto out_row = ApplyGateRow<gsize>(
491+
state_re, state_im, load_gate_row(gate_row));
481492

482-
j += 2;
483-
484-
for (unsigned l = 1; l < gsize; ++l) {
485-
wre = vld1q_f32(w + 4 * j);
486-
wim = vld1q_f32(w + 4 * (j + 1));
487-
rn = vfmaq_f32(rn, rs[l], wre);
488-
in = vfmaq_f32(in, rs[l], wim);
489-
rn = vfmsq_f32(rn, is[l], wim);
490-
in = vfmaq_f32(in, is[l], wre);
491-
492-
j += 2;
493-
}
494-
495-
vst1q_f32(p0 + xss[k], rn);
496-
vst1q_f32(p0 + xss[k] + 4, in);
493+
StoreStateAmplitudeRow(p0, xss, k, out_row);
497494
}
498495
};
499496

0 commit comments

Comments
 (0)