@@ -317,98 +317,103 @@ class SimulatorNEON final : public SimulatorBase {
317317 return 0 ;
318318 }
319319
320- static unsigned SIMDRegisterSize () { return 4 ; }
320+ static constexpr unsigned SIMDRegisterSize () {
321+ return sizeof (float32x4_t ) / sizeof (float32_t );
322+ }
321323
322324 private:
325+ struct Complex {
326+ float32x4_t re;
327+ float32x4_t im;
328+ };
329+
330+ template <unsigned Size, typename GateCoeff>
331+ static Complex ApplyGateRow (
332+ const float32x4_t * state_re, const float32x4_t * state_im,
333+ GateCoeff gate_coeff) {
334+ auto gate = gate_coeff (0 );
335+ auto re = vmulq_f32 (state_re[0 ], gate.re );
336+ auto im = vmulq_f32 (state_re[0 ], gate.im );
337+ re = vfmsq_f32 (re, state_im[0 ], gate.im );
338+ im = vfmaq_f32 (im, state_im[0 ], gate.re );
339+
340+ for (unsigned in = 1 ; in < Size; ++in) {
341+ // Complex MAC: out += gate * state.
342+ // re += state.re * gate.re - state.im * gate.im
343+ // im += state.re * gate.im + state.im * gate.re
344+ gate = gate_coeff (in);
345+ re = vfmaq_f32 (re, state_re[in], gate.re );
346+ im = vfmaq_f32 (im, state_re[in], gate.im );
347+ re = vfmsq_f32 (re, state_im[in], gate.im );
348+ im = vfmaq_f32 (im, state_im[in], gate.re );
349+ }
350+
351+ return Complex{re, im};
352+ }
353+
354+ static void StoreStateAmplitudeRow (
355+ fp_type* state_block, const uint64_t * state_offsets,
356+ unsigned output_basis, const Complex& output_amplitudes) {
357+ const auto addr_re = state_block + state_offsets[output_basis];
358+ const auto addr_im = addr_re + SIMDRegisterSize ();
359+ vst1q_f32 (addr_re, output_amplitudes.re );
360+ vst1q_f32 (addr_im, output_amplitudes.im );
361+ }
362+
323363 template <unsigned H>
324364 void ApplyGateH (
325365 const std::vector<unsigned >& qs, const fp_type* matrix,
326366 State& state) const {
327- auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* v,
328- const uint64_t * ms, const uint64_t * xss, fp_type* rstate) {
367+ auto f = [](unsigned n, unsigned m, uint64_t i,
368+ const fp_type* gate_matrix, const uint64_t * masks,
369+ const uint64_t * state_offsets, fp_type* state_data) {
329370 constexpr unsigned hsize = 1 << H;
330371
331- float32x4_t rs [hsize];
332- float32x4_t is [hsize];
372+ float32x4_t state_re [hsize];
373+ float32x4_t state_im [hsize];
333374
334375 i *= 4 ;
335376
336- uint64_t ii = i & ms [0 ];
377+ uint64_t ii = i & masks [0 ];
337378 for (unsigned j = 1 ; j <= H; ++j) {
338379 i *= 2 ;
339- ii |= i & ms [j];
380+ ii |= i & masks [j];
340381 }
341382
342- auto p0 = rstate + 2 * ii;
383+ auto block = state_data + 2 * ii;
343384
344385 for (unsigned k = 0 ; k < hsize; ++k) {
345- rs [k] = vld1q_f32 (p0 + xss [k]);
346- is [k] = vld1q_f32 (p0 + xss [k] + 4 );
386+ state_re [k] = vld1q_f32 (block + state_offsets [k]);
387+ state_im [k] = vld1q_f32 (block + state_offsets [k] + 4 );
347388 }
348389
349- unsigned k = 0 ;
350-
351- for (; k + 1 < hsize; k += 2 ) {
352- const fp_type* v0 = v + 2 * k * hsize;
353- const fp_type* v1 = v0 + 2 * hsize;
354-
355- float32x4_t ru0 = vdupq_n_f32 (v0[0 ]);
356- float32x4_t iu0 = vdupq_n_f32 (v0[1 ]);
357- float32x4_t rn0 = vmulq_f32 (rs[0 ], ru0);
358- float32x4_t in0 = vmulq_f32 (rs[0 ], iu0);
359- rn0 = vfmsq_f32 (rn0, is[0 ], iu0);
360- in0 = vfmaq_f32 (in0, is[0 ], ru0);
361-
362- float32x4_t ru1 = vdupq_n_f32 (v1[0 ]);
363- float32x4_t iu1 = vdupq_n_f32 (v1[1 ]);
364- float32x4_t rn1 = vmulq_f32 (rs[0 ], ru1);
365- float32x4_t in1 = vmulq_f32 (rs[0 ], iu1);
366- rn1 = vfmsq_f32 (rn1, is[0 ], iu1);
367- in1 = vfmaq_f32 (in1, is[0 ], ru1);
368-
369- for (unsigned l = 1 ; l < hsize; ++l) {
370- ru0 = vdupq_n_f32 (v0[2 * l]);
371- iu0 = vdupq_n_f32 (v0[2 * l + 1 ]);
372- rn0 = vfmaq_f32 (rn0, rs[l], ru0);
373- in0 = vfmaq_f32 (in0, rs[l], iu0);
374- rn0 = vfmsq_f32 (rn0, is[l], iu0);
375- in0 = vfmaq_f32 (in0, is[l], ru0);
376-
377- ru1 = vdupq_n_f32 (v1[2 * l]);
378- iu1 = vdupq_n_f32 (v1[2 * l + 1 ]);
379- rn1 = vfmaq_f32 (rn1, rs[l], ru1);
380- in1 = vfmaq_f32 (in1, rs[l], iu1);
381- rn1 = vfmsq_f32 (rn1, is[l], iu1);
382- in1 = vfmaq_f32 (in1, is[l], ru1);
383- }
384-
385- vst1q_f32 (p0 + xss[k], rn0);
386- vst1q_f32 (p0 + xss[k] + 4 , in0);
387- vst1q_f32 (p0 + xss[k + 1 ], rn1);
388- vst1q_f32 (p0 + xss[k + 1 ] + 4 , in1);
390+ auto load_gate_row = [](const fp_type* gate_row) {
391+ return [gate_row](unsigned in) {
392+ return Complex{
393+ vdupq_n_f32 (gate_row[2 * in]),
394+ vdupq_n_f32 (gate_row[2 * in + 1 ]),
395+ };
396+ };
397+ };
398+
399+ unsigned out = 0 ;
400+ for (; out + 1 < hsize; out += 2 ) {
401+ const fp_type* gate_row0 = gate_matrix + 2 * out * hsize;
402+ const fp_type* gate_row1 = gate_row0 + 2 * hsize;
403+
404+ auto out0 = ApplyGateRow<hsize>(state_re, state_im, load_gate_row (gate_row0));
405+ auto out1 = ApplyGateRow<hsize>(state_re, state_im, load_gate_row (gate_row1));
406+
407+ StoreStateAmplitudeRow (block, state_offsets, out, out0);
408+ StoreStateAmplitudeRow (block, state_offsets, out + 1 , out1);
389409 }
390410
391- for (; k < hsize; ++k) {
392- const fp_type* vk = v + 2 * k * hsize;
393-
394- float32x4_t ru = vdupq_n_f32 (vk[0 ]);
395- float32x4_t iu = vdupq_n_f32 (vk[1 ]);
396- float32x4_t rn = vmulq_f32 (rs[0 ], ru);
397- float32x4_t in = vmulq_f32 (rs[0 ], iu);
398- rn = vfmsq_f32 (rn, is[0 ], iu);
399- in = vfmaq_f32 (in, is[0 ], ru);
411+ for (; out < hsize; ++out) {
412+ const fp_type* gate_row = gate_matrix + 2 * out * hsize;
400413
401- for (unsigned l = 1 ; l < hsize; ++l) {
402- ru = vdupq_n_f32 (vk[2 * l]);
403- iu = vdupq_n_f32 (vk[2 * l + 1 ]);
404- rn = vfmaq_f32 (rn, rs[l], ru);
405- in = vfmaq_f32 (in, rs[l], iu);
406- rn = vfmsq_f32 (rn, is[l], iu);
407- in = vfmaq_f32 (in, is[l], ru);
408- }
414+ auto out_row = ApplyGateRow<hsize>(state_re, state_im, load_gate_row (gate_row));
409415
410- vst1q_f32 (p0 + xss[k], rn);
411- vst1q_f32 (p0 + xss[k] + 4 , in);
416+ StoreStateAmplitudeRow (block, state_offsets, out, out_row);
412417 }
413418 };
414419
@@ -435,8 +440,8 @@ class SimulatorNEON final : public SimulatorBase {
435440 constexpr unsigned hsize = 1 << H;
436441 constexpr unsigned lsize = 1 << L;
437442
438- float32x4_t rs [gsize];
439- float32x4_t is [gsize];
443+ float32x4_t state_re [gsize];
444+ float32x4_t state_im [gsize];
440445
441446 i *= 4 ;
442447
@@ -451,49 +456,41 @@ class SimulatorNEON final : public SimulatorBase {
451456 for (unsigned k = 0 ; k < hsize; ++k) {
452457 unsigned k2 = lsize * k;
453458
454- rs [k2] = vld1q_f32 (p0 + xss[k]);
455- is [k2] = vld1q_f32 (p0 + xss[k] + 4 );
459+ state_re [k2] = vld1q_f32 (p0 + xss[k]);
460+ state_im [k2] = vld1q_f32 (p0 + xss[k] + 4 );
456461
457462 if (L == 1 ) {
458- rs[k2 + 1 ] =
459- q0 == 0 ? vrev64q_f32 (rs[k2]) : vextq_f32 (rs[k2], rs[k2], 2 );
460- is[k2 + 1 ] =
461- q0 == 0 ? vrev64q_f32 (is[k2]) : vextq_f32 (is[k2], is[k2], 2 );
463+ state_re[k2 + 1 ] =
464+ q0 == 0 ? vrev64q_f32 (state_re[k2])
465+ : vextq_f32 (state_re[k2], state_re[k2], 2 );
466+ state_im[k2 + 1 ] =
467+ q0 == 0 ? vrev64q_f32 (state_im[k2])
468+ : vextq_f32 (state_im[k2], state_im[k2], 2 );
462469 } else if (L == 2 ) {
463- rs [k2 + 1 ] = vextq_f32 (rs [k2], rs [k2], 1 );
464- is [k2 + 1 ] = vextq_f32 (is [k2], is [k2], 1 );
465- rs [k2 + 2 ] = vextq_f32 (rs [k2], rs [k2], 2 );
466- is [k2 + 2 ] = vextq_f32 (is [k2], is [k2], 2 );
467- rs [k2 + 3 ] = vextq_f32 (rs [k2], rs [k2], 3 );
468- is [k2 + 3 ] = vextq_f32 (is [k2], is [k2], 3 );
470+ state_re [k2 + 1 ] = vextq_f32 (state_re [k2], state_re [k2], 1 );
471+ state_im [k2 + 1 ] = vextq_f32 (state_im [k2], state_im [k2], 1 );
472+ state_re [k2 + 2 ] = vextq_f32 (state_re [k2], state_re [k2], 2 );
473+ state_im [k2 + 2 ] = vextq_f32 (state_im [k2], state_im [k2], 2 );
474+ state_re [k2 + 3 ] = vextq_f32 (state_re [k2], state_re [k2], 3 );
475+ state_im [k2 + 3 ] = vextq_f32 (state_im [k2], state_im [k2], 3 );
469476 }
470477 }
471478
472- uint64_t j = 0 ;
479+ auto load_gate_row = [](const fp_type* gate_row) {
480+ return [gate_row](unsigned in) {
481+ return Complex{
482+ vld1q_f32 (gate_row + 8 * in),
483+ vld1q_f32 (gate_row + 8 * in + 4 ),
484+ };
485+ };
486+ };
473487
474488 for (unsigned k = 0 ; k < hsize; ++k) {
475- float32x4_t wre = vld1q_f32 (w + 4 * j);
476- float32x4_t wim = vld1q_f32 (w + 4 * (j + 1 ));
477- float32x4_t rn = vmulq_f32 (rs[0 ], wre);
478- float32x4_t in = vmulq_f32 (rs[0 ], wim);
479- rn = vfmsq_f32 (rn, is[0 ], wim);
480- in = vfmaq_f32 (in, is[0 ], wre);
489+ const fp_type* gate_row = w + 8 * k * gsize;
490+ auto out_row = ApplyGateRow<gsize>(
491+ state_re, state_im, load_gate_row (gate_row));
481492
482- j += 2 ;
483-
484- for (unsigned l = 1 ; l < gsize; ++l) {
485- wre = vld1q_f32 (w + 4 * j);
486- wim = vld1q_f32 (w + 4 * (j + 1 ));
487- rn = vfmaq_f32 (rn, rs[l], wre);
488- in = vfmaq_f32 (in, rs[l], wim);
489- rn = vfmsq_f32 (rn, is[l], wim);
490- in = vfmaq_f32 (in, is[l], wre);
491-
492- j += 2 ;
493- }
494-
495- vst1q_f32 (p0 + xss[k], rn);
496- vst1q_f32 (p0 + xss[k] + 4 , in);
493+ StoreStateAmplitudeRow (p0, xss, k, out_row);
497494 }
498495 };
499496
0 commit comments