mirror of
https://github.com/mii443/FINAL.git
synced 2025-08-22 23:15:28 +00:00
Refactor tests and source
This commit is contained in:
@ -114,7 +114,7 @@ class KeyGen
|
|||||||
void get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot);
|
void get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a bootstrapping key
|
* Generate a bootstrapping key (EXPERIMENTAL)
|
||||||
* @param[out] bsk bootstrapping key.
|
* @param[out] bsk bootstrapping key.
|
||||||
* @param[in] sk_base secret key of the base scheme.
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
@ -20,7 +20,6 @@ FFT_engine::FFT_engine(const int dim): fft_dim(dim)
|
|||||||
for(int i = 0; i < dim; i++)
|
for(int i = 0; i < dim; i++)
|
||||||
{
|
{
|
||||||
ModQPoly x_power(dim,0);
|
ModQPoly x_power(dim,0);
|
||||||
//x_power[0] = -1;
|
|
||||||
x_power[i] += 1;
|
x_power[i] += 1;
|
||||||
FFTPoly x_power_fft(fft_dim2);
|
FFTPoly x_power_fft(fft_dim2);
|
||||||
to_fft(x_power_fft, x_power);
|
to_fft(x_power_fft, x_power);
|
||||||
@ -30,7 +29,6 @@ FFT_engine::FFT_engine(const int dim): fft_dim(dim)
|
|||||||
to_fft(x_power_fft, x_power);
|
to_fft(x_power_fft, x_power);
|
||||||
neg_powers[i] = x_power_fft;
|
neg_powers[i] = x_power_fft;
|
||||||
}
|
}
|
||||||
//x_powers.insert({{-1,neg_powers}, {1,pos_powers}});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const
|
void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const
|
||||||
@ -48,11 +46,10 @@ void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const
|
|||||||
}
|
}
|
||||||
fftw_execute(plan_to_fft);
|
fftw_execute(plan_to_fft);
|
||||||
int tmp = 1;
|
int tmp = 1;
|
||||||
//for (int i = 0; i < fft_dim2; i++)
|
|
||||||
for (auto it = out.begin(); it < out.end(); ++it)
|
for (auto it = out.begin(); it < out.end(); ++it)
|
||||||
{
|
{
|
||||||
fftw_complex& out_z = out_arr[tmp];
|
fftw_complex& out_z = out_arr[tmp];
|
||||||
complex<double>& outi = *it; //out[i];
|
complex<double>& outi = *it;
|
||||||
outi.real(out_z[0]);
|
outi.real(out_z[0]);
|
||||||
outi.imag(out_z[1]);
|
outi.imag(out_z[1]);
|
||||||
tmp += 2;
|
tmp += 2;
|
||||||
@ -67,7 +64,6 @@ void FFT_engine::from_fft(vector<long>& out, const FFTPoly& in) const
|
|||||||
int N = fft_dim;
|
int N = fft_dim;
|
||||||
int Nd = double(N);
|
int Nd = double(N);
|
||||||
|
|
||||||
//for (int i = 0; i < fft_dim2; ++i)
|
|
||||||
for (auto it = in.begin(); it < in.end(); ++it)
|
for (auto it = in.begin(); it < in.end(); ++it)
|
||||||
{
|
{
|
||||||
//std::cout << "i: " << i << ", number: " << in[i] << std::endl;
|
//std::cout << "i: " << i << ", number: " << in[i] << std::endl;
|
||||||
@ -163,7 +159,6 @@ void operator *=(FFTPoly& a, const FFTPoly& b)
|
|||||||
a[i]*=b[i];
|
a[i]*=b[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make a test
|
|
||||||
FFTPoly operator *(const FFTPoly& a, const int b)
|
FFTPoly operator *(const FFTPoly& a, const int b)
|
||||||
{
|
{
|
||||||
FFTPoly res(a.size());
|
FFTPoly res(a.size());
|
||||||
|
@ -44,18 +44,18 @@ void KeyGen::get_sk_base(SKey_base_LWE& sk_base)
|
|||||||
|
|
||||||
void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot)
|
void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot)
|
||||||
{
|
{
|
||||||
cout << "Started key-switching key generation" << endl;
|
//cout << "Started key-switching key generation" << endl;
|
||||||
clock_t start = clock();
|
clock_t start = clock();
|
||||||
// reset key-switching key
|
// reset key-switching key
|
||||||
ksk.clear();
|
ksk.clear();
|
||||||
ksk = ModQMatrix(param.Nl, vector<int>(param.n,0));
|
ksk = ModQMatrix(param.Nl, vector<int>(param.n,0));
|
||||||
vector<vector<long>> ksk_long(param.Nl, vector<long>(param.n,0L));
|
vector<vector<long>> ksk_long(param.Nl, vector<long>(param.n,0L));
|
||||||
cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// noise matrix G as in the paper
|
// noise matrix G as in the paper
|
||||||
ModQMatrix G(param.Nl, vector<int>(param.n,0L));
|
ModQMatrix G(param.Nl, vector<int>(param.n,0L));
|
||||||
sampler.get_ternary_matrix(G);
|
sampler.get_ternary_matrix(G);
|
||||||
cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// matrix G + P * Phi(f) * E as in the paper
|
// matrix G + P * Phi(f) * E as in the paper
|
||||||
int coef_w_pwr = sk_boot.sk[0];
|
int coef_w_pwr = sk_boot.sk[0];
|
||||||
@ -73,7 +73,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_
|
|||||||
coef_w_pwr *= Param::B_ksk;
|
coef_w_pwr *= Param::B_ksk;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// parameters of the block optimization of matrix multiplication
|
// parameters of the block optimization of matrix multiplication
|
||||||
int block = 4;
|
int block = 4;
|
||||||
@ -101,7 +101,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_
|
|||||||
k_row[blocks+j] += (coef * f_row[blocks+j]);
|
k_row[blocks+j] += (coef * f_row[blocks+j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// reduce modulo q_base
|
// reduce modulo q_base
|
||||||
for (int i = 0; i < param.Nl; i++)
|
for (int i = 0; i < param.Nl; i++)
|
||||||
@ -111,7 +111,7 @@ void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_
|
|||||||
|
|
||||||
void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
{
|
{
|
||||||
cout << "Started key-switching key generation" << endl;
|
//cout << "Started key-switching key generation" << endl;
|
||||||
clock_t start = clock();
|
clock_t start = clock();
|
||||||
// reset key-switching key
|
// reset key-switching key
|
||||||
ksk.A.clear();
|
ksk.A.clear();
|
||||||
@ -122,11 +122,11 @@ void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_bo
|
|||||||
ksk.A.push_back(row);
|
ksk.A.push_back(row);
|
||||||
}
|
}
|
||||||
ksk.b = vector<int>(param.Nl, 0L);
|
ksk.b = vector<int>(param.Nl, 0L);
|
||||||
cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// noise matrix G as in the paper
|
// noise matrix G as in the paper
|
||||||
sampler.get_uniform_matrix(ksk.A);
|
sampler.get_uniform_matrix(ksk.A);
|
||||||
cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// matrix P * f_0 as in the paper
|
// matrix P * f_0 as in the paper
|
||||||
vector<int> Pf0(param.Nl, 0L);
|
vector<int> Pf0(param.Nl, 0L);
|
||||||
@ -145,7 +145,7 @@ void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_bo
|
|||||||
coef_w_pwr *= Param::B_ksk;
|
coef_w_pwr *= Param::B_ksk;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
//cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
// A*s_base + e + Pf0 as in the paper
|
// A*s_base + e + Pf0 as in the paper
|
||||||
normal_distribution<double> gaussian_sampler(0.0, Param::e_st_dev);
|
normal_distribution<double> gaussian_sampler(0.0, Param::e_st_dev);
|
||||||
@ -257,7 +257,7 @@ void KeyGen::get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_
|
|||||||
coef_counter += param.bsk_partition[iBase];
|
coef_counter += param.bsk_partition[iBase];
|
||||||
}
|
}
|
||||||
|
|
||||||
cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
cout << "Bootstrapping key generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
@ -344,7 +344,7 @@ void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_bo
|
|||||||
coef_counter += param.bsk_partition[iBase];
|
coef_counter += param.bsk_partition[iBase];
|
||||||
}
|
}
|
||||||
|
|
||||||
cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
cout << "Bootstrapping key generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KeyGen::get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
void KeyGen::get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
|
@ -118,17 +118,17 @@ void external_product(vector<long>& res, const vector<int>& poly, const vector<F
|
|||||||
for (int i = 0; i < N; ++i)
|
for (int i = 0; i < N; ++i)
|
||||||
{
|
{
|
||||||
int& abs_val = poly_abs[i];
|
int& abs_val = poly_abs[i];
|
||||||
digit = abs_val & mask; //poly_abs[i] % b;
|
digit = abs_val & mask;
|
||||||
if (digit > bound)
|
if (digit > bound)
|
||||||
{
|
{
|
||||||
poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit);
|
poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit);
|
||||||
abs_val >>= shift;
|
abs_val >>= shift;
|
||||||
++abs_val; //(abs_val - digit)/b + 1;
|
++abs_val;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit;
|
poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit;
|
||||||
abs_val >>= shift; //(abs_val - digit)/b;
|
abs_val >>= shift;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fftN.to_fft(tmp_fft, poly_decomp);
|
fftN.to_fft(tmp_fft, poly_decomp);
|
||||||
@ -136,7 +136,6 @@ void external_product(vector<long>& res, const vector<int>& poly, const vector<F
|
|||||||
res_fft += tmp_fft;
|
res_fft += tmp_fft;
|
||||||
}
|
}
|
||||||
fftN.from_fft(res, res_fft);
|
fftN.from_fft(res, res_fft);
|
||||||
//mod_q_boot(poly);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SchemeLWE::key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const
|
void SchemeLWE::key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const
|
||||||
@ -200,65 +199,6 @@ void SchemeLWE::key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const
|
|||||||
ct.b = parLWE.mod_q_base(b);
|
ct.b = parLWE.mod_q_base(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
// debugger functions
|
|
||||||
void print(const vector<int>& vec)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < vec.size(); i++)
|
|
||||||
{
|
|
||||||
printf("[%zu] %d ", i, vec[i]);
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk)
|
|
||||||
{
|
|
||||||
FFTPoly sk_fft;
|
|
||||||
fftN.to_fft(sk_fft, sk.sk);
|
|
||||||
FFTPoly ct_fft;
|
|
||||||
fftN.to_fft(ct_fft, ct);
|
|
||||||
|
|
||||||
FFTPoly output_fft;
|
|
||||||
output_fft = ct_fft * sk_fft;
|
|
||||||
vector<long> output;
|
|
||||||
vector<int> output_int;
|
|
||||||
fftN.from_fft(output, output_fft);
|
|
||||||
parLWE.mod_q_boot(output_int, output);
|
|
||||||
print(output_int);
|
|
||||||
}
|
|
||||||
|
|
||||||
void decrypt_poly_base_and_print(const ModQPoly& ct, const SKey_boot& sk)
|
|
||||||
{
|
|
||||||
FFTPoly sk_fft;
|
|
||||||
fftN.to_fft(sk_fft, sk.sk);
|
|
||||||
FFTPoly ct_fft;
|
|
||||||
fftN.to_fft(ct_fft, ct);
|
|
||||||
|
|
||||||
FFTPoly output_fft;
|
|
||||||
output_fft = ct_fft * sk_fft;
|
|
||||||
ModQPoly output;
|
|
||||||
fftN.from_fft(output, output_fft);
|
|
||||||
mod_q_base(output);
|
|
||||||
print(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void decryptN2(const Ctxt_LWE& ct, const SKey_base_LWE& sk)
|
|
||||||
{
|
|
||||||
int output = ct.b;
|
|
||||||
for (int i = 0; i < lwe_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct.a[i] * sk[i];
|
|
||||||
}
|
|
||||||
output = output%parLWE.N2;
|
|
||||||
if (output > parLWE.N)
|
|
||||||
output -= parLWE.N2;
|
|
||||||
if (output <= -parLWE.N)
|
|
||||||
output += parLWE.N2;
|
|
||||||
cout << output << endl;
|
|
||||||
}
|
|
||||||
// end debugger functions
|
|
||||||
*/
|
|
||||||
|
|
||||||
void SchemeLWE::bootstrap(Ctxt_LWE& ct) const
|
void SchemeLWE::bootstrap(Ctxt_LWE& ct) const
|
||||||
{
|
{
|
||||||
//clock_t start = clock();
|
//clock_t start = clock();
|
||||||
@ -304,10 +244,6 @@ void SchemeLWE::bootstrap(Ctxt_LWE& ct) const
|
|||||||
Bd = double(B);
|
Bd = double(B);
|
||||||
shift = parLWE.shift_bsk[iBase];
|
shift = parLWE.shift_bsk[iBase];
|
||||||
l = parLWE.l_bsk[iBase];
|
l = parLWE.l_bsk[iBase];
|
||||||
//vector<complex<double>> w_powers(l);
|
|
||||||
//w_powers[0] = complex<double>(1.0,0.0);
|
|
||||||
//for (int i = 1; i < l; i++)
|
|
||||||
// w_powers[i] = w_powers[i-1] * Bd;
|
|
||||||
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
||||||
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; ++iCoef)
|
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; ++iCoef)
|
||||||
{
|
{
|
||||||
@ -359,13 +295,9 @@ void SchemeLWE::bootstrap(Ctxt_LWE& ct) const
|
|||||||
|
|
||||||
//mod q_boot of the accumulator
|
//mod q_boot of the accumulator
|
||||||
mod_q_boot(acc);
|
mod_q_boot(acc);
|
||||||
|
|
||||||
//decrypt_poly_boot_and_print(acc, sk_boot);
|
|
||||||
|
|
||||||
//mod switch to q_base
|
//mod switch to q_base
|
||||||
modulo_switch_to_base_lwe(acc);
|
modulo_switch_to_base_lwe(acc);
|
||||||
|
|
||||||
//decrypt_poly_boot_and_print(acc, sk_boot);
|
|
||||||
|
|
||||||
//key switch
|
//key switch
|
||||||
//auto start = clock();
|
//auto start = clock();
|
||||||
@ -420,10 +352,6 @@ void SchemeLWE::bootstrap2(Ctxt_LWE& ct) const
|
|||||||
Bd = double(B);
|
Bd = double(B);
|
||||||
shift = parLWE.shift_bsk[iBase];
|
shift = parLWE.shift_bsk[iBase];
|
||||||
l = parLWE.l_bsk[iBase];
|
l = parLWE.l_bsk[iBase];
|
||||||
//vector<complex<double>> w_powers(l);
|
|
||||||
//w_powers[0] = complex<double>(1.0,0.0);
|
|
||||||
//for (int i = 1; i < l; i++)
|
|
||||||
// w_powers[i] = w_powers[i-1] * Bd;
|
|
||||||
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
||||||
vector<FFTPoly> mux_fft(l,FFTPoly(N2p1,complex<double>(0.0,0.0)));
|
vector<FFTPoly> mux_fft(l,FFTPoly(N2p1,complex<double>(0.0,0.0)));
|
||||||
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; iCoef+=2)
|
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; iCoef+=2)
|
||||||
@ -508,13 +436,9 @@ void SchemeLWE::bootstrap2(Ctxt_LWE& ct) const
|
|||||||
|
|
||||||
//mod q_boot of the accumulator
|
//mod q_boot of the accumulator
|
||||||
mod_q_boot(acc);
|
mod_q_boot(acc);
|
||||||
|
|
||||||
//decrypt_poly_boot_and_print(acc, sk_boot);
|
|
||||||
|
|
||||||
//mod switch to q_base
|
//mod switch to q_base
|
||||||
modulo_switch_to_base_lwe(acc);
|
modulo_switch_to_base_lwe(acc);
|
||||||
|
|
||||||
//decrypt_poly_boot_and_print(acc, sk_boot);
|
|
||||||
|
|
||||||
//key switch
|
//key switch
|
||||||
//auto start = clock();
|
//auto start = clock();
|
||||||
|
131
src/ntruhe.cpp
131
src/ntruhe.cpp
@ -81,52 +81,6 @@ int SchemeNTRU::decrypt(const Ctxt_NTRU& ct) const
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
void SchemeNTRU::external_product(vector<long>& res, const vector<int>& poly, const vector<FFTPoly>& poly_vector, const int b, const int shift, const int l) const
|
|
||||||
{
|
|
||||||
int N = Param::N;
|
|
||||||
int N2p1 = Param::N2p1;
|
|
||||||
|
|
||||||
ModQPoly poly_sign(N,0L);
|
|
||||||
ModQPoly poly_abs(N,0L);
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
const int& polyi = poly[i];
|
|
||||||
poly_abs[i] = abs(polyi);
|
|
||||||
poly_sign[i] = (polyi < 0)? -1 : 1;
|
|
||||||
}
|
|
||||||
FFTPoly res_fft(N2p1);
|
|
||||||
FFTPoly tmp_fft(N2p1);
|
|
||||||
int mask = b-1;
|
|
||||||
int bound = b >> 1;
|
|
||||||
int digit, sgn, abs_val;
|
|
||||||
vector<int> poly_decomp(N);
|
|
||||||
for (int j = 0; j < l; j++)
|
|
||||||
{
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
abs_val = poly_abs[i];
|
|
||||||
digit = abs_val & mask; //poly_abs[i] % b;
|
|
||||||
if (digit > bound)
|
|
||||||
{
|
|
||||||
poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit);
|
|
||||||
poly_abs[i] = (abs_val >> shift) + 1; //(abs_val - digit)/b + 1;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit;
|
|
||||||
poly_abs[i] = abs_val >> shift; //(abs_val - digit)/b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fftN.to_fft(tmp_fft, poly_decomp);
|
|
||||||
tmp_fft *= poly_vector[j];
|
|
||||||
res_fft += tmp_fft;
|
|
||||||
}
|
|
||||||
fftN.from_fft(res, res_fft);
|
|
||||||
//mod_q_boot(poly);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const
|
void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const
|
||||||
{
|
{
|
||||||
int N = Param::N;
|
int N = Param::N;
|
||||||
@ -181,87 +135,6 @@ void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const
|
|||||||
parNTRU.mod_q_base(ct.data, ct_long);
|
parNTRU.mod_q_base(ct.data, ct_long);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
// debugger functions
|
|
||||||
void print(const vector<int>& vec)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < vec.size(); i++)
|
|
||||||
{
|
|
||||||
printf("[%zu] %d ", i, vec[i]);
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk, const Param& param)
|
|
||||||
{
|
|
||||||
FFTPoly sk_fft(Param::N2p1);
|
|
||||||
fftN.to_fft(sk_fft, sk.sk);
|
|
||||||
FFTPoly ct_fft(Param::N2p1);
|
|
||||||
fftN.to_fft(ct_fft, ct);
|
|
||||||
|
|
||||||
FFTPoly output_fft;
|
|
||||||
output_fft = ct_fft * sk_fft;
|
|
||||||
ModQPoly output;
|
|
||||||
vector<long> output_long;
|
|
||||||
fftN.from_fft(output_long, output_fft);
|
|
||||||
mod_q_boot(output, output_long);
|
|
||||||
print(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void decrypt_poly_base_and_print(const ModQPoly& ct, const Param& param, const SKey_boot& sk)
|
|
||||||
{
|
|
||||||
FFTPoly sk_fft(Param::N2p1);
|
|
||||||
fftN.to_fft(sk_fft, sk.sk);
|
|
||||||
FFTPoly ct_fft(Param::N2p1);
|
|
||||||
fftN.to_fft(ct_fft, ct);
|
|
||||||
|
|
||||||
FFTPoly output_fft;
|
|
||||||
output_fft = ct_fft * sk_fft;
|
|
||||||
ModQPoly output;
|
|
||||||
vector<long> output_long;
|
|
||||||
fftN.from_fft(output_long, output_fft);
|
|
||||||
param.mod_q_base(output, output_long);
|
|
||||||
print(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void decryptN2(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk)
|
|
||||||
{
|
|
||||||
int N = Param::N;
|
|
||||||
int N2 = Param::N2;
|
|
||||||
int n = parNTRU.n;
|
|
||||||
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
{
|
|
||||||
output += ct.data[i] * sk.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%N2;
|
|
||||||
if (output > N)
|
|
||||||
output -= N2;
|
|
||||||
if (output <= -N)
|
|
||||||
output += N2;
|
|
||||||
cout << output << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void decrypt_base(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk)
|
|
||||||
{
|
|
||||||
int n = parNTRU.n;
|
|
||||||
int q_base = parNTRU.q_base;
|
|
||||||
int half_q_base= parNTRU.half_q_base;
|
|
||||||
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
{
|
|
||||||
output += ct.data[i] * sk.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%q_base;
|
|
||||||
if (output > half_q_base)
|
|
||||||
output -= q_base;
|
|
||||||
if (output <= -half_q_base)
|
|
||||||
output += q_base;
|
|
||||||
cout << output << endl;
|
|
||||||
}
|
|
||||||
// end debugger functions
|
|
||||||
*/
|
|
||||||
void SchemeNTRU::mask_constant(Ctxt_NTRU& ct, int constant)
|
void SchemeNTRU::mask_constant(Ctxt_NTRU& ct, int constant)
|
||||||
{
|
{
|
||||||
int n = parNTRU.n;
|
int n = parNTRU.n;
|
||||||
@ -313,10 +186,6 @@ void SchemeNTRU::bootstrap(Ctxt_NTRU& ct) const
|
|||||||
Bd = double(B);
|
Bd = double(B);
|
||||||
shift = parNTRU.shift_bsk[iBase];
|
shift = parNTRU.shift_bsk[iBase];
|
||||||
l = parNTRU.l_bsk[iBase];
|
l = parNTRU.l_bsk[iBase];
|
||||||
//vector<complex<double>> w_power_fft(l);
|
|
||||||
//w_power_fft[0] = complex<double>(1.0,0.0);
|
|
||||||
//for (int i = 1; i < l; i++)
|
|
||||||
// w_power_fft[i] = w_power_fft[i-1] * Bd;
|
|
||||||
const vector<vector<NGSFFTctxt>>& bk_coef_row = boot_key[iBase];
|
const vector<vector<NGSFFTctxt>>& bk_coef_row = boot_key[iBase];
|
||||||
vector<FFTPoly> mux_fft(l, FFTPoly(N2p1));
|
vector<FFTPoly> mux_fft(l, FFTPoly(N2p1));
|
||||||
for (int iCoef = 0; iCoef < parNTRU.bsk_partition[iBase]; ++iCoef)
|
for (int iCoef = 0; iCoef < parNTRU.bsk_partition[iBase]; ++iCoef)
|
||||||
|
529
test.cpp
529
test.cpp
@ -270,510 +270,6 @@ void test_sampler()
|
|||||||
cout << "SAMPLER IS OK" << endl;
|
cout << "SAMPLER IS OK" << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_ntru_key_gen()
|
|
||||||
{
|
|
||||||
Param param(NTRU);
|
|
||||||
int n = param.n;
|
|
||||||
int Nl = param.Nl;
|
|
||||||
int half_q_base = param.half_q_base;
|
|
||||||
int q_base = param.q_base;
|
|
||||||
int l_ksk = param.l_ksk;
|
|
||||||
int N = Param::N;
|
|
||||||
int t = Param::t;
|
|
||||||
int B_ksk = Param::B_ksk;
|
|
||||||
int B_bsk_size = Param::B_bsk_size;
|
|
||||||
int N2p1 = Param::N2p1;
|
|
||||||
|
|
||||||
SKey_base_NTRU sk_base;
|
|
||||||
KeyGen k(param);
|
|
||||||
k.get_sk_base(sk_base);
|
|
||||||
cout << "Secret key of the base scheme is generated" << endl;
|
|
||||||
assert(sk_base.sk.size() == n && sk_base.sk[0].size() == n
|
|
||||||
&& sk_base.sk_inv.size() == n && sk_base.sk_inv[0].size() == n);
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
for (int j = 0; j < n; j++)
|
|
||||||
{
|
|
||||||
assert((sk_base.sk[i][j]==0) || (sk_base.sk[i][j]==-1) || (sk_base.sk[i][j]==1) );
|
|
||||||
}
|
|
||||||
|
|
||||||
SKey_boot sk_boot;
|
|
||||||
k.get_sk_boot(sk_boot);
|
|
||||||
cout << "Secret key of the bootstrapping scheme is generated" << endl;
|
|
||||||
assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N);
|
|
||||||
assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1)));
|
|
||||||
for (int i = 1; i < N; i++)
|
|
||||||
{
|
|
||||||
assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) );
|
|
||||||
}
|
|
||||||
|
|
||||||
KSKey_NTRU ksk;
|
|
||||||
k.get_ksk(ksk, sk_base, sk_boot);
|
|
||||||
cout << "Key-switching key is generated" << endl;
|
|
||||||
assert(ksk.size() == Nl && ksk[0].size() == n);
|
|
||||||
for (int i = 0; i < Nl; i++)
|
|
||||||
for (int j = 0; j < n; j++)
|
|
||||||
{
|
|
||||||
//cout << ksk[i][j] << endl;
|
|
||||||
assert(ksk[i][j] <= half_q_base && ksk[i][j] >= -half_q_base);
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<int> q4_decomp;
|
|
||||||
decompose(q4_decomp, q_base/4, B_ksk, l_ksk);
|
|
||||||
vector<int> ks_res(n,0L);
|
|
||||||
for (int i = 0; i < l_ksk; i++)
|
|
||||||
{
|
|
||||||
int tmp_int = q4_decomp[i];
|
|
||||||
vector<int>& ksk_row = ksk[i];
|
|
||||||
for (int j = 0; j < n; j++)
|
|
||||||
{
|
|
||||||
ks_res[j] += ksk_row[j] * tmp_int;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
param.mod_q_base(ks_res);
|
|
||||||
int ks_int = 0;
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
{
|
|
||||||
ks_int += ks_res[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
ks_int = param.mod_q_base(ks_int);
|
|
||||||
ks_int = int(round(double(ks_int*4)/double(q_base)));
|
|
||||||
assert(ks_int == 1L);
|
|
||||||
|
|
||||||
// bootstrapping key test
|
|
||||||
BSKey_NTRU bsk;
|
|
||||||
k.get_bsk(bsk, sk_base, sk_boot);
|
|
||||||
cout << "Bootstrapping key is generated" << endl;
|
|
||||||
|
|
||||||
// check dimensions
|
|
||||||
assert(bsk.size() == B_bsk_size);
|
|
||||||
for (int i = 0; i < bsk.size(); i++)
|
|
||||||
{
|
|
||||||
assert(bsk[i].size() == param.bsk_partition[i]);
|
|
||||||
for (int j = 0; j < bsk[i].size(); j++)
|
|
||||||
{
|
|
||||||
assert(bsk[i][j].size() == 2);
|
|
||||||
assert(bsk[i][j][0].size() == param.l_bsk[i]);
|
|
||||||
assert(bsk[i][j][1].size() == param.l_bsk[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// convert sk_boot to FFT
|
|
||||||
vector<complex<double>> sk_boot_fft(N2p1);
|
|
||||||
fftN.to_fft(sk_boot_fft, sk_boot.sk);
|
|
||||||
|
|
||||||
int coef_counter = 0;
|
|
||||||
for (int iBase = 0; iBase < B_bsk_size; iBase++)
|
|
||||||
{
|
|
||||||
decompose(q4_decomp, q_boot/4, param.B_bsk[iBase], param.l_bsk[iBase]);
|
|
||||||
for (size_t iCoef = 0; iCoef < bsk[iBase].size(); iCoef++)
|
|
||||||
{
|
|
||||||
int sk_coef = 0;
|
|
||||||
int sk_base_coef_bits[2];
|
|
||||||
for (int iBit = 0; iBit < 2; iBit++)
|
|
||||||
{
|
|
||||||
vector<complex<double>> tmp_fft(N2p1, complex<double>(0.0,0.0));
|
|
||||||
for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++)
|
|
||||||
{
|
|
||||||
tmp_fft = tmp_fft + bsk[iBase][iCoef][iBit][iPart] * q4_decomp[iPart];
|
|
||||||
}
|
|
||||||
tmp_fft = tmp_fft * sk_boot_fft;
|
|
||||||
vector<int> tmp_int;
|
|
||||||
vector<long> tmp_long;
|
|
||||||
fftN.from_fft(tmp_long, tmp_fft);
|
|
||||||
mod_q_boot(tmp_int, tmp_long);
|
|
||||||
sk_base_coef_bits[iBit] = int(round(double(tmp_int[0]*4)/double(q_boot)));
|
|
||||||
}
|
|
||||||
if (sk_base_coef_bits[1] == 1)
|
|
||||||
sk_coef = -1;
|
|
||||||
else if (sk_base_coef_bits[0] == 1)
|
|
||||||
sk_coef = 1;
|
|
||||||
|
|
||||||
assert(sk_coef == sk_base.sk[coef_counter + iCoef][0]);
|
|
||||||
}
|
|
||||||
coef_counter += param.bsk_partition[iBase];
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "KEYGEN IS OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_lwe_key_gen()
|
|
||||||
{
|
|
||||||
Param param(LWE);
|
|
||||||
int n = param.n;
|
|
||||||
int N = Param::N;
|
|
||||||
int t = Param::t;
|
|
||||||
|
|
||||||
SKey_base_LWE sk_base;
|
|
||||||
KeyGen k(param);
|
|
||||||
k.get_sk_base(sk_base);
|
|
||||||
cout << "Secret key of the base scheme is generated" << endl;
|
|
||||||
assert(sk_base.size() == n);
|
|
||||||
for (int j = 0; j < n; j++)
|
|
||||||
{
|
|
||||||
assert((sk_base[j]==0) || (sk_base[j]==1));
|
|
||||||
}
|
|
||||||
|
|
||||||
SKey_boot sk_boot;
|
|
||||||
k.get_sk_boot(sk_boot);
|
|
||||||
cout << "Secret key of the bootstrapping scheme is generated" << endl;
|
|
||||||
assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N);
|
|
||||||
assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1)));
|
|
||||||
for (int i = 1; i < N; i++)
|
|
||||||
{
|
|
||||||
assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) );
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "KEYGEN IS OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_fft()
|
|
||||||
{
|
|
||||||
int N = Param::N;
|
|
||||||
int N2p1 = Param::N2p1;
|
|
||||||
|
|
||||||
FFT_engine fft_engine(N);
|
|
||||||
{
|
|
||||||
vector<int> in(N,0L);
|
|
||||||
vector<complex<double>> out(N2p1);
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.to_fft(out, in);
|
|
||||||
cout << "Forward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
for (size_t i = 0; i < N/2; i++)
|
|
||||||
{
|
|
||||||
if (int(round(real(out[i])))!=0 || int(round(imag(out[i])))!=0)
|
|
||||||
{
|
|
||||||
cout << i << " " << out[i] << endl;
|
|
||||||
assert(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
vector<long> out;
|
|
||||||
vector<complex<double>> in(N2p1, complex<double>(0.0,0.0));
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.from_fft(out, in);
|
|
||||||
cout << "Backward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
for (size_t i = 0; i < N/2; i++)
|
|
||||||
{
|
|
||||||
assert(out[i] == 0L);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
vector<int> in(N,0L);
|
|
||||||
in[0] = 1L;
|
|
||||||
vector<complex<double>> out(N2p1);
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.to_fft(out, in);
|
|
||||||
cout << "Forward FFT (1,0,...0): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
for (size_t i = 0; i < N/2; i++)
|
|
||||||
{
|
|
||||||
assert(int(round(real(out[i])))==1 && int(round(imag(out[i])))==0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
vector<long> out;
|
|
||||||
vector<complex<double>> in(N2p1, complex<double>(1.0,0.0));
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.from_fft(out, in);
|
|
||||||
cout << "Backward FFT (1,1,...1): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
assert(out[0] == 1L);
|
|
||||||
for (size_t i = 1; i < N; i++)
|
|
||||||
{
|
|
||||||
assert(out[i] == 0L);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
uniform_int_distribution<int> sampler(INT_MIN, INT_MAX);
|
|
||||||
int coef = sampler(rand_engine);
|
|
||||||
vector<long> out;
|
|
||||||
vector<complex<double>> in(N2p1, complex<double>(double(coef),0.0));
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.from_fft(out, in);
|
|
||||||
cout << "Backward FFT (a,a,...a): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
assert(out[0] == coef);
|
|
||||||
for (size_t i = 1; i < N; i++)
|
|
||||||
{
|
|
||||||
assert(out[i] == 0L);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
uniform_int_distribution<int> sampler(INT_MIN, INT_MAX);
|
|
||||||
vector<int> in;
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
in.push_back(sampler(rand_engine));
|
|
||||||
vector<complex<double>> interm(N2p1);
|
|
||||||
vector<long> out;
|
|
||||||
clock_t start = clock();
|
|
||||||
fft_engine.to_fft(interm, in);
|
|
||||||
cout << "Forward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
start = clock();
|
|
||||||
fft_engine.from_fft(out, interm);
|
|
||||||
cout << "Backward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
for (size_t i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
//cout << "i: " << i << "in[i]: " << in[i] << " out[i]: " << out[i] << endl;
|
|
||||||
assert(in[i] == out[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
uniform_int_distribution<int> sampler(-100, 100);
|
|
||||||
vector<int> in1, in2, res;
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
in1.push_back(sampler(rand_engine));
|
|
||||||
in2.push_back(sampler(rand_engine));
|
|
||||||
res.push_back(in1[i]+in2[i]);
|
|
||||||
}
|
|
||||||
vector<complex<double>> interm1(N2p1);
|
|
||||||
vector<complex<double>> interm2(N2p1);
|
|
||||||
vector<complex<double>> intermres(N2p1);
|
|
||||||
vector<long> out;
|
|
||||||
|
|
||||||
fft_engine.to_fft(interm1, in1);
|
|
||||||
fft_engine.to_fft(interm2, in2);
|
|
||||||
|
|
||||||
clock_t start = clock();
|
|
||||||
intermres = interm1 + interm2;
|
|
||||||
cout << "FFT addition: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
|
|
||||||
fft_engine.from_fft(out, intermres);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
//cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl;
|
|
||||||
assert(res[i] == out[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
uniform_int_distribution<int> sampler(-100, 100);
|
|
||||||
vector<int> in1, in2, res;
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
in1.push_back(sampler(rand_engine));
|
|
||||||
in2.push_back(sampler(rand_engine));
|
|
||||||
}
|
|
||||||
ZZX poly1, poly2, poly_res;
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
SetCoeff(poly1, i, in1[i]);
|
|
||||||
SetCoeff(poly2, i, in2[i]);
|
|
||||||
}
|
|
||||||
ZZX poly_mod;
|
|
||||||
SetCoeff(poly_mod, 0, 1);
|
|
||||||
SetCoeff(poly_mod, N, 1);
|
|
||||||
|
|
||||||
clock_t start = clock();
|
|
||||||
MulMod(poly_res, poly1, poly2, poly_mod);
|
|
||||||
cout << "NTL multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
|
|
||||||
for (int i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
res.push_back(conv<long>(poly_res[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<complex<double>> interm1(N2p1);
|
|
||||||
vector<complex<double>> interm2(N2p1);
|
|
||||||
vector<complex<double>> intermres(N2p1);
|
|
||||||
vector<long> out;
|
|
||||||
|
|
||||||
fft_engine.to_fft(interm1, in1);
|
|
||||||
fft_engine.to_fft(interm2, in2);
|
|
||||||
|
|
||||||
start = clock();
|
|
||||||
intermres = interm1 * interm2;
|
|
||||||
cout << "FFT multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
|
||||||
|
|
||||||
fft_engine.from_fft(out, intermres);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < N; i++)
|
|
||||||
{
|
|
||||||
//cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl;
|
|
||||||
assert(res[i] == out[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "FFT is OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_ntruhe_encrypt()
|
|
||||||
{
|
|
||||||
SchemeNTRU s;
|
|
||||||
|
|
||||||
{
|
|
||||||
int input = 0;
|
|
||||||
Ctxt_NTRU ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 1;
|
|
||||||
Ctxt_NTRU ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
cout << "NTRU ENCRYPTION IS OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
void test_lwehe_encrypt()
|
|
||||||
{
|
|
||||||
SchemeLWE s;
|
|
||||||
|
|
||||||
{
|
|
||||||
int input = 0;
|
|
||||||
Ctxt_LWE ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 1;
|
|
||||||
Ctxt_LWE ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
cout << "LWE ENCRYPTION IS OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
void test_mod_switch()
|
|
||||||
{
|
|
||||||
SchemeNTRU s;
|
|
||||||
{
|
|
||||||
int input = 0;
|
|
||||||
Ctxt_NTRU ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
s.modulo_switch_to_base(ct.data);
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < ntru_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct.data[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%Param::N2;
|
|
||||||
if (output > Param::N)
|
|
||||||
output -= Param::N2;
|
|
||||||
else if (output <= -Param::N)
|
|
||||||
output += Param::N2;
|
|
||||||
output = int(round(double(output*t)/double(Param::N2)));
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 1;
|
|
||||||
ntru_he::Ctxt ct;
|
|
||||||
ntru_he::encrypt(ct, input, sk_base);
|
|
||||||
ntru_he::modulo_switch(ct, ntru_he::q_base, Param::N2);
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < ntru_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%Param::N2;
|
|
||||||
if (output > Param::N)
|
|
||||||
output -= Param::N2;
|
|
||||||
else if (output <= -Param::N)
|
|
||||||
output += Param::N2;
|
|
||||||
output = int(round(double(output*t)/double(Param::N2)));
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 0;
|
|
||||||
ntru_he::Ctxt ct;
|
|
||||||
ntru_he::encrypt(ct, input, sk_base);
|
|
||||||
ntru_he::modulo_switch_to_boot(ct);
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < ntru_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%Param::N2;
|
|
||||||
if (output > Param::N)
|
|
||||||
output -= Param::N2;
|
|
||||||
else if (output <= -Param::N)
|
|
||||||
output += Param::N2;
|
|
||||||
output = int(round(double(output*t)/double(Param::N2)));
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 1;
|
|
||||||
ntru_he::Ctxt ct;
|
|
||||||
ntru_he::encrypt(ct, input, sk_base);
|
|
||||||
ntru_he::modulo_switch_to_boot(ct);
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < ntru_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
output = output%Param::N2;
|
|
||||||
if (output > Param::N)
|
|
||||||
output -= Param::N2;
|
|
||||||
else if (output <= -Param::N)
|
|
||||||
output += Param::N2;
|
|
||||||
output = int(round(double(output*t)/double(Param::N2)));
|
|
||||||
assert(output == input);
|
|
||||||
}
|
|
||||||
cout << "MODULO SWITCHING IS OK" << endl;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
void test_bootstrap()
|
|
||||||
{
|
|
||||||
SchemeNTRU s;
|
|
||||||
{
|
|
||||||
int input = 2;
|
|
||||||
Ctxt_NTRU ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
|
|
||||||
s.bootstrap(ct);
|
|
||||||
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
cout << "Bootstrapping output: " << output << endl;
|
|
||||||
assert(output == 1L);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
int input = 0;
|
|
||||||
Ctxt_NTRU ct;
|
|
||||||
s.encrypt(ct, input);
|
|
||||||
|
|
||||||
s.bootstrap(ct);
|
|
||||||
|
|
||||||
int output = s.decrypt(ct);
|
|
||||||
cout << "Bootstrapping output: " << output << endl;
|
|
||||||
assert(output == 0L);
|
|
||||||
}
|
|
||||||
|
|
||||||
cout << "BOOTSTRAPPING IS OK" << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
void test_nand_aux()
|
|
||||||
{
|
|
||||||
ntru_he::SKey_base sk_base;
|
|
||||||
ntru_he::get_sk_base(sk_base);
|
|
||||||
|
|
||||||
ntru_he::Ctxt ct;
|
|
||||||
ntru_he::get_nand_aux(ct, sk_base);
|
|
||||||
int output = 0;
|
|
||||||
for (int i = 0; i < ntru_he::n; i++)
|
|
||||||
{
|
|
||||||
output += ct[i] * sk_base.sk[i][0];
|
|
||||||
}
|
|
||||||
output = ntru_he::mod_q_base(output);
|
|
||||||
assert(
|
|
||||||
output == (ntru_he::nand_const-ntru_he::q_base)
|
|
||||||
|| output == (ntru_he::nand_const-ntru_he::q_base+1)
|
|
||||||
|| output == (ntru_he::nand_const-ntru_he::q_base-1)
|
|
||||||
);
|
|
||||||
cout << "NAND ENCRYPTION IS OK" << endl;
|
|
||||||
}*/
|
|
||||||
|
|
||||||
enum GateType {NAND, AND, OR};
|
enum GateType {NAND, AND, OR};
|
||||||
|
|
||||||
void test_ntruhe_gate_helper(int in1, int in2, const SchemeNTRU& s, GateType g)
|
void test_ntruhe_gate_helper(int in1, int in2, const SchemeNTRU& s, GateType g)
|
||||||
@ -940,21 +436,20 @@ void test_lwehe_or()
|
|||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
//test_params();
|
test_params();
|
||||||
//test_sampler();
|
test_sampler();
|
||||||
//test_ntru_key_gen();
|
|
||||||
//test_lwe_key_gen();
|
cout << "NTRU tests" << endl;
|
||||||
//test_fft();
|
test_ntruhe_nand();
|
||||||
//test_ntruhe_encrypt();
|
test_ntruhe_and();
|
||||||
//test_lwehe_encrypt();
|
test_ntruhe_or();
|
||||||
//test_mod_switch();
|
cout << "NTRU tests PASSED" << endl;
|
||||||
//test_bootstrap();
|
|
||||||
//test_nand_aux();
|
cout << "LWE tests" << endl;
|
||||||
//test_ntruhe_nand();
|
|
||||||
test_lwehe_nand();
|
test_lwehe_nand();
|
||||||
//test_ntruhe_and();
|
|
||||||
test_lwehe_and();
|
test_lwehe_and();
|
||||||
//test_ntruhe_or();
|
|
||||||
test_lwehe_or();
|
test_lwehe_or();
|
||||||
|
cout << "LWE tests PASSED" << endl;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
Reference in New Issue
Block a user