From bd0c6ed15365fbf0acb398358609bb777bdec175 Mon Sep 17 00:00:00 2001 From: iliailia Date: Tue, 18 Jan 2022 23:43:13 +0100 Subject: [PATCH] Refactor tests and source --- include/keygen.h | 2 +- src/fft.cpp | 7 +- src/keygen.cpp | 22 +- src/lwehe.cpp | 82 +------- src/ntruhe.cpp | 131 ------------ test.cpp | 529 ++--------------------------------------------- 6 files changed, 28 insertions(+), 745 deletions(-) diff --git a/include/keygen.h b/include/keygen.h index 2fe0cbb..f355916 100644 --- a/include/keygen.h +++ b/include/keygen.h @@ -114,7 +114,7 @@ class KeyGen void get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot); /** - * Generate a bootstrapping key + * Generate a bootstrapping key (EXPERIMENTAL) * @param[out] bsk bootstrapping key. * @param[in] sk_base secret key of the base scheme. * @param[in] sk_boot secret key of the bootstrapping scheme. diff --git a/src/fft.cpp b/src/fft.cpp index 40d9531..fd7b1fa 100644 --- a/src/fft.cpp +++ b/src/fft.cpp @@ -20,7 +20,6 @@ FFT_engine::FFT_engine(const int dim): fft_dim(dim) for(int i = 0; i < dim; i++) { ModQPoly x_power(dim,0); - //x_power[0] = -1; x_power[i] += 1; FFTPoly x_power_fft(fft_dim2); to_fft(x_power_fft, x_power); @@ -30,7 +29,6 @@ FFT_engine::FFT_engine(const int dim): fft_dim(dim) to_fft(x_power_fft, x_power); neg_powers[i] = x_power_fft; } - //x_powers.insert({{-1,neg_powers}, {1,pos_powers}}); } void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const @@ -48,11 +46,10 @@ void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const } fftw_execute(plan_to_fft); int tmp = 1; - //for (int i = 0; i < fft_dim2; i++) for (auto it = out.begin(); it < out.end(); ++it) { fftw_complex& out_z = out_arr[tmp]; - complex& outi = *it; //out[i]; + complex& outi = *it; outi.real(out_z[0]); outi.imag(out_z[1]); tmp += 2; @@ -67,7 +64,6 @@ void FFT_engine::from_fft(vector& out, const FFTPoly& in) const int N = fft_dim; int Nd = double(N); - //for (int i = 0; i < fft_dim2; ++i) for (auto it = in.begin(); it < in.end(); ++it) { //std::cout << "i: " << i << ", number: " << in[i] << std::endl; @@ -163,7 +159,6 @@ void operator *=(FFTPoly& a, const FFTPoly& b) a[i]*=b[i]; } -// TODO: make a test FFTPoly operator *(const FFTPoly& a, const int b) { FFTPoly res(a.size()); diff --git a/src/keygen.cpp b/src/keygen.cpp index c5792f1..fa84341 100644 --- a/src/keygen.cpp +++ b/src/keygen.cpp @@ -44,18 +44,18 @@ void KeyGen::get_sk_base(SKey_base_LWE& sk_base) void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot) { - cout << "Started key-switching key generation" << endl; + //cout << "Started key-switching key generation" << endl; clock_t start = clock(); // reset key-switching key ksk.clear(); ksk = ModQMatrix(param.Nl, vector(param.n,0)); vector> ksk_long(param.Nl, vector(param.n,0L)); - cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // noise matrix G as in the paper ModQMatrix G(param.Nl, vector(param.n,0L)); sampler.get_ternary_matrix(G); - cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // matrix G + P * Phi(f) * E as in the paper int coef_w_pwr = sk_boot.sk[0]; @@ -73,7 +73,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_ coef_w_pwr *= Param::B_ksk; } } - cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // parameters of the block optimization of matrix multiplication int block = 4; @@ -101,7 +101,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_ k_row[blocks+j] += (coef * f_row[blocks+j]); } } - cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // reduce modulo q_base for (int i = 0; i < param.Nl; i++) @@ -111,7 +111,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_ void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) { - cout << "Started key-switching key generation" << endl; + //cout << "Started key-switching key generation" << endl; clock_t start = clock(); // reset key-switching key ksk.A.clear(); @@ -122,11 +122,11 @@ void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_bo ksk.A.push_back(row); } ksk.b = vector(param.Nl, 0L); - cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // noise matrix G as in the paper sampler.get_uniform_matrix(ksk.A); - cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // matrix P * f_0 as in the paper vector Pf0(param.Nl, 0L); @@ -145,7 +145,7 @@ void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_bo coef_w_pwr *= Param::B_ksk; } } - cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + //cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; // A*s_base + e + Pf0 as in the paper normal_distribution gaussian_sampler(0.0, Param::e_st_dev); @@ -257,7 +257,7 @@ void KeyGen::get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_ coef_counter += param.bsk_partition[iBase]; } - cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + cout << "Bootstrapping key generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; } void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) @@ -344,7 +344,7 @@ void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_bo coef_counter += param.bsk_partition[iBase]; } - cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + cout << "Bootstrapping key generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; } void KeyGen::get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) diff --git a/src/lwehe.cpp b/src/lwehe.cpp index 3a7804c..209eb24 100644 --- a/src/lwehe.cpp +++ b/src/lwehe.cpp @@ -118,17 +118,17 @@ void external_product(vector& res, const vector& poly, const vector bound) { poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit); abs_val >>= shift; - ++abs_val; //(abs_val - digit)/b + 1; + ++abs_val; } else { poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit; - abs_val >>= shift; //(abs_val - digit)/b; + abs_val >>= shift; } } fftN.to_fft(tmp_fft, poly_decomp); @@ -136,7 +136,6 @@ void external_product(vector& res, const vector& poly, const vector& vec) - { - for (size_t i = 0; i < vec.size(); i++) - { - printf("[%zu] %d ", i, vec[i]); - } - cout << endl; - } - - void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk) - { - FFTPoly sk_fft; - fftN.to_fft(sk_fft, sk.sk); - FFTPoly ct_fft; - fftN.to_fft(ct_fft, ct); - - FFTPoly output_fft; - output_fft = ct_fft * sk_fft; - vector output; - vector output_int; - fftN.from_fft(output, output_fft); - parLWE.mod_q_boot(output_int, output); - print(output_int); - } - - void decrypt_poly_base_and_print(const ModQPoly& ct, const SKey_boot& sk) - { - FFTPoly sk_fft; - fftN.to_fft(sk_fft, sk.sk); - FFTPoly ct_fft; - fftN.to_fft(ct_fft, ct); - - FFTPoly output_fft; - output_fft = ct_fft * sk_fft; - ModQPoly output; - fftN.from_fft(output, output_fft); - mod_q_base(output); - print(output); - } - -void decryptN2(const Ctxt_LWE& ct, const SKey_base_LWE& sk) -{ - int output = ct.b; - for (int i = 0; i < lwe_he::n; i++) - { - output += ct.a[i] * sk[i]; - } - output = output%parLWE.N2; - if (output > parLWE.N) - output -= parLWE.N2; - if (output <= -parLWE.N) - output += parLWE.N2; - cout << output << endl; -} -// end debugger functions -*/ - void SchemeLWE::bootstrap(Ctxt_LWE& ct) const { //clock_t start = clock(); @@ -304,10 +244,6 @@ void SchemeLWE::bootstrap(Ctxt_LWE& ct) const Bd = double(B); shift = parLWE.shift_bsk[iBase]; l = parLWE.l_bsk[iBase]; - //vector> w_powers(l); - //w_powers[0] = complex(1.0,0.0); - //for (int i = 1; i < l; i++) - // w_powers[i] = w_powers[i-1] * Bd; const vector& bk_coef_row = boot_key[iBase]; for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; ++iCoef) { @@ -359,13 +295,9 @@ void SchemeLWE::bootstrap(Ctxt_LWE& ct) const //mod q_boot of the accumulator mod_q_boot(acc); - - //decrypt_poly_boot_and_print(acc, sk_boot); //mod switch to q_base modulo_switch_to_base_lwe(acc); - - //decrypt_poly_boot_and_print(acc, sk_boot); //key switch //auto start = clock(); @@ -420,10 +352,6 @@ void SchemeLWE::bootstrap2(Ctxt_LWE& ct) const Bd = double(B); shift = parLWE.shift_bsk[iBase]; l = parLWE.l_bsk[iBase]; - //vector> w_powers(l); - //w_powers[0] = complex(1.0,0.0); - //for (int i = 1; i < l; i++) - // w_powers[i] = w_powers[i-1] * Bd; const vector& bk_coef_row = boot_key[iBase]; vector mux_fft(l,FFTPoly(N2p1,complex(0.0,0.0))); for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; iCoef+=2) @@ -508,13 +436,9 @@ void SchemeLWE::bootstrap2(Ctxt_LWE& ct) const //mod q_boot of the accumulator mod_q_boot(acc); - - //decrypt_poly_boot_and_print(acc, sk_boot); //mod switch to q_base modulo_switch_to_base_lwe(acc); - - //decrypt_poly_boot_and_print(acc, sk_boot); //key switch //auto start = clock(); diff --git a/src/ntruhe.cpp b/src/ntruhe.cpp index 0026969..012536a 100644 --- a/src/ntruhe.cpp +++ b/src/ntruhe.cpp @@ -81,52 +81,6 @@ int SchemeNTRU::decrypt(const Ctxt_NTRU& ct) const return output; } -/* -void SchemeNTRU::external_product(vector& res, const vector& poly, const vector& poly_vector, const int b, const int shift, const int l) const -{ - int N = Param::N; - int N2p1 = Param::N2p1; - - ModQPoly poly_sign(N,0L); - ModQPoly poly_abs(N,0L); - for (int i = 0; i < N; i++) - { - const int& polyi = poly[i]; - poly_abs[i] = abs(polyi); - poly_sign[i] = (polyi < 0)? -1 : 1; - } - FFTPoly res_fft(N2p1); - FFTPoly tmp_fft(N2p1); - int mask = b-1; - int bound = b >> 1; - int digit, sgn, abs_val; - vector poly_decomp(N); - for (int j = 0; j < l; j++) - { - for (int i = 0; i < N; i++) - { - abs_val = poly_abs[i]; - digit = abs_val & mask; //poly_abs[i] % b; - if (digit > bound) - { - poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit); - poly_abs[i] = (abs_val >> shift) + 1; //(abs_val - digit)/b + 1; - } - else - { - poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit; - poly_abs[i] = abs_val >> shift; //(abs_val - digit)/b; - } - } - fftN.to_fft(tmp_fft, poly_decomp); - tmp_fft *= poly_vector[j]; - res_fft += tmp_fft; - } - fftN.from_fft(res, res_fft); - //mod_q_boot(poly); -} -*/ - void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const { int N = Param::N; @@ -181,87 +135,6 @@ void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const parNTRU.mod_q_base(ct.data, ct_long); } -/* -// debugger functions -void print(const vector& vec) -{ - for (size_t i = 0; i < vec.size(); i++) - { - printf("[%zu] %d ", i, vec[i]); - } - cout << endl; -} -void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk, const Param& param) -{ - FFTPoly sk_fft(Param::N2p1); - fftN.to_fft(sk_fft, sk.sk); - FFTPoly ct_fft(Param::N2p1); - fftN.to_fft(ct_fft, ct); - - FFTPoly output_fft; - output_fft = ct_fft * sk_fft; - ModQPoly output; - vector output_long; - fftN.from_fft(output_long, output_fft); - mod_q_boot(output, output_long); - print(output); -} - -void decrypt_poly_base_and_print(const ModQPoly& ct, const Param& param, const SKey_boot& sk) -{ - FFTPoly sk_fft(Param::N2p1); - fftN.to_fft(sk_fft, sk.sk); - FFTPoly ct_fft(Param::N2p1); - fftN.to_fft(ct_fft, ct); - - FFTPoly output_fft; - output_fft = ct_fft * sk_fft; - ModQPoly output; - vector output_long; - fftN.from_fft(output_long, output_fft); - param.mod_q_base(output, output_long); - print(output); -} - -void decryptN2(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk) -{ - int N = Param::N; - int N2 = Param::N2; - int n = parNTRU.n; - - int output = 0; - for (int i = 0; i < n; i++) - { - output += ct.data[i] * sk.sk[i][0]; - } - output = output%N2; - if (output > N) - output -= N2; - if (output <= -N) - output += N2; - cout << output << endl; -} - -void decrypt_base(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk) -{ - int n = parNTRU.n; - int q_base = parNTRU.q_base; - int half_q_base= parNTRU.half_q_base; - - int output = 0; - for (int i = 0; i < n; i++) - { - output += ct.data[i] * sk.sk[i][0]; - } - output = output%q_base; - if (output > half_q_base) - output -= q_base; - if (output <= -half_q_base) - output += q_base; - cout << output << endl; -} -// end debugger functions -*/ void SchemeNTRU::mask_constant(Ctxt_NTRU& ct, int constant) { int n = parNTRU.n; @@ -313,10 +186,6 @@ void SchemeNTRU::bootstrap(Ctxt_NTRU& ct) const Bd = double(B); shift = parNTRU.shift_bsk[iBase]; l = parNTRU.l_bsk[iBase]; - //vector> w_power_fft(l); - //w_power_fft[0] = complex(1.0,0.0); - //for (int i = 1; i < l; i++) - // w_power_fft[i] = w_power_fft[i-1] * Bd; const vector>& bk_coef_row = boot_key[iBase]; vector mux_fft(l, FFTPoly(N2p1)); for (int iCoef = 0; iCoef < parNTRU.bsk_partition[iBase]; ++iCoef) diff --git a/test.cpp b/test.cpp index eee41a8..b09aff5 100644 --- a/test.cpp +++ b/test.cpp @@ -270,510 +270,6 @@ void test_sampler() cout << "SAMPLER IS OK" << endl; } -void test_ntru_key_gen() -{ - Param param(NTRU); - int n = param.n; - int Nl = param.Nl; - int half_q_base = param.half_q_base; - int q_base = param.q_base; - int l_ksk = param.l_ksk; - int N = Param::N; - int t = Param::t; - int B_ksk = Param::B_ksk; - int B_bsk_size = Param::B_bsk_size; - int N2p1 = Param::N2p1; - - SKey_base_NTRU sk_base; - KeyGen k(param); - k.get_sk_base(sk_base); - cout << "Secret key of the base scheme is generated" << endl; - assert(sk_base.sk.size() == n && sk_base.sk[0].size() == n - && sk_base.sk_inv.size() == n && sk_base.sk_inv[0].size() == n); - for (int i = 0; i < n; i++) - for (int j = 0; j < n; j++) - { - assert((sk_base.sk[i][j]==0) || (sk_base.sk[i][j]==-1) || (sk_base.sk[i][j]==1) ); - } - - SKey_boot sk_boot; - k.get_sk_boot(sk_boot); - cout << "Secret key of the bootstrapping scheme is generated" << endl; - assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N); - assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1))); - for (int i = 1; i < N; i++) - { - assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) ); - } - - KSKey_NTRU ksk; - k.get_ksk(ksk, sk_base, sk_boot); - cout << "Key-switching key is generated" << endl; - assert(ksk.size() == Nl && ksk[0].size() == n); - for (int i = 0; i < Nl; i++) - for (int j = 0; j < n; j++) - { - //cout << ksk[i][j] << endl; - assert(ksk[i][j] <= half_q_base && ksk[i][j] >= -half_q_base); - } - - vector q4_decomp; - decompose(q4_decomp, q_base/4, B_ksk, l_ksk); - vector ks_res(n,0L); - for (int i = 0; i < l_ksk; i++) - { - int tmp_int = q4_decomp[i]; - vector& ksk_row = ksk[i]; - for (int j = 0; j < n; j++) - { - ks_res[j] += ksk_row[j] * tmp_int; - } - } - param.mod_q_base(ks_res); - int ks_int = 0; - for (int i = 0; i < n; i++) - { - ks_int += ks_res[i] * sk_base.sk[i][0]; - } - ks_int = param.mod_q_base(ks_int); - ks_int = int(round(double(ks_int*4)/double(q_base))); - assert(ks_int == 1L); - - // bootstrapping key test - BSKey_NTRU bsk; - k.get_bsk(bsk, sk_base, sk_boot); - cout << "Bootstrapping key is generated" << endl; - - // check dimensions - assert(bsk.size() == B_bsk_size); - for (int i = 0; i < bsk.size(); i++) - { - assert(bsk[i].size() == param.bsk_partition[i]); - for (int j = 0; j < bsk[i].size(); j++) - { - assert(bsk[i][j].size() == 2); - assert(bsk[i][j][0].size() == param.l_bsk[i]); - assert(bsk[i][j][1].size() == param.l_bsk[i]); - } - } - // convert sk_boot to FFT - vector> sk_boot_fft(N2p1); - fftN.to_fft(sk_boot_fft, sk_boot.sk); - - int coef_counter = 0; - for (int iBase = 0; iBase < B_bsk_size; iBase++) - { - decompose(q4_decomp, q_boot/4, param.B_bsk[iBase], param.l_bsk[iBase]); - for (size_t iCoef = 0; iCoef < bsk[iBase].size(); iCoef++) - { - int sk_coef = 0; - int sk_base_coef_bits[2]; - for (int iBit = 0; iBit < 2; iBit++) - { - vector> tmp_fft(N2p1, complex(0.0,0.0)); - for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++) - { - tmp_fft = tmp_fft + bsk[iBase][iCoef][iBit][iPart] * q4_decomp[iPart]; - } - tmp_fft = tmp_fft * sk_boot_fft; - vector tmp_int; - vector tmp_long; - fftN.from_fft(tmp_long, tmp_fft); - mod_q_boot(tmp_int, tmp_long); - sk_base_coef_bits[iBit] = int(round(double(tmp_int[0]*4)/double(q_boot))); - } - if (sk_base_coef_bits[1] == 1) - sk_coef = -1; - else if (sk_base_coef_bits[0] == 1) - sk_coef = 1; - - assert(sk_coef == sk_base.sk[coef_counter + iCoef][0]); - } - coef_counter += param.bsk_partition[iBase]; - } - - cout << "KEYGEN IS OK" << endl; -} - -void test_lwe_key_gen() -{ - Param param(LWE); - int n = param.n; - int N = Param::N; - int t = Param::t; - - SKey_base_LWE sk_base; - KeyGen k(param); - k.get_sk_base(sk_base); - cout << "Secret key of the base scheme is generated" << endl; - assert(sk_base.size() == n); - for (int j = 0; j < n; j++) - { - assert((sk_base[j]==0) || (sk_base[j]==1)); - } - - SKey_boot sk_boot; - k.get_sk_boot(sk_boot); - cout << "Secret key of the bootstrapping scheme is generated" << endl; - assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N); - assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1))); - for (int i = 1; i < N; i++) - { - assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) ); - } - - cout << "KEYGEN IS OK" << endl; -} - -void test_fft() -{ - int N = Param::N; - int N2p1 = Param::N2p1; - - FFT_engine fft_engine(N); - { - vector in(N,0L); - vector> out(N2p1); - clock_t start = clock(); - fft_engine.to_fft(out, in); - cout << "Forward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - for (size_t i = 0; i < N/2; i++) - { - if (int(round(real(out[i])))!=0 || int(round(imag(out[i])))!=0) - { - cout << i << " " << out[i] << endl; - assert(false); - } - } - } - - { - vector out; - vector> in(N2p1, complex(0.0,0.0)); - clock_t start = clock(); - fft_engine.from_fft(out, in); - cout << "Backward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - for (size_t i = 0; i < N/2; i++) - { - assert(out[i] == 0L); - } - } - - { - vector in(N,0L); - in[0] = 1L; - vector> out(N2p1); - clock_t start = clock(); - fft_engine.to_fft(out, in); - cout << "Forward FFT (1,0,...0): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - for (size_t i = 0; i < N/2; i++) - { - assert(int(round(real(out[i])))==1 && int(round(imag(out[i])))==0); - } - } - - { - vector out; - vector> in(N2p1, complex(1.0,0.0)); - clock_t start = clock(); - fft_engine.from_fft(out, in); - cout << "Backward FFT (1,1,...1): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - assert(out[0] == 1L); - for (size_t i = 1; i < N; i++) - { - assert(out[i] == 0L); - } - } - - { - uniform_int_distribution sampler(INT_MIN, INT_MAX); - int coef = sampler(rand_engine); - vector out; - vector> in(N2p1, complex(double(coef),0.0)); - clock_t start = clock(); - fft_engine.from_fft(out, in); - cout << "Backward FFT (a,a,...a): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - assert(out[0] == coef); - for (size_t i = 1; i < N; i++) - { - assert(out[i] == 0L); - } - } - - { - uniform_int_distribution sampler(INT_MIN, INT_MAX); - vector in; - for (int i = 0; i < N; i++) - in.push_back(sampler(rand_engine)); - vector> interm(N2p1); - vector out; - clock_t start = clock(); - fft_engine.to_fft(interm, in); - cout << "Forward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - start = clock(); - fft_engine.from_fft(out, interm); - cout << "Backward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl; - for (size_t i = 0; i < N; i++) - { - //cout << "i: " << i << "in[i]: " << in[i] << " out[i]: " << out[i] << endl; - assert(in[i] == out[i]); - } - } - - { - uniform_int_distribution sampler(-100, 100); - vector in1, in2, res; - for (int i = 0; i < N; i++) - { - in1.push_back(sampler(rand_engine)); - in2.push_back(sampler(rand_engine)); - res.push_back(in1[i]+in2[i]); - } - vector> interm1(N2p1); - vector> interm2(N2p1); - vector> intermres(N2p1); - vector out; - - fft_engine.to_fft(interm1, in1); - fft_engine.to_fft(interm2, in2); - - clock_t start = clock(); - intermres = interm1 + interm2; - cout << "FFT addition: " << float(clock()-start)/CLOCKS_PER_SEC << endl; - - fft_engine.from_fft(out, intermres); - - for (size_t i = 0; i < N; i++) - { - //cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl; - assert(res[i] == out[i]); - } - } - - { - uniform_int_distribution sampler(-100, 100); - vector in1, in2, res; - for (int i = 0; i < N; i++) - { - in1.push_back(sampler(rand_engine)); - in2.push_back(sampler(rand_engine)); - } - ZZX poly1, poly2, poly_res; - for (int i = 0; i < N; i++) - { - SetCoeff(poly1, i, in1[i]); - SetCoeff(poly2, i, in2[i]); - } - ZZX poly_mod; - SetCoeff(poly_mod, 0, 1); - SetCoeff(poly_mod, N, 1); - - clock_t start = clock(); - MulMod(poly_res, poly1, poly2, poly_mod); - cout << "NTL multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl; - - for (int i = 0; i < N; i++) - { - res.push_back(conv(poly_res[i])); - } - - vector> interm1(N2p1); - vector> interm2(N2p1); - vector> intermres(N2p1); - vector out; - - fft_engine.to_fft(interm1, in1); - fft_engine.to_fft(interm2, in2); - - start = clock(); - intermres = interm1 * interm2; - cout << "FFT multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl; - - fft_engine.from_fft(out, intermres); - - for (size_t i = 0; i < N; i++) - { - //cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl; - assert(res[i] == out[i]); - } - } - - cout << "FFT is OK" << endl; -} - -void test_ntruhe_encrypt() -{ - SchemeNTRU s; - - { - int input = 0; - Ctxt_NTRU ct; - s.encrypt(ct, input); - int output = s.decrypt(ct); - assert(output == input); - } - { - int input = 1; - Ctxt_NTRU ct; - s.encrypt(ct, input); - int output = s.decrypt(ct); - assert(output == input); - } - cout << "NTRU ENCRYPTION IS OK" << endl; -} - -void test_lwehe_encrypt() -{ - SchemeLWE s; - - { - int input = 0; - Ctxt_LWE ct; - s.encrypt(ct, input); - int output = s.decrypt(ct); - assert(output == input); - } - { - int input = 1; - Ctxt_LWE ct; - s.encrypt(ct, input); - int output = s.decrypt(ct); - assert(output == input); - } - cout << "LWE ENCRYPTION IS OK" << endl; -} - -/* -void test_mod_switch() -{ - SchemeNTRU s; - { - int input = 0; - Ctxt_NTRU ct; - s.encrypt(ct, input); - s.modulo_switch_to_base(ct.data); - int output = 0; - for (int i = 0; i < ntru_he::n; i++) - { - output += ct.data[i] * sk_base.sk[i][0]; - } - output = output%Param::N2; - if (output > Param::N) - output -= Param::N2; - else if (output <= -Param::N) - output += Param::N2; - output = int(round(double(output*t)/double(Param::N2))); - assert(output == input); - } - { - int input = 1; - ntru_he::Ctxt ct; - ntru_he::encrypt(ct, input, sk_base); - ntru_he::modulo_switch(ct, ntru_he::q_base, Param::N2); - int output = 0; - for (int i = 0; i < ntru_he::n; i++) - { - output += ct[i] * sk_base.sk[i][0]; - } - output = output%Param::N2; - if (output > Param::N) - output -= Param::N2; - else if (output <= -Param::N) - output += Param::N2; - output = int(round(double(output*t)/double(Param::N2))); - assert(output == input); - } - { - int input = 0; - ntru_he::Ctxt ct; - ntru_he::encrypt(ct, input, sk_base); - ntru_he::modulo_switch_to_boot(ct); - int output = 0; - for (int i = 0; i < ntru_he::n; i++) - { - output += ct[i] * sk_base.sk[i][0]; - } - output = output%Param::N2; - if (output > Param::N) - output -= Param::N2; - else if (output <= -Param::N) - output += Param::N2; - output = int(round(double(output*t)/double(Param::N2))); - assert(output == input); - } - { - int input = 1; - ntru_he::Ctxt ct; - ntru_he::encrypt(ct, input, sk_base); - ntru_he::modulo_switch_to_boot(ct); - int output = 0; - for (int i = 0; i < ntru_he::n; i++) - { - output += ct[i] * sk_base.sk[i][0]; - } - output = output%Param::N2; - if (output > Param::N) - output -= Param::N2; - else if (output <= -Param::N) - output += Param::N2; - output = int(round(double(output*t)/double(Param::N2))); - assert(output == input); - } - cout << "MODULO SWITCHING IS OK" << endl; -} -*/ - -void test_bootstrap() -{ - SchemeNTRU s; - { - int input = 2; - Ctxt_NTRU ct; - s.encrypt(ct, input); - - s.bootstrap(ct); - - int output = s.decrypt(ct); - cout << "Bootstrapping output: " << output << endl; - assert(output == 1L); - } - { - int input = 0; - Ctxt_NTRU ct; - s.encrypt(ct, input); - - s.bootstrap(ct); - - int output = s.decrypt(ct); - cout << "Bootstrapping output: " << output << endl; - assert(output == 0L); - } - - cout << "BOOTSTRAPPING IS OK" << endl; -} - -/* -void test_nand_aux() -{ - ntru_he::SKey_base sk_base; - ntru_he::get_sk_base(sk_base); - - ntru_he::Ctxt ct; - ntru_he::get_nand_aux(ct, sk_base); - int output = 0; - for (int i = 0; i < ntru_he::n; i++) - { - output += ct[i] * sk_base.sk[i][0]; - } - output = ntru_he::mod_q_base(output); - assert( - output == (ntru_he::nand_const-ntru_he::q_base) - || output == (ntru_he::nand_const-ntru_he::q_base+1) - || output == (ntru_he::nand_const-ntru_he::q_base-1) - ); - cout << "NAND ENCRYPTION IS OK" << endl; -}*/ - enum GateType {NAND, AND, OR}; void test_ntruhe_gate_helper(int in1, int in2, const SchemeNTRU& s, GateType g) @@ -940,21 +436,20 @@ void test_lwehe_or() int main() { - //test_params(); - //test_sampler(); - //test_ntru_key_gen(); - //test_lwe_key_gen(); - //test_fft(); - //test_ntruhe_encrypt(); - //test_lwehe_encrypt(); - //test_mod_switch(); - //test_bootstrap(); - //test_nand_aux(); - //test_ntruhe_nand(); + test_params(); + test_sampler(); + + cout << "NTRU tests" << endl; + test_ntruhe_nand(); + test_ntruhe_and(); + test_ntruhe_or(); + cout << "NTRU tests PASSED" << endl; + + cout << "LWE tests" << endl; test_lwehe_nand(); - //test_ntruhe_and(); test_lwehe_and(); - //test_ntruhe_or(); test_lwehe_or(); + cout << "LWE tests PASSED" << endl; + return 0; } \ No newline at end of file