mirror of
https://github.com/mii443/FINAL.git
synced 2025-08-22 15:05:36 +00:00
Add files via upload
This commit is contained in:
26
Makefile
Normal file
26
Makefile
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
CCX = g++
|
||||||
|
CCXFLAGS = -O3 -funroll-loops -march=native -std=c++11 -pthread -I. -I./include
|
||||||
|
DEPS = -lntl -lgmp -lfftw3 -lm
|
||||||
|
|
||||||
|
all: clean test
|
||||||
|
|
||||||
|
clean:
|
||||||
|
$(RM) test test.o lwehe.o ntruhe.o fft.o sampler.o keygen.o
|
||||||
|
|
||||||
|
test: include/params.h ntruhe.o lwehe.o keygen.o fft.o sampler.o
|
||||||
|
$(CCX) $(CCXFLAGS) -o test test.cpp ntruhe.o lwehe.o keygen.o fft.o sampler.o $(DEPS)
|
||||||
|
|
||||||
|
ntruhe.o: include/ntruhe.h keygen.o sampler.o lwehe.o src/ntruhe.cpp
|
||||||
|
$(CCX) $(CCXFLAGS) -c src/ntruhe.cpp
|
||||||
|
|
||||||
|
lwehe.o: include/lwehe.h keygen.o sampler.o src/lwehe.cpp
|
||||||
|
$(CCX) $(CCXFLAGS) -c src/lwehe.cpp
|
||||||
|
|
||||||
|
keygen.o: include/keygen.h sampler.o fft.o src/keygen.cpp
|
||||||
|
$(CCX) $(CCXFLAGS) -c src/keygen.cpp
|
||||||
|
|
||||||
|
fft.o: include/fft.h
|
||||||
|
$(CCX) $(CCXFLAGS) -c src/fft.cpp
|
||||||
|
|
||||||
|
sampler.o: include/sampler.h include/params.h src/sampler.cpp
|
||||||
|
$(CCX) $(CCXFLAGS) -c src/sampler.cpp
|
52
include/fft.h
Normal file
52
include/fft.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#ifndef FFT
|
||||||
|
#define FFT
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <fftw3.h>
|
||||||
|
#include <complex>
|
||||||
|
#include <map>
|
||||||
|
#include "params.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
//TODO: description
|
||||||
|
|
||||||
|
class FFT_engine
|
||||||
|
{
|
||||||
|
int fft_dim;
|
||||||
|
int fft_dim2;
|
||||||
|
|
||||||
|
fftw_plan plan_to_fft;
|
||||||
|
fftw_plan plan_from_fft;
|
||||||
|
|
||||||
|
double* in_array;
|
||||||
|
fftw_complex* out_array;
|
||||||
|
|
||||||
|
public:
|
||||||
|
//map<int, vector<FFTPoly>> x_powers;
|
||||||
|
vector<FFTPoly> pos_powers;
|
||||||
|
vector<FFTPoly> neg_powers;
|
||||||
|
|
||||||
|
FFT_engine() = delete;
|
||||||
|
FFT_engine(const int dim);
|
||||||
|
|
||||||
|
void to_fft(FFTPoly& out, const ModQPoly& in) const;
|
||||||
|
void from_fft(vector<long>& out, const FFTPoly& in) const;
|
||||||
|
|
||||||
|
~FFT_engine();
|
||||||
|
};
|
||||||
|
|
||||||
|
FFTPoly operator *(const FFTPoly& a, const FFTPoly& b);
|
||||||
|
void operator *=(FFTPoly& a, const FFTPoly& b);
|
||||||
|
FFTPoly operator *(const FFTPoly& a, const int b);
|
||||||
|
FFTPoly operator +(const FFTPoly& a, const FFTPoly& b);
|
||||||
|
void operator +=(FFTPoly& a, const FFTPoly& b);
|
||||||
|
void operator +=(FFTPoly& a, const complex<double> b);
|
||||||
|
FFTPoly operator -(const FFTPoly& a, const FFTPoly& b);
|
||||||
|
void operator -=(FFTPoly& a, const FFTPoly& b);
|
||||||
|
|
||||||
|
// global FFT engine of dimension N
|
||||||
|
const FFT_engine fftN(Param::N);
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
125
include/keygen.h
Normal file
125
include/keygen.h
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#ifndef KEYGEN
|
||||||
|
#define KEYGEN
|
||||||
|
|
||||||
|
#include <NTL/mat_ZZ.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <fftw3.h>
|
||||||
|
#include <complex>
|
||||||
|
#include "params.h"
|
||||||
|
#include "sampler.h"
|
||||||
|
|
||||||
|
// secret key of the bootstrapping scheme
|
||||||
|
typedef struct {
|
||||||
|
ModQPoly sk;
|
||||||
|
ModQPoly sk_inv;
|
||||||
|
} SKey_boot;
|
||||||
|
|
||||||
|
// secret key of the NTRU base scheme
|
||||||
|
typedef struct {
|
||||||
|
ModQMatrix sk;
|
||||||
|
ModQMatrix sk_inv;
|
||||||
|
} SKey_base_NTRU;
|
||||||
|
|
||||||
|
// secret key of the LWE base scheme
|
||||||
|
typedef std::vector<int> SKey_base_LWE;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bootstrapping key.
|
||||||
|
* It consists of several sets of keys corresponding to different
|
||||||
|
* decomposition bases of the bootstrapping key B_bsk.
|
||||||
|
* The i-th set contains vectors with l_bsk[i] complex vectors.
|
||||||
|
* These complex vectors are an encryption of some bit of the secret key
|
||||||
|
* of the base scheme in the NGS form.
|
||||||
|
*/
|
||||||
|
typedef std::vector<std::vector<std::vector<NGSFFTctxt>>> BSKey_NTRU;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bootstrapping key.
|
||||||
|
* It consists of several sets of keys corresponding to different
|
||||||
|
* decomposition bases of the bootstrapping key B_bsk.
|
||||||
|
* The i-th set contains vectors with l_bsk[i] complex vectors.
|
||||||
|
* These complex vectors are an encryption of some bit of the secret key
|
||||||
|
* of the base scheme in the NGS form.
|
||||||
|
*/
|
||||||
|
typedef std::vector<std::vector<NGSFFTctxt>> BSKey_LWE;
|
||||||
|
|
||||||
|
// key-switching key from NTRU to NTRU
|
||||||
|
typedef ModQMatrix KSKey_NTRU;
|
||||||
|
|
||||||
|
// key-switching key from NTRU to LWE
|
||||||
|
typedef struct{
|
||||||
|
ModQMatrix A;
|
||||||
|
std::vector<int> b;
|
||||||
|
} KSKey_LWE;
|
||||||
|
|
||||||
|
|
||||||
|
class KeyGen
|
||||||
|
{
|
||||||
|
Param param;
|
||||||
|
Sampler sampler;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
KeyGen(Param _param): param(_param), sampler(_param)
|
||||||
|
{}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a secret key of the bootstrapping scheme.
|
||||||
|
* @param[out] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_sk_boot(SKey_boot& sk_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a secret key of the base scheme.
|
||||||
|
* @param[out] sk_base secret key of the base scheme.
|
||||||
|
*/
|
||||||
|
void get_sk_base(SKey_base_NTRU& sk_base);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a secret key of the base scheme.
|
||||||
|
* @param[out] sk_base secret key of the base scheme.
|
||||||
|
*/
|
||||||
|
void get_sk_base(SKey_base_LWE& sk_base);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a key-switching key from the bootstrapping scheme to the base scheme.
|
||||||
|
* @param[out] ksk key-switching key.
|
||||||
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a key-switching key from the bootstrapping scheme to the base scheme.
|
||||||
|
* @param[out] ksk key-switching key.
|
||||||
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a bootstrapping key
|
||||||
|
* @param[out] bsk bootstrapping key.
|
||||||
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a bootstrapping key
|
||||||
|
* @param[out] bsk bootstrapping key.
|
||||||
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a bootstrapping key
|
||||||
|
* @param[out] bsk bootstrapping key.
|
||||||
|
* @param[in] sk_base secret key of the base scheme.
|
||||||
|
* @param[in] sk_boot secret key of the bootstrapping scheme.
|
||||||
|
*/
|
||||||
|
void get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
153
include/lwehe.h
Normal file
153
include/lwehe.h
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#ifndef LWEHE
|
||||||
|
#define LWEHE
|
||||||
|
|
||||||
|
#include "params.h"
|
||||||
|
#include "keygen.h"
|
||||||
|
|
||||||
|
class Ctxt_LWE
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
std::vector<int> a;
|
||||||
|
int b;
|
||||||
|
|
||||||
|
Ctxt_LWE()
|
||||||
|
{
|
||||||
|
a.clear();
|
||||||
|
a.resize(parLWE.n);
|
||||||
|
}
|
||||||
|
Ctxt_LWE(const Ctxt_LWE& ct);
|
||||||
|
Ctxt_LWE& operator=(const Ctxt_LWE& ct);
|
||||||
|
|
||||||
|
Ctxt_LWE operator +(const Ctxt_LWE& ct) const;
|
||||||
|
Ctxt_LWE operator -(const Ctxt_LWE& ct) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
Ctxt_LWE operator -(const int c, const Ctxt_LWE& ct);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given ciphertext to a given modulus.
|
||||||
|
* @param[in,out] ct ciphertext
|
||||||
|
* @param[in] old_q old modulus
|
||||||
|
* @param[in] new_q old modulus
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_lwe(Ctxt_LWE& ct, int old_q, int new_q)
|
||||||
|
{
|
||||||
|
std::vector<int>& a = ct.a;
|
||||||
|
for (size_t i = 0; i < a.size(); i++)
|
||||||
|
a[i] = int((a[i]*new_q)/old_q);
|
||||||
|
|
||||||
|
ct.b = int((ct.b*new_q)/old_q);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given polynomial from q_base to modulus 2*N.
|
||||||
|
* @param[in,out] poly polynomial
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_to_boot(Ctxt_LWE& poly)
|
||||||
|
{
|
||||||
|
modulo_switch_lwe(poly, parLWE.q_base, Param::N2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given polynomial from q_boot to q_base.
|
||||||
|
* @param[in,out] poly polynomial
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_to_base_lwe(ModQPoly& poly)
|
||||||
|
{
|
||||||
|
modulo_switch(poly, q_boot, parLWE.q_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the external product of a given polynomial ciphertext
|
||||||
|
* with an NGS ciphertext in the FFT form
|
||||||
|
* @param[in,out] poly polynomial ciphertext
|
||||||
|
* @param[in] poly_vector NGS ciphertext
|
||||||
|
* @param[in] b decomposition base, power of 2
|
||||||
|
* @param[in] shift bit shift to divide by b
|
||||||
|
* @param[in] l decomposition length
|
||||||
|
*/
|
||||||
|
void external_product(std::vector<long>& res, const std::vector<int>& poly, const std::vector<FFTPoly>& poly_vector, int b, int shift, int l);
|
||||||
|
|
||||||
|
class SchemeLWE
|
||||||
|
{
|
||||||
|
SKey_base_LWE sk_base;
|
||||||
|
SKey_boot sk_boot;
|
||||||
|
KSKey_LWE ksk;
|
||||||
|
BSKey_LWE bk;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
SchemeLWE()
|
||||||
|
{
|
||||||
|
KeyGen keygen(parLWE);
|
||||||
|
//sampler = Sampler(param);
|
||||||
|
|
||||||
|
keygen.get_sk_base(sk_base);
|
||||||
|
keygen.get_sk_boot(sk_boot);
|
||||||
|
keygen.get_ksk(ksk,sk_base,sk_boot);
|
||||||
|
keygen.get_bsk(bk,sk_base,sk_boot);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Encrypts a bit using LWE.
|
||||||
|
* @param[out] ct ciphertext encrypting the input bit
|
||||||
|
* @param[in] m bit to encrypt
|
||||||
|
*/
|
||||||
|
void encrypt(Ctxt_LWE& ct, int m) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decrypts a bit using LWE.
|
||||||
|
* @param[out] ct ciphertext encrypting a bit
|
||||||
|
* @return b bit
|
||||||
|
*/
|
||||||
|
int decrypt(const Ctxt_LWE& ct) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs key switching of a given ciphertext from a polynomial NTRU
|
||||||
|
* to LWE
|
||||||
|
* @param[out] ct LWE ciphertext (vector of dimension n)
|
||||||
|
* @param[in] poly polynomial ciphertext (vector of dimension N)
|
||||||
|
*/
|
||||||
|
void key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bootstrapps a given ciphertext
|
||||||
|
* @param[in,out] ct ciphertext to bootstrap
|
||||||
|
*/
|
||||||
|
void bootstrap(Ctxt_LWE& ct) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bootstrapps a given ciphertext
|
||||||
|
* @param[in,out] ct ciphertext to bootstrap
|
||||||
|
*/
|
||||||
|
void bootstrap2(Ctxt_LWE& ct) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the NAND gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void nand_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the AND gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void and_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the OR gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void or_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
170
include/ntruhe.h
Normal file
170
include/ntruhe.h
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
#ifndef NTRUHE
|
||||||
|
#define NTRUHE
|
||||||
|
|
||||||
|
#include "params.h"
|
||||||
|
#include "keygen.h"
|
||||||
|
|
||||||
|
class Ctxt_NTRU
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
std::vector<int> data;
|
||||||
|
|
||||||
|
Ctxt_NTRU()
|
||||||
|
{
|
||||||
|
data.clear();
|
||||||
|
data.resize(parNTRU.n);
|
||||||
|
}
|
||||||
|
Ctxt_NTRU(const Ctxt_NTRU& ct);
|
||||||
|
Ctxt_NTRU& operator=(const Ctxt_NTRU& ct);
|
||||||
|
|
||||||
|
Ctxt_NTRU operator +(const Ctxt_NTRU& ct) const;
|
||||||
|
Ctxt_NTRU operator -(const Ctxt_NTRU& ct) const;
|
||||||
|
void operator -=(const Ctxt_NTRU& ct);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given ciphertext to a given modulus.
|
||||||
|
* @param[in,out] ct ciphertext
|
||||||
|
* @param[in] old_q old modulus
|
||||||
|
* @param[in] new_q old modulus
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_ntru(Ctxt_NTRU& ct, int old_q, int new_q)
|
||||||
|
{
|
||||||
|
std::vector<int>& a = ct.data;
|
||||||
|
for (size_t i = 0; i < a.size(); i++)
|
||||||
|
a[i] = int((a[i]*new_q)/old_q);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given polynomial from q_base to modulus 2*N.
|
||||||
|
* @param[in,out] poly polynomial
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_to_boot(Ctxt_NTRU& poly)
|
||||||
|
{
|
||||||
|
modulo_switch_ntru(poly, parNTRU.q_base, Param::N2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given polynomial from q_boot to q_base.
|
||||||
|
* @param[in,out] poly polynomial
|
||||||
|
*/
|
||||||
|
inline void modulo_switch_to_base_ntru(ModQPoly& poly)
|
||||||
|
{
|
||||||
|
modulo_switch(poly, q_boot, parNTRU.q_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the external product of a given polynomial ciphertext
|
||||||
|
* with an NGS ciphertext in the FFT form
|
||||||
|
* @param[in,out] poly polynomial ciphertext
|
||||||
|
* @param[in] poly_vector NGS ciphertext
|
||||||
|
* @param[in] b decomposition base, power of 2
|
||||||
|
* @param[in] shift bit shift to divide by b
|
||||||
|
* @param[in] l decomposition length
|
||||||
|
*/
|
||||||
|
//void external_product(std::vector<long>& res, const std::vector<int>& poly, const std::vector<FFTPoly>& poly_vector, const int b, const int shift, const int l);
|
||||||
|
|
||||||
|
class SchemeNTRU
|
||||||
|
{
|
||||||
|
SKey_base_NTRU sk_base;
|
||||||
|
SKey_boot sk_boot;
|
||||||
|
KSKey_NTRU ksk;
|
||||||
|
BSKey_NTRU bk;
|
||||||
|
|
||||||
|
Ctxt_NTRU ct_nand_const;
|
||||||
|
Ctxt_NTRU ct_and_const;
|
||||||
|
Ctxt_NTRU ct_or_const;
|
||||||
|
|
||||||
|
void mask_constant(Ctxt_NTRU& ct, int constant);
|
||||||
|
|
||||||
|
inline void set_nand_const()
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
mask_constant(ct_nand_const, parNTRU.nand_const);
|
||||||
|
//cout << "Encryption of NAND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void set_and_const()
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
mask_constant(ct_and_const, parNTRU.and_const);
|
||||||
|
//cout << "Encryption of AND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void set_or_const()
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
mask_constant(ct_or_const, parNTRU.or_const);
|
||||||
|
//cout << "Encryption of OR: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
SchemeNTRU()
|
||||||
|
{
|
||||||
|
KeyGen keygen(parNTRU);
|
||||||
|
//sampler = Sampler(param);
|
||||||
|
|
||||||
|
keygen.get_sk_base(sk_base);
|
||||||
|
keygen.get_sk_boot(sk_boot);
|
||||||
|
keygen.get_ksk(ksk,sk_base,sk_boot);
|
||||||
|
keygen.get_bsk(bk,sk_base,sk_boot);
|
||||||
|
|
||||||
|
set_nand_const();
|
||||||
|
set_and_const();
|
||||||
|
set_or_const();
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Encrypts a bit using matrix NTRU.
|
||||||
|
* @param[out] ct ciphertext encrypting the input bit
|
||||||
|
* @param[in] b bit to encrypt
|
||||||
|
*/
|
||||||
|
void encrypt(Ctxt_NTRU& ct, const int b) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decrypts a bit using matrix NTRU.
|
||||||
|
* @param[out] ct ciphertext encrypting a bit
|
||||||
|
* @return b bit
|
||||||
|
*/
|
||||||
|
int decrypt(const Ctxt_NTRU& ct) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs key switching of a given ciphertext from a polynomial NTRU
|
||||||
|
* to a matrix NTRU
|
||||||
|
* @param[out] ct matrix NTRU ciphertext (vector of dimension n)
|
||||||
|
* @param[in] poly polynomial ciphertext (vector of dimension N)
|
||||||
|
*/
|
||||||
|
void key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bootstrapps a given ciphertext
|
||||||
|
* @param[in,out] ct ciphertext to bootstrap
|
||||||
|
*/
|
||||||
|
void bootstrap(Ctxt_NTRU& ct) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the NAND gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void nand_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the AND gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void and_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the OR gate of two given ciphertexts ct1 and ct2
|
||||||
|
* @param[out] ct_res encryptions of the outuput of the NAND gate
|
||||||
|
* @param[in] ct_1 encryption of the first input bit
|
||||||
|
* @param[in] ct_2 encryption of the second input bit
|
||||||
|
*/
|
||||||
|
void or_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
333
include/params.h
Normal file
333
include/params.h
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
#ifndef PARAMS
|
||||||
|
#define PARAMS
|
||||||
|
|
||||||
|
#include <NTL/ZZ_pX.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
using namespace NTL;
|
||||||
|
|
||||||
|
enum SchemeType {NTRU, LWE};
|
||||||
|
|
||||||
|
// representation of a polynomial modulo some integer
|
||||||
|
typedef std::vector<int> ModQPoly;
|
||||||
|
// matrix modulo some integer
|
||||||
|
typedef std::vector<std::vector<int>> ModQMatrix;
|
||||||
|
// representation of an FFT transformation of some poly
|
||||||
|
typedef std::vector<std::complex<double>> FFTPoly;
|
||||||
|
// NGS ciphertest in NTT form
|
||||||
|
typedef std::vector<FFTPoly> NGSFFTctxt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q in the symmetric interval (-q/2, q/2]
|
||||||
|
* Correct only for odd q.
|
||||||
|
* @param[in] input integer to reduce.
|
||||||
|
* @param[in] q modulus
|
||||||
|
* @param[in] half_q (q-1)/2
|
||||||
|
* @returns reduced integer in the symmetric interval (-q/2, q/2]
|
||||||
|
*/
|
||||||
|
inline long lazy_mod_q(const long input, const long q, const long half_q)
|
||||||
|
{
|
||||||
|
int coef = input%q;
|
||||||
|
if (coef > half_q)
|
||||||
|
return coef - q;
|
||||||
|
if (coef < -half_q)
|
||||||
|
return coef + q;
|
||||||
|
return coef;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int lazy_mod_q(const int input, const int q, const int half_q)
|
||||||
|
{
|
||||||
|
int coef = input%q;
|
||||||
|
if (coef > half_q)
|
||||||
|
return coef - q;
|
||||||
|
if (coef < -half_q)
|
||||||
|
return coef + q;
|
||||||
|
return coef;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void lazy_mod_q(std::vector<int>& output, const std::vector<long>& input, const long q, const long half_q)
|
||||||
|
{
|
||||||
|
assert(output.size() == input.size());
|
||||||
|
|
||||||
|
std::vector<int>::iterator oit = output.begin();
|
||||||
|
for (auto iit = input.begin(); iit < input.end(); iit++, oit++)
|
||||||
|
*oit = static_cast<int>(lazy_mod_q(*iit, q, half_q));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** ciphertext modulus of the ring-based scheme used
|
||||||
|
* for bootstrapping keys and test vectors/lookup-table encodings
|
||||||
|
*/
|
||||||
|
const int q_boot = 912829; // ~2^19.8, prime
|
||||||
|
const long q_boot_long = long(q_boot);
|
||||||
|
|
||||||
|
//half of the above modulus
|
||||||
|
const int half_q_boot = q_boot/2;
|
||||||
|
const long half_q_boot_long = long(half_q_boot);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2]
|
||||||
|
* Correct only for odd q_boot.
|
||||||
|
* @param[in] input integer to reduce.
|
||||||
|
* @returns reduced integer in the symmetric interval (-q_boot/2, q_boot/2]
|
||||||
|
*/
|
||||||
|
inline int mod_q_boot(const int input)
|
||||||
|
{
|
||||||
|
return lazy_mod_q(input, q_boot, half_q_boot);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int mod_q_boot(const long input)
|
||||||
|
{
|
||||||
|
return lazy_mod_q(input, long(q_boot), long(half_q_boot));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2]
|
||||||
|
* Correct only for odd q_boot.
|
||||||
|
* @param[in,out] input vector to reduce.
|
||||||
|
* @param[in] q integer modulus.
|
||||||
|
*/
|
||||||
|
inline void mod_q_boot(std::vector<int>& input)
|
||||||
|
{
|
||||||
|
for (auto it = input.begin(); it < input.end(); it++)
|
||||||
|
*it = mod_q_boot(*it);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2]
|
||||||
|
* Correct only for odd q_boot.
|
||||||
|
* @param[in,out] input vector to reduce.
|
||||||
|
* @param[in] q integer modulus.
|
||||||
|
*/
|
||||||
|
inline void mod_q_boot(std::vector<int>& output, std::vector<long>& input)
|
||||||
|
{
|
||||||
|
lazy_mod_q(output, input, q_boot_long, half_q_boot_long);
|
||||||
|
}
|
||||||
|
|
||||||
|
class Param
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
// decomposition base of key-switching keys (DO NOT CHANGE!)
|
||||||
|
const static int B_ksk = 3;
|
||||||
|
|
||||||
|
// dimension of bootstrapping keys and test vectors/lookup-table encodings
|
||||||
|
const static int N = 1024;
|
||||||
|
// N/2 + 1, needed for FFT
|
||||||
|
const static int N2p1 = N/2+1;
|
||||||
|
// order of the cyclotomic ring, which is used for the bootstrapping scheme
|
||||||
|
const static int N2 = 2*N;
|
||||||
|
|
||||||
|
static ZZ_pX get_def_poly()
|
||||||
|
{
|
||||||
|
ZZ_pX poly;
|
||||||
|
// polynomial modulus of the ring-based scheme
|
||||||
|
//element of Z_(q_boot)
|
||||||
|
ZZ_p coef;
|
||||||
|
coef.init(ZZ(q_boot));
|
||||||
|
//polynomial modulus X^N+1
|
||||||
|
coef = 1;
|
||||||
|
SetCoeff(poly, 0, coef);
|
||||||
|
SetCoeff(poly, N, coef);
|
||||||
|
|
||||||
|
return poly;
|
||||||
|
}
|
||||||
|
|
||||||
|
ZZ_pX xToNplus1;
|
||||||
|
|
||||||
|
const static int B_bsk_size = 2;
|
||||||
|
|
||||||
|
// plaintext modulus
|
||||||
|
const static int t = 4;
|
||||||
|
|
||||||
|
// Delta scalar used in bootstrapping
|
||||||
|
const static int half_delta_boot = q_boot/(2*t);
|
||||||
|
|
||||||
|
// standard deviation for discrete Gaussian distribution
|
||||||
|
constexpr static double e_st_dev = 4.39;
|
||||||
|
|
||||||
|
// Type of the base scheme
|
||||||
|
SchemeType scheme_type;
|
||||||
|
|
||||||
|
// ciphertext modulus of the base scheme used for encryption
|
||||||
|
int q_base;
|
||||||
|
//half of the above modulus
|
||||||
|
int half_q_base;
|
||||||
|
// dimension of the ciphertext space
|
||||||
|
int n;
|
||||||
|
|
||||||
|
// decomposition base of key-switching keys
|
||||||
|
int l_ksk;
|
||||||
|
|
||||||
|
// number of rows of key-switching key matrices
|
||||||
|
int Nl;
|
||||||
|
|
||||||
|
// decomposition bases of bootstrapping keys
|
||||||
|
int B_bsk[B_bsk_size];
|
||||||
|
// binary logarithms of decomposition bases
|
||||||
|
int shift_bsk[B_bsk_size];
|
||||||
|
// partition of bootstrapping keys per decomposition base
|
||||||
|
int bsk_partition[B_bsk_size];
|
||||||
|
|
||||||
|
// decomposition lengths of bootstrapping keys
|
||||||
|
int l_bsk[2];
|
||||||
|
|
||||||
|
// Delta scalars used in encryption
|
||||||
|
int half_delta_base;
|
||||||
|
int delta_base;
|
||||||
|
|
||||||
|
// NAND constant
|
||||||
|
int nand_const;
|
||||||
|
// AND constant
|
||||||
|
int and_const;
|
||||||
|
// OR constant
|
||||||
|
int or_const;
|
||||||
|
|
||||||
|
void init()
|
||||||
|
{
|
||||||
|
if (scheme_type == SchemeType::NTRU)
|
||||||
|
{
|
||||||
|
q_base = 131071; // ~2^17, prime
|
||||||
|
n = 800;
|
||||||
|
|
||||||
|
B_bsk[0] = 8;
|
||||||
|
B_bsk[1] = 16;
|
||||||
|
shift_bsk[0] = 3;
|
||||||
|
shift_bsk[1] = 4;
|
||||||
|
bsk_partition[0] = 750;
|
||||||
|
bsk_partition[1] = 50;
|
||||||
|
}
|
||||||
|
else if (scheme_type == SchemeType::LWE)
|
||||||
|
{
|
||||||
|
q_base = 92683; // ~2^16.5, prime
|
||||||
|
n = 610;
|
||||||
|
|
||||||
|
B_bsk[0] = 8;
|
||||||
|
B_bsk[1] = 16;
|
||||||
|
shift_bsk[0] = 3;
|
||||||
|
shift_bsk[1] = 4;
|
||||||
|
bsk_partition[0] = 140;
|
||||||
|
bsk_partition[1] = 470;
|
||||||
|
}
|
||||||
|
|
||||||
|
half_q_base = q_base/2;
|
||||||
|
l_ksk = int(ceil(log(double(q_base))/log(double(B_ksk))));
|
||||||
|
Nl = N * l_ksk;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < 2; i++)
|
||||||
|
l_bsk[i] = int(ceil(log(double(q_boot))/log(double(B_bsk[i]))));
|
||||||
|
|
||||||
|
half_delta_base = q_base/(2*t);
|
||||||
|
delta_base = 2*half_delta_base;
|
||||||
|
|
||||||
|
nand_const = 5*half_delta_base;
|
||||||
|
and_const = half_delta_base;
|
||||||
|
or_const = 7*half_delta_base;
|
||||||
|
|
||||||
|
xToNplus1 = get_def_poly();
|
||||||
|
}
|
||||||
|
Param(){};
|
||||||
|
Param(SchemeType _scheme_type): scheme_type(_scheme_type)
|
||||||
|
{
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
Param(const Param ¶m): scheme_type(param.scheme_type)
|
||||||
|
{
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
Param& operator=(const Param& param)
|
||||||
|
{
|
||||||
|
scheme_type = param.scheme_type;
|
||||||
|
init();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator == (const Param& param) const
|
||||||
|
{
|
||||||
|
return this->scheme_type == param.scheme_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q_base in the symmetric interval (-q_base/2, q_base/2]
|
||||||
|
* Correct only for odd q_base.
|
||||||
|
* @param[in] input integer to reduce.
|
||||||
|
* @returns reduced integer in the symmetric interval (-q_base/2, q_base/2]
|
||||||
|
*/
|
||||||
|
inline int mod_q_base(const int input) const
|
||||||
|
{
|
||||||
|
return lazy_mod_q(input, q_base, half_q_base);
|
||||||
|
}
|
||||||
|
inline int mod_q_base(const long input) const
|
||||||
|
{
|
||||||
|
return lazy_mod_q(input, long(q_base), long(half_q_base));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reduction modulo q_base in the symmetric interval (-q_base/2, q_base/2]
|
||||||
|
* Correct only for odd q_base.
|
||||||
|
* @param[in,out] input vector to reduce.
|
||||||
|
*/
|
||||||
|
inline void mod_q_base(std::vector<int>& input) const
|
||||||
|
{
|
||||||
|
for (auto it = input.begin(); it < input.end(); it++)
|
||||||
|
*it = mod_q_base(*it);
|
||||||
|
}
|
||||||
|
inline void mod_q_base(std::vector<int>& output, std::vector<long>& input) const
|
||||||
|
{
|
||||||
|
output.resize(input.size());
|
||||||
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
|
output[i] = mod_q_base(input[i]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const Param parLWE(LWE);
|
||||||
|
const Param parNTRU(NTRU);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Switches a given polynomial to a given modulus.
|
||||||
|
* @param[in,out] poly polynomial
|
||||||
|
* @param[in] old_q old modulus
|
||||||
|
* @param[in] new_q old modulus
|
||||||
|
*/
|
||||||
|
inline void modulo_switch(ModQPoly& poly, int old_q, int new_q)
|
||||||
|
{
|
||||||
|
double ratio = double(new_q)/double(old_q);
|
||||||
|
for (auto it = poly.begin(); it < poly.end(); it++)
|
||||||
|
*it = int(round(double(*it)*ratio));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Balanced decomposition of an integer in base b.
|
||||||
|
* The length of decomposition must be l.
|
||||||
|
* @param[out] res decomposition vector
|
||||||
|
* @param[in] input integer to decompose
|
||||||
|
* @param[in] b decomposition base
|
||||||
|
* @param[in] l decomposition length
|
||||||
|
*/
|
||||||
|
inline void decompose(std::vector<int>& res, const int input, const int b, const int l)
|
||||||
|
{
|
||||||
|
res.clear();
|
||||||
|
|
||||||
|
int input_sign = (input < 0) ? -1: 1;
|
||||||
|
int input_rem = abs(input);
|
||||||
|
for (int i=0; i<l; i++)
|
||||||
|
{
|
||||||
|
int digit = input_rem % b;
|
||||||
|
int digit2 = 2*digit;
|
||||||
|
if (digit2 > b)
|
||||||
|
{
|
||||||
|
res.push_back(input_sign * (digit - b));
|
||||||
|
input_rem = (input_rem - digit)/b + 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
res.push_back(input_sign * digit);
|
||||||
|
input_rem = (input_rem - digit)/b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (input_rem != 0)
|
||||||
|
throw std::overflow_error("Input is too big for given length\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
111
include/sampler.h
Normal file
111
include/sampler.h
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
#ifndef SAMPLER
|
||||||
|
#define SAMPLER
|
||||||
|
|
||||||
|
#include <NTL/mat_ZZ.h>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
#include <chrono>
|
||||||
|
#include "params.h"
|
||||||
|
|
||||||
|
using namespace NTL;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
// random engine
|
||||||
|
static default_random_engine rand_engine(std::chrono::system_clock::now().time_since_epoch().count());
|
||||||
|
// uniform distribution on the ternary set
|
||||||
|
static uniform_int_distribution<int> ternary_sampler(-1,1);
|
||||||
|
// uniform distribution on the binary set
|
||||||
|
static uniform_int_distribution<int> binary_sampler(0,1);
|
||||||
|
|
||||||
|
class Sampler
|
||||||
|
{
|
||||||
|
Param param;
|
||||||
|
uniform_int_distribution<int> mod_q_base_sampler;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Sampler(Param _param): param(_param)
|
||||||
|
{
|
||||||
|
mod_q_base_sampler = uniform_int_distribution<int>(-param.half_q_base, param.half_q_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a uniformly random vector modulo q_base.
|
||||||
|
*
|
||||||
|
* @param[out] vec vector with uniformly random coefficients.
|
||||||
|
*/
|
||||||
|
void get_uniform_vector(vector<int>& vec);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a uniformly random matrix modulo q_base.
|
||||||
|
*
|
||||||
|
* @param[out] mat matrix with uniformly random coefficients.
|
||||||
|
*/
|
||||||
|
void get_uniform_matrix(vector<vector<int>>& mat);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a matrix of the form scale*M+shift*I
|
||||||
|
* where M has uniformly random ternary coefficients
|
||||||
|
* and I is an identity matrix.
|
||||||
|
* This matrix must be invertible
|
||||||
|
* @param[out] mat random matrix of the above form.
|
||||||
|
* @param[out] mat_inv inverse matrix.
|
||||||
|
* @param[in] scale scale in the above form.
|
||||||
|
* @param[in] shift shift in the above form
|
||||||
|
*/
|
||||||
|
void get_invertible_matrix(vector<vector<int>>& mat, vector<vector<int>>& mat_inv, int scale, int shift);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a uniformly random matrix with ternary coefficients.
|
||||||
|
*
|
||||||
|
* @param[out] mat matrix with ternary coefficients.
|
||||||
|
*/
|
||||||
|
static void get_ternary_matrix(vector<vector<int>>& mat);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a uniformly random vector with ternary coefficients.
|
||||||
|
*
|
||||||
|
* @param[out] vec vector with ternary coefficients.
|
||||||
|
*/
|
||||||
|
static void get_ternary_vector(vector<int>& vec);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a uniformly random vector with binary coefficients.
|
||||||
|
*
|
||||||
|
* @param[out] vec vector with binary coefficients.
|
||||||
|
*/
|
||||||
|
static void get_binary_vector(vector<int>& vec);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a random matrix with coefficients distributed
|
||||||
|
* according to the discrete Gaussian distribution
|
||||||
|
* with zero mean and standard deviation st_dev.
|
||||||
|
* @param[out] mat random matrix.
|
||||||
|
* @param[in] st_dev standard deviation.
|
||||||
|
*/
|
||||||
|
static void get_gaussian_matrix(vector<vector<int>>& mat, double st_dev);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a random vector with coefficients distributed
|
||||||
|
* according to the discrete Gaussian distribution
|
||||||
|
* with zero mean and standard deviation st_dev.
|
||||||
|
* @param[out] vec random vector.
|
||||||
|
* @param[in] st_dev standard deviation.
|
||||||
|
*/
|
||||||
|
static void get_gaussian_vector(vector<int>& vec, double st_dev);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a vector of the form scale*v+shift
|
||||||
|
* where v has uniformly random ternary coefficients.
|
||||||
|
* The polynomial with the above vector of coefficients
|
||||||
|
* must be invertible modulo X^N+1 and q_boot.
|
||||||
|
* @param[out] vec random vector of the above form.
|
||||||
|
* @param[out] vec_inv coefficient vector of the polynomial inverse modulo X^N+1.
|
||||||
|
* @param[in] scale scale in the above form.
|
||||||
|
* @param[in] shift shift in the above form
|
||||||
|
*/
|
||||||
|
void get_invertible_vector(vector<int>& vec, vector<int>& vec_inv, int scale, int shift);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
175
src/fft.cpp
Normal file
175
src/fft.cpp
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
#include "fft.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iterator>
|
||||||
|
|
||||||
|
FFT_engine::FFT_engine(const int dim): fft_dim(dim)
|
||||||
|
{
|
||||||
|
assert(dim%2 == 0);
|
||||||
|
|
||||||
|
fft_dim2 = (dim >> 1) + 1;
|
||||||
|
|
||||||
|
in_array = (double*) fftw_malloc(sizeof(double) * 2*dim);
|
||||||
|
out_array = (fftw_complex*) fftw_malloc(sizeof(fftw_complex) * (dim + 2));
|
||||||
|
plan_to_fft = fftw_plan_dft_r2c_1d(2*dim, in_array, out_array, FFTW_PATIENT);
|
||||||
|
plan_from_fft = fftw_plan_dft_c2r_1d(2*dim, out_array, in_array, FFTW_PATIENT);
|
||||||
|
|
||||||
|
pos_powers = vector<FFTPoly>(dim,FFTPoly(fft_dim2));
|
||||||
|
neg_powers = vector<FFTPoly>(dim,FFTPoly(fft_dim2));
|
||||||
|
for(int i = 0; i < dim; i++)
|
||||||
|
{
|
||||||
|
ModQPoly x_power(dim,0);
|
||||||
|
//x_power[0] = -1;
|
||||||
|
x_power[i] += 1;
|
||||||
|
FFTPoly x_power_fft(fft_dim2);
|
||||||
|
to_fft(x_power_fft, x_power);
|
||||||
|
pos_powers[i] = x_power_fft;
|
||||||
|
|
||||||
|
x_power[i] -= 2;
|
||||||
|
to_fft(x_power_fft, x_power);
|
||||||
|
neg_powers[i] = x_power_fft;
|
||||||
|
}
|
||||||
|
//x_powers.insert({{-1,neg_powers}, {1,pos_powers}});
|
||||||
|
}
|
||||||
|
|
||||||
|
void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const
|
||||||
|
{
|
||||||
|
assert(out.size() == fft_dim2);
|
||||||
|
|
||||||
|
double* in_arr = in_array;
|
||||||
|
fftw_complex* out_arr = out_array;
|
||||||
|
int N = fft_dim;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
in_arr[i] = double(in[i]);
|
||||||
|
in_arr[i+N] = 0.0;
|
||||||
|
}
|
||||||
|
fftw_execute(plan_to_fft);
|
||||||
|
int tmp = 1;
|
||||||
|
//for (int i = 0; i < fft_dim2; i++)
|
||||||
|
for (auto it = out.begin(); it < out.end(); ++it)
|
||||||
|
{
|
||||||
|
fftw_complex& out_z = out_arr[tmp];
|
||||||
|
complex<double>& outi = *it; //out[i];
|
||||||
|
outi.real(out_z[0]);
|
||||||
|
outi.imag(out_z[1]);
|
||||||
|
tmp += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FFT_engine::from_fft(vector<long>& out, const FFTPoly& in) const
|
||||||
|
{
|
||||||
|
int tmp = 0;
|
||||||
|
double* in_arr = in_array;
|
||||||
|
fftw_complex* out_arr = out_array;
|
||||||
|
int N = fft_dim;
|
||||||
|
int Nd = double(N);
|
||||||
|
|
||||||
|
//for (int i = 0; i < fft_dim2; ++i)
|
||||||
|
for (auto it = in.begin(); it < in.end(); ++it)
|
||||||
|
{
|
||||||
|
//std::cout << "i: " << i << ", number: " << in[i] << std::endl;
|
||||||
|
out_arr[tmp+1][0] = real(*it)/Nd;
|
||||||
|
out_arr[tmp+1][1] = imag(*it)/Nd;
|
||||||
|
out_arr[tmp][0] = 0.0;
|
||||||
|
out_arr[tmp][1] = 0.0;
|
||||||
|
tmp += 2;
|
||||||
|
}
|
||||||
|
fftw_execute(plan_from_fft);
|
||||||
|
out.resize(fft_dim);
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
out[i] = long(round(in_arr[i]));
|
||||||
|
//std::cout << "i: " << i << ", number: " << out[i] << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FFT_engine::~FFT_engine()
|
||||||
|
{
|
||||||
|
fftw_destroy_plan(plan_to_fft);
|
||||||
|
fftw_destroy_plan(plan_from_fft);
|
||||||
|
fftw_free(in_array);
|
||||||
|
fftw_free(out_array);
|
||||||
|
}
|
||||||
|
|
||||||
|
FFTPoly operator +(const FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
FFTPoly res(a.size());
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
res[i] = a[i]+b[i];
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator +=(FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
a[i]+=b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator +=(FFTPoly& a, const complex<double> b)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
a[i]+=b;
|
||||||
|
}
|
||||||
|
|
||||||
|
FFTPoly operator -(const FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
FFTPoly res(a.size());
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
res[i] = a[i]-b[i];
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator -=(FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
a[i]-=b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
FFTPoly operator *(const FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
FFTPoly res(a.size());
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
res[i] = a[i]*b[i];
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator *=(FFTPoly& a, const FFTPoly& b)
|
||||||
|
{
|
||||||
|
// check that input vectors have the same size
|
||||||
|
assert(a.size() == b.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
a[i]*=b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make a test
|
||||||
|
FFTPoly operator *(const FFTPoly& a, const int b)
|
||||||
|
{
|
||||||
|
FFTPoly res(a.size());
|
||||||
|
double bd = double(b);
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
res[i] = a[i] * bd;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
443
src/keygen.cpp
Normal file
443
src/keygen.cpp
Normal file
@ -0,0 +1,443 @@
|
|||||||
|
#include "keygen.h"
|
||||||
|
#include "params.h"
|
||||||
|
#include "fft.h"
|
||||||
|
|
||||||
|
#include<iostream>
|
||||||
|
#include<time.h>
|
||||||
|
#include<algorithm>
|
||||||
|
|
||||||
|
using namespace NTL;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
void KeyGen::get_sk_boot(SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
cout << "Started generating the secret key of the bootstrapping scheme" << endl;
|
||||||
|
clock_t start = clock();
|
||||||
|
sk_boot.sk = ModQPoly(Param::N,0);
|
||||||
|
sk_boot.sk_inv = ModQPoly(Param::N,0);
|
||||||
|
|
||||||
|
sampler.get_invertible_vector(sk_boot.sk, sk_boot.sk_inv, Param::t, 1L);
|
||||||
|
cout << "Generation time of the secret key of the bootstrapping scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_sk_base(SKey_base_NTRU& sk_base)
|
||||||
|
{
|
||||||
|
cout << "Started generating the secret key of the base scheme" << endl;
|
||||||
|
clock_t start = clock();
|
||||||
|
sk_base.sk = ModQMatrix(param.n, vector<int>(param.n,0L));
|
||||||
|
sk_base.sk_inv = ModQMatrix(param.n, vector<int>(param.n,0L));
|
||||||
|
|
||||||
|
sampler.get_invertible_matrix(sk_base.sk, sk_base.sk_inv, 1L, 0L);
|
||||||
|
cout << "Generation time of the secret key of the base scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_sk_base(SKey_base_LWE& sk_base)
|
||||||
|
{
|
||||||
|
cout << "Started generating the secret key of the base scheme" << endl;
|
||||||
|
clock_t start = clock();
|
||||||
|
sk_base.clear();
|
||||||
|
sk_base = vector<int>(param.n,0L);
|
||||||
|
|
||||||
|
sampler.get_binary_vector(sk_base);
|
||||||
|
cout << "Generation time of the secret key of the base scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
cout << "Started key-switching key generation" << endl;
|
||||||
|
clock_t start = clock();
|
||||||
|
// reset key-switching key
|
||||||
|
ksk.clear();
|
||||||
|
ksk = ModQMatrix(param.Nl, vector<int>(param.n,0));
|
||||||
|
vector<vector<long>> ksk_long(param.Nl, vector<long>(param.n,0L));
|
||||||
|
cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// noise matrix G as in the paper
|
||||||
|
ModQMatrix G(param.Nl, vector<int>(param.n,0L));
|
||||||
|
sampler.get_ternary_matrix(G);
|
||||||
|
cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// matrix G + P * Phi(f) * E as in the paper
|
||||||
|
int coef_w_pwr = sk_boot.sk[0];
|
||||||
|
for (int i = 0; i < param.l_ksk; i++)
|
||||||
|
{
|
||||||
|
G[i][0] += coef_w_pwr;
|
||||||
|
coef_w_pwr *= Param::B_ksk;
|
||||||
|
}
|
||||||
|
for (int i = 1; i < Param::N; i++)
|
||||||
|
{
|
||||||
|
coef_w_pwr = -sk_boot.sk[Param::N-i];
|
||||||
|
for (int j = 0; j < param.l_ksk; j++)
|
||||||
|
{
|
||||||
|
G[i*param.l_ksk+j][0] += coef_w_pwr;
|
||||||
|
coef_w_pwr *= Param::B_ksk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// parameters of the block optimization of matrix multiplication
|
||||||
|
int block = 4;
|
||||||
|
int blocks = (param.n/block)*block;
|
||||||
|
int rem_block = param.n%block;
|
||||||
|
// (G + P * Phi(f) * E) * F^(-1) as in the paper
|
||||||
|
for (int i = 0; i < param.Nl; i++)
|
||||||
|
{
|
||||||
|
//cout << "i: " << i << endl;
|
||||||
|
vector<long>& k_row = ksk_long[i];
|
||||||
|
vector<int>& g_row = G[i];
|
||||||
|
for (int k = 0; k < param.n; k++)
|
||||||
|
{
|
||||||
|
const vector<int>& f_row = sk_base.sk_inv[k];
|
||||||
|
//cout << "j: " << j << endl;
|
||||||
|
long coef = long(g_row[k]);
|
||||||
|
for (int j = 0; j < blocks; j+=block)
|
||||||
|
{
|
||||||
|
k_row[j] += (coef * f_row[j]);
|
||||||
|
k_row[j+1] += (coef * f_row[j+1]);
|
||||||
|
k_row[j+2] += (coef * f_row[j+2]);
|
||||||
|
k_row[j+3] += (coef * f_row[j+3]);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < rem_block; j++)
|
||||||
|
k_row[blocks+j] += (coef * f_row[blocks+j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// reduce modulo q_base
|
||||||
|
for (int i = 0; i < param.Nl; i++)
|
||||||
|
param.mod_q_base(ksk[i], ksk_long[i]);
|
||||||
|
cout << "KSKey-gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
cout << "Started key-switching key generation" << endl;
|
||||||
|
clock_t start = clock();
|
||||||
|
// reset key-switching key
|
||||||
|
ksk.A.clear();
|
||||||
|
ksk.b.clear();
|
||||||
|
for (int i = 0; i < param.Nl; i++)
|
||||||
|
{
|
||||||
|
vector<int> row(param.n,0L);
|
||||||
|
ksk.A.push_back(row);
|
||||||
|
}
|
||||||
|
ksk.b = vector<int>(param.Nl, 0L);
|
||||||
|
cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// noise matrix G as in the paper
|
||||||
|
sampler.get_uniform_matrix(ksk.A);
|
||||||
|
cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// matrix P * f_0 as in the paper
|
||||||
|
vector<int> Pf0(param.Nl, 0L);
|
||||||
|
int coef_w_pwr = sk_boot.sk[0];
|
||||||
|
for (int i = 0; i < param.l_ksk; i++)
|
||||||
|
{
|
||||||
|
Pf0[i] += coef_w_pwr;
|
||||||
|
coef_w_pwr *= Param::B_ksk;
|
||||||
|
}
|
||||||
|
for (int i = 1; i < Param::N; i++)
|
||||||
|
{
|
||||||
|
coef_w_pwr = -sk_boot.sk[Param::N-i];
|
||||||
|
for (int j = 0; j < param.l_ksk; j++)
|
||||||
|
{
|
||||||
|
Pf0[i*param.l_ksk+j] += coef_w_pwr;
|
||||||
|
coef_w_pwr *= Param::B_ksk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
// A*s_base + e + Pf0 as in the paper
|
||||||
|
normal_distribution<double> gaussian_sampler(0.0, Param::e_st_dev);
|
||||||
|
for (int i = 0; i < param.Nl; i++)
|
||||||
|
{
|
||||||
|
//cout << "i: " << i << endl;
|
||||||
|
vector<int>& k_row = ksk.A[i];
|
||||||
|
for (int k = 0; k < param.n; k++)
|
||||||
|
ksk.b[i] -= k_row[k] * sk_base[k];
|
||||||
|
ksk.b[i] += (Pf0[i] + static_cast<int>(round(gaussian_sampler(rand_engine))));
|
||||||
|
param.mod_q_base(ksk.b[i]);
|
||||||
|
}
|
||||||
|
cout << "KSKey-gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
int coef_counter = 0;
|
||||||
|
|
||||||
|
// reset the input
|
||||||
|
bsk.clear();
|
||||||
|
|
||||||
|
// transform the secret key of the bootstrapping scheme to the DFT domain
|
||||||
|
FFTPoly sk_boot_inv_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv);
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
coef_counter = 0;
|
||||||
|
// vector to keep the DFT transform of a random ternary vector
|
||||||
|
FFTPoly g_fft(Param::N2p1);
|
||||||
|
// vector to keep the DFT transform of a bootstrapping key part
|
||||||
|
FFTPoly tmp_bsk_fft(Param::N2p1);
|
||||||
|
// vector to keep a bootstrapping key part
|
||||||
|
vector<long> tmp_bsk(Param::N);
|
||||||
|
vector<int> tmp_bsk_int(Param::N);
|
||||||
|
// precompute FFT transformed powers of decomposition bases
|
||||||
|
vector<vector<FFTPoly>> B_bsk_pwr_poly;
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
double B_bsk_double = param.B_bsk[iBase];
|
||||||
|
vector<FFTPoly> base_row;
|
||||||
|
// FFT transform of (1,0,...,0)
|
||||||
|
FFTPoly tmp_fft(Param::N2p1,complex<double>(1.0, 0.0));
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(),
|
||||||
|
[B_bsk_double](complex<double> &z){ return z*B_bsk_double; });
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
}
|
||||||
|
B_bsk_pwr_poly.push_back(base_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over different decomposition bases
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
vector<vector<NGSFFTctxt>> base_row;
|
||||||
|
vector<FFTPoly>& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase];
|
||||||
|
for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef++)
|
||||||
|
{
|
||||||
|
vector<NGSFFTctxt> coef_row;
|
||||||
|
int sk_base_coef = sk_base.sk[iCoef][0];
|
||||||
|
/**
|
||||||
|
* represent coefficient of the secret key
|
||||||
|
* of the base scheme using 2 bits.
|
||||||
|
* The representation rule is as follows:
|
||||||
|
* -1 => [0,1]
|
||||||
|
* 0 => [0,0]
|
||||||
|
* 1 => [1,0]
|
||||||
|
* */
|
||||||
|
int coef_bits[2] = {0,0};
|
||||||
|
if (sk_base_coef == -1)
|
||||||
|
coef_bits[1] = 1;
|
||||||
|
else if (sk_base_coef == 1)
|
||||||
|
coef_bits[0] = 1;
|
||||||
|
// encrypt each bit using the NGS scheme
|
||||||
|
for (int iBit = 0; iBit < 2; iBit++)
|
||||||
|
{
|
||||||
|
NGSFFTctxt bit_row;
|
||||||
|
for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
// sample random ternary vector
|
||||||
|
ModQPoly g(Param::N,0L);
|
||||||
|
sampler.get_ternary_vector(g);
|
||||||
|
// FFT transform it
|
||||||
|
fftN.to_fft(g_fft, g);
|
||||||
|
// compute g * sk_boot^(-1)
|
||||||
|
tmp_bsk_fft = g_fft * sk_boot_inv_fft;
|
||||||
|
// compute g * sk_boot^(-1) + B^i * bit
|
||||||
|
if (coef_bits[iBit] == 1)
|
||||||
|
tmp_bsk_fft = B_bsk_pwr_poly_row[iPart] + tmp_bsk_fft;
|
||||||
|
// inverse FFT of the above result
|
||||||
|
fftN.from_fft(tmp_bsk, tmp_bsk_fft);
|
||||||
|
// reduction modulo q_boot
|
||||||
|
mod_q_boot(tmp_bsk_int, tmp_bsk);
|
||||||
|
// FFT transform for further use
|
||||||
|
fftN.to_fft(tmp_bsk_fft, tmp_bsk_int);
|
||||||
|
|
||||||
|
bit_row.push_back(tmp_bsk_fft);
|
||||||
|
}
|
||||||
|
coef_row.push_back(bit_row);
|
||||||
|
}
|
||||||
|
base_row.push_back(coef_row);
|
||||||
|
}
|
||||||
|
bsk.push_back(base_row);
|
||||||
|
coef_counter += param.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
int coef_counter = 0;
|
||||||
|
|
||||||
|
// reset the input
|
||||||
|
bsk.clear();
|
||||||
|
|
||||||
|
// transform the secret key of the bootstrapping scheme to the DFT domain
|
||||||
|
FFTPoly sk_boot_inv_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv);
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
coef_counter = 0;
|
||||||
|
// vector to keep the DFT transform of a random ternary vector
|
||||||
|
FFTPoly g_fft(Param::N2p1);
|
||||||
|
// vector to keep the DFT transform of a bootstrapping key part
|
||||||
|
FFTPoly tmp_bsk_fft(Param::N2p1);
|
||||||
|
// vector to keep a bootstrapping key part
|
||||||
|
ModQPoly tmp_bsk(Param::N);
|
||||||
|
vector<long> tmp_bsk_long;
|
||||||
|
// precompute FFT transformed powers of decomposition bases
|
||||||
|
vector<vector<FFTPoly>> B_bsk_pwr_poly;
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
double B_bsk_double = param.B_bsk[iBase];
|
||||||
|
vector<FFTPoly> base_row;
|
||||||
|
// FFT transform of (1,0,...,0)
|
||||||
|
FFTPoly tmp_fft(Param::N2p1,complex<double>(1.0, 0.0));
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(),
|
||||||
|
[B_bsk_double](complex<double> &z){ return z*B_bsk_double; });
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
}
|
||||||
|
B_bsk_pwr_poly.push_back(base_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
bsk.clear();
|
||||||
|
bsk = vector<vector<NGSFFTctxt>>(Param::B_bsk_size);
|
||||||
|
for (int i = 0; i < Param::B_bsk_size; i++)
|
||||||
|
bsk[i] = vector<NGSFFTctxt>(param.bsk_partition[i], NGSFFTctxt(param.l_bsk[i], FFTPoly(Param::N2p1)));
|
||||||
|
|
||||||
|
ModQPoly g(Param::N,0L);
|
||||||
|
// loop over different decomposition bases
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
vector<NGSFFTctxt> base_row(param.bsk_partition[iBase], NGSFFTctxt(param.l_bsk[iBase], FFTPoly(Param::N2p1)));
|
||||||
|
vector<FFTPoly>& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase];
|
||||||
|
for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef++)
|
||||||
|
{
|
||||||
|
NGSFFTctxt coef_row(param.l_bsk[iBase], FFTPoly(Param::N2p1));
|
||||||
|
int sk_base_coef = sk_base[iCoef];
|
||||||
|
// encrypt each bit using the NGS scheme
|
||||||
|
for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
// sample random ternary vector
|
||||||
|
sampler.get_ternary_vector(g);
|
||||||
|
// FFT transform it
|
||||||
|
fftN.to_fft(g_fft, g);
|
||||||
|
// compute g * sk_boot^(-1)
|
||||||
|
tmp_bsk_fft = g_fft;
|
||||||
|
tmp_bsk_fft *= sk_boot_inv_fft;
|
||||||
|
// compute g * sk_boot^(-1) + B^i * bit
|
||||||
|
if (sk_base_coef == 1)
|
||||||
|
tmp_bsk_fft += B_bsk_pwr_poly_row[iPart];
|
||||||
|
// inverse FFT of the above result
|
||||||
|
fftN.from_fft(tmp_bsk_long, tmp_bsk_fft);
|
||||||
|
// reduction modulo q_boot
|
||||||
|
mod_q_boot(tmp_bsk, tmp_bsk_long);
|
||||||
|
// FFT transform for further use
|
||||||
|
fftN.to_fft(tmp_bsk_fft, tmp_bsk);
|
||||||
|
|
||||||
|
coef_row[iPart] = tmp_bsk_fft;
|
||||||
|
}
|
||||||
|
base_row[iCoef-coef_counter] = coef_row;
|
||||||
|
}
|
||||||
|
bsk[iBase] = base_row;
|
||||||
|
coef_counter += param.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeyGen::get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot)
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
int coef_counter = 0;
|
||||||
|
|
||||||
|
// reset the input
|
||||||
|
bsk.clear();
|
||||||
|
|
||||||
|
// transform the secret key of the bootstrapping scheme to the DFT domain
|
||||||
|
FFTPoly sk_boot_inv_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv);
|
||||||
|
|
||||||
|
// index of a secret key coefficient of the base scheme
|
||||||
|
coef_counter = 0;
|
||||||
|
// vector to keep the DFT transform of a random ternary vector
|
||||||
|
FFTPoly g_fft(Param::N2p1);
|
||||||
|
// vector to keep the DFT transform of a bootstrapping key part
|
||||||
|
FFTPoly tmp_bsk_fft(Param::N2p1);
|
||||||
|
// vector to keep a bootstrapping key part
|
||||||
|
ModQPoly tmp_bsk(Param::N);
|
||||||
|
vector<long> tmp_bsk_long;
|
||||||
|
// precompute FFT transformed powers of decomposition bases
|
||||||
|
vector<vector<FFTPoly>> B_bsk_pwr_poly;
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
double B_bsk_double = param.B_bsk[iBase];
|
||||||
|
vector<FFTPoly> base_row;
|
||||||
|
// FFT transform of (1,0,...,0)
|
||||||
|
FFTPoly tmp_fft(Param::N2p1,complex<double>(1.0, 0.0));
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(),
|
||||||
|
[B_bsk_double](complex<double> &z){ return z*B_bsk_double; });
|
||||||
|
base_row.push_back(tmp_fft);
|
||||||
|
}
|
||||||
|
B_bsk_pwr_poly.push_back(base_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
bsk.clear();
|
||||||
|
bsk = vector<vector<NGSFFTctxt>>(Param::B_bsk_size);
|
||||||
|
for (int i = 0; i < Param::B_bsk_size; i++)
|
||||||
|
bsk[i] = vector<NGSFFTctxt>(4 * (param.bsk_partition[i] >> 1), NGSFFTctxt(param.l_bsk[i], FFTPoly(Param::N2p1)));
|
||||||
|
|
||||||
|
ModQPoly g(Param::N,0L);
|
||||||
|
// loop over different decomposition bases
|
||||||
|
int bits[4];
|
||||||
|
for (int iBase = 0; iBase < Param::B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
vector<NGSFFTctxt> base_row(4 * (param.bsk_partition[iBase] >> 1), NGSFFTctxt(param.l_bsk[iBase], FFTPoly(Param::N2p1)));
|
||||||
|
vector<FFTPoly>& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase];
|
||||||
|
for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef+=2)
|
||||||
|
{
|
||||||
|
NGSFFTctxt coef_row(param.l_bsk[iBase], FFTPoly(Param::N2p1));
|
||||||
|
// bits to encrypt: s[coef]*s[coef+1], s[coef]*(1-s[coef+1]), (1-s[coef])*s[coef+1]
|
||||||
|
bits[0] = sk_base[iCoef]*sk_base[iCoef+1];
|
||||||
|
bits[1] = sk_base[iCoef]*(1-sk_base[iCoef+1]);
|
||||||
|
bits[2] = (1-sk_base[iCoef])*sk_base[iCoef+1];
|
||||||
|
bits[3] = (1-sk_base[iCoef])*(1-sk_base[iCoef+1]);
|
||||||
|
// encrypt each bit using the NGS scheme
|
||||||
|
for (int iBit = 0; iBit < 4; iBit++)
|
||||||
|
{
|
||||||
|
for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
// sample random ternary vector
|
||||||
|
sampler.get_ternary_vector(g);
|
||||||
|
// FFT transform it
|
||||||
|
fftN.to_fft(g_fft, g);
|
||||||
|
// compute g * sk_boot^(-1)
|
||||||
|
tmp_bsk_fft = g_fft;
|
||||||
|
tmp_bsk_fft *= sk_boot_inv_fft;
|
||||||
|
// compute g * sk_boot^(-1) + B^i * bit
|
||||||
|
if (bits[iBit] == 1)
|
||||||
|
tmp_bsk_fft += B_bsk_pwr_poly_row[iPart];
|
||||||
|
// inverse FFT of the above result
|
||||||
|
fftN.from_fft(tmp_bsk_long, tmp_bsk_fft);
|
||||||
|
// reduction modulo q_boot
|
||||||
|
mod_q_boot(tmp_bsk, tmp_bsk_long);
|
||||||
|
// FFT transform for further use
|
||||||
|
fftN.to_fft(tmp_bsk_fft, tmp_bsk);
|
||||||
|
|
||||||
|
coef_row[iPart] = tmp_bsk_fft;
|
||||||
|
}
|
||||||
|
base_row[4*((iCoef-coef_counter) >> 1)+iBit] = coef_row;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bsk[iBase] = base_row;
|
||||||
|
coef_counter += param.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "Bootstrapping2 generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
549
src/lwehe.cpp
Normal file
549
src/lwehe.cpp
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
#include "lwehe.h"
|
||||||
|
#include "sampler.h"
|
||||||
|
#include "fft.h"
|
||||||
|
#include "ntruhe.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <vector>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
Ctxt_LWE::Ctxt_LWE(const Ctxt_LWE& ct)
|
||||||
|
{
|
||||||
|
a = ct.a;
|
||||||
|
b = ct.b;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_LWE& Ctxt_LWE::operator=(const Ctxt_LWE& ct)
|
||||||
|
{
|
||||||
|
a = ct.a;
|
||||||
|
b = ct.b;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_LWE Ctxt_LWE::operator +(const Ctxt_LWE& ct) const
|
||||||
|
{
|
||||||
|
Ctxt_LWE res;
|
||||||
|
for (size_t i = 0; i < parLWE.n; i++)
|
||||||
|
res.a[i] = parLWE.mod_q_base(a[i] + ct.a[i]);
|
||||||
|
|
||||||
|
res.b = parLWE.mod_q_base(b + ct.b);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_LWE Ctxt_LWE::operator -(const Ctxt_LWE& ct) const
|
||||||
|
{
|
||||||
|
Ctxt_LWE res;
|
||||||
|
for (size_t i = 0; i < parLWE.n; i++)
|
||||||
|
res.a[i] = parLWE.mod_q_base(a[i] - ct.a[i]);
|
||||||
|
|
||||||
|
res.b = parLWE.mod_q_base(b + ct.b);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_LWE operator -(const int c, const Ctxt_LWE& ct)
|
||||||
|
{
|
||||||
|
Ctxt_LWE res;
|
||||||
|
res.a = vector<int>(parLWE.n);
|
||||||
|
const vector<int>& a = ct.a;
|
||||||
|
for (size_t i = 0; i < a.size(); i++)
|
||||||
|
res.a[i] = parLWE.mod_q_base(-a[i]);
|
||||||
|
|
||||||
|
res.b = parLWE.mod_q_base(c-ct.b);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::encrypt(Ctxt_LWE& ct, int m) const
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
int n = parLWE.n;
|
||||||
|
|
||||||
|
vector<int> a(n,0L);
|
||||||
|
Sampler s(parLWE);
|
||||||
|
s.get_uniform_vector(a);
|
||||||
|
ct.a = a;
|
||||||
|
normal_distribution<double> gaussian_sampler(0.0, Param::e_st_dev);
|
||||||
|
int b = parLWE.delta_base*m + static_cast<int>(round(gaussian_sampler(rand_engine)));
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
b -= sk_base[i] * a[i];
|
||||||
|
}
|
||||||
|
parLWE.mod_q_base(b);
|
||||||
|
ct.b = b;
|
||||||
|
|
||||||
|
//cout << "Encryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SchemeLWE::decrypt(const Ctxt_LWE& ct) const
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
int output = ct.b;
|
||||||
|
for (int i = 0; i < parLWE.n; i++)
|
||||||
|
{
|
||||||
|
output += ct.a[i] * sk_base[i];
|
||||||
|
}
|
||||||
|
output = parLWE.mod_q_base(output);
|
||||||
|
output = int(round(double(output*Param::t)/double(parLWE.q_base)));
|
||||||
|
//cout << "Decryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void external_product(vector<long>& res, const vector<int>& poly, const vector<FFTPoly>& poly_vector, int b, int shift, int l)
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
|
||||||
|
ModQPoly poly_sign(N);
|
||||||
|
ModQPoly poly_abs(N);
|
||||||
|
vector<int> poly_decomp(N);
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
const int& polyi = poly[i];
|
||||||
|
poly_abs[i] = abs(polyi);
|
||||||
|
poly_sign[i] = (polyi < 0)? -1 : 1;
|
||||||
|
}
|
||||||
|
FFTPoly res_fft(N2p1);
|
||||||
|
FFTPoly tmp_fft(N2p1);
|
||||||
|
int mask = b-1;
|
||||||
|
int bound = b >> 1;
|
||||||
|
int digit, sgn;
|
||||||
|
for (int j = 0; j < l; ++j)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
int& abs_val = poly_abs[i];
|
||||||
|
digit = abs_val & mask; //poly_abs[i] % b;
|
||||||
|
if (digit > bound)
|
||||||
|
{
|
||||||
|
poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit);
|
||||||
|
abs_val >>= shift;
|
||||||
|
++abs_val; //(abs_val - digit)/b + 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit;
|
||||||
|
abs_val >>= shift; //(abs_val - digit)/b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fftN.to_fft(tmp_fft, poly_decomp);
|
||||||
|
tmp_fft *= poly_vector[j];
|
||||||
|
res_fft += tmp_fft;
|
||||||
|
}
|
||||||
|
fftN.from_fft(res, res_fft);
|
||||||
|
//mod_q_boot(poly);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int B_ksk = Param::B_ksk;
|
||||||
|
int l_ksk = parLWE.l_ksk;
|
||||||
|
int Nl = parLWE.Nl;
|
||||||
|
int n = parLWE.n;
|
||||||
|
|
||||||
|
vector<int> poly_decomp(Nl);
|
||||||
|
ModQPoly poly_sign(N);
|
||||||
|
ModQPoly poly_abs(N);
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
const int& polyi = poly[i];
|
||||||
|
poly_abs[i] = abs(polyi);
|
||||||
|
poly_sign[i] = (polyi < 0)? -1 : 1;
|
||||||
|
}
|
||||||
|
int digit;
|
||||||
|
int il = 0;
|
||||||
|
int tmp;
|
||||||
|
int sgn;
|
||||||
|
int bound = B_ksk >> 1;
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
tmp = poly_abs[i];
|
||||||
|
sgn = poly_sign[i];
|
||||||
|
for (int j = 0; j < l_ksk; ++j)
|
||||||
|
{
|
||||||
|
digit = tmp % B_ksk;
|
||||||
|
if (digit > bound)
|
||||||
|
{
|
||||||
|
poly_decomp[il+j] = (sgn == 1) ? (digit - B_ksk): (B_ksk - digit);
|
||||||
|
tmp /= B_ksk;
|
||||||
|
++tmp;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
poly_decomp[il+j] = (sgn == 1) ? digit: - digit;
|
||||||
|
tmp /= B_ksk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
il += l_ksk;
|
||||||
|
}
|
||||||
|
vector<long> a(n);
|
||||||
|
for (int i = 0; i < Nl; ++i)
|
||||||
|
{
|
||||||
|
long tmp_int = long(poly_decomp[i]);
|
||||||
|
const vector<int>& ksk_row = ksk.A[i];
|
||||||
|
for (int j = 0; j < n; ++j)
|
||||||
|
{
|
||||||
|
a[j] += long(ksk_row[j]) * tmp_int;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parLWE.mod_q_base(ct.a, a);
|
||||||
|
long b = 0L;
|
||||||
|
const vector<int>& ksk_b = ksk.b;
|
||||||
|
for (int i = 0; i < Nl; ++i)
|
||||||
|
b += ksk_b[i] * long(poly_decomp[i]);
|
||||||
|
ct.b = parLWE.mod_q_base(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// debugger functions
|
||||||
|
void print(const vector<int>& vec)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < vec.size(); i++)
|
||||||
|
{
|
||||||
|
printf("[%zu] %d ", i, vec[i]);
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk)
|
||||||
|
{
|
||||||
|
FFTPoly sk_fft;
|
||||||
|
fftN.to_fft(sk_fft, sk.sk);
|
||||||
|
FFTPoly ct_fft;
|
||||||
|
fftN.to_fft(ct_fft, ct);
|
||||||
|
|
||||||
|
FFTPoly output_fft;
|
||||||
|
output_fft = ct_fft * sk_fft;
|
||||||
|
vector<long> output;
|
||||||
|
vector<int> output_int;
|
||||||
|
fftN.from_fft(output, output_fft);
|
||||||
|
parLWE.mod_q_boot(output_int, output);
|
||||||
|
print(output_int);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrypt_poly_base_and_print(const ModQPoly& ct, const SKey_boot& sk)
|
||||||
|
{
|
||||||
|
FFTPoly sk_fft;
|
||||||
|
fftN.to_fft(sk_fft, sk.sk);
|
||||||
|
FFTPoly ct_fft;
|
||||||
|
fftN.to_fft(ct_fft, ct);
|
||||||
|
|
||||||
|
FFTPoly output_fft;
|
||||||
|
output_fft = ct_fft * sk_fft;
|
||||||
|
ModQPoly output;
|
||||||
|
fftN.from_fft(output, output_fft);
|
||||||
|
mod_q_base(output);
|
||||||
|
print(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decryptN2(const Ctxt_LWE& ct, const SKey_base_LWE& sk)
|
||||||
|
{
|
||||||
|
int output = ct.b;
|
||||||
|
for (int i = 0; i < lwe_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct.a[i] * sk[i];
|
||||||
|
}
|
||||||
|
output = output%parLWE.N2;
|
||||||
|
if (output > parLWE.N)
|
||||||
|
output -= parLWE.N2;
|
||||||
|
if (output <= -parLWE.N)
|
||||||
|
output += parLWE.N2;
|
||||||
|
cout << output << endl;
|
||||||
|
}
|
||||||
|
// end debugger functions
|
||||||
|
*/
|
||||||
|
|
||||||
|
void SchemeLWE::bootstrap(Ctxt_LWE& ct) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
int N = Param::N;
|
||||||
|
int N2 = Param::N2;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
int B_bsk_size = Param::B_bsk_size;
|
||||||
|
int half_delta_boot = Param::half_delta_boot;
|
||||||
|
|
||||||
|
// switch to modulus 2*N
|
||||||
|
modulo_switch_to_boot(ct);
|
||||||
|
// initialize accumulator and rotate accumulator by X^ct.b
|
||||||
|
vector<int> acc(N, half_delta_boot);
|
||||||
|
int b_pow = (N/2 + ct.b)%N2;
|
||||||
|
if (b_pow < 0)
|
||||||
|
b_pow += N2;
|
||||||
|
int b_sign = 1;
|
||||||
|
if (b_pow >= N)
|
||||||
|
{
|
||||||
|
b_pow -= N;
|
||||||
|
b_sign = -1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < b_pow; ++i)
|
||||||
|
acc[i] = (b_sign == 1) ? -acc[i]: acc[i];
|
||||||
|
for (int i = b_pow; i < N; ++i)
|
||||||
|
acc[i] = (b_sign == 1) ? acc[i]: -acc[i];
|
||||||
|
|
||||||
|
//accumulator loop
|
||||||
|
int coef_counter = 0;
|
||||||
|
vector<int>& a = ct.a;
|
||||||
|
int coef, coef_sign, B, shift, l;
|
||||||
|
double Bd;
|
||||||
|
//auto start = clock();
|
||||||
|
//float cmux_time = 0.0;
|
||||||
|
//float extprod_time = 0.0;
|
||||||
|
vector<int> tmp_poly(N);
|
||||||
|
vector<long> tmp_poly_long(N);
|
||||||
|
|
||||||
|
const BSKey_LWE& boot_key = bk;
|
||||||
|
for (int iBase = 0; iBase < B_bsk_size; ++iBase)
|
||||||
|
{
|
||||||
|
B = parLWE.B_bsk[iBase];
|
||||||
|
Bd = double(B);
|
||||||
|
shift = parLWE.shift_bsk[iBase];
|
||||||
|
l = parLWE.l_bsk[iBase];
|
||||||
|
//vector<complex<double>> w_powers(l);
|
||||||
|
//w_powers[0] = complex<double>(1.0,0.0);
|
||||||
|
//for (int i = 1; i < l; i++)
|
||||||
|
// w_powers[i] = w_powers[i-1] * Bd;
|
||||||
|
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
||||||
|
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; ++iCoef)
|
||||||
|
{
|
||||||
|
//auto start = clock();
|
||||||
|
coef = a[iCoef+coef_counter];
|
||||||
|
if (coef == 0) continue;
|
||||||
|
coef_sign = 1;
|
||||||
|
if (coef < 0) coef += N2;
|
||||||
|
if (coef >= N)
|
||||||
|
{
|
||||||
|
coef -= N;
|
||||||
|
coef_sign = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// acc * (X^coef - 1)
|
||||||
|
if (coef_sign == 1)
|
||||||
|
{
|
||||||
|
for (int i = 0; i<coef; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(-acc[i-coef+N] - acc[i]);
|
||||||
|
for (int i = coef; i < N; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(acc[i-coef] - acc[i]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i<coef; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(acc[i-coef+N] - acc[i]);
|
||||||
|
for (int i = coef; i < N; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(-acc[i-coef] - acc[i]);
|
||||||
|
}
|
||||||
|
//cmux_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
//start = clock();
|
||||||
|
// acc * (X^coef - 1) x bk[i]
|
||||||
|
external_product(tmp_poly_long, tmp_poly, bk_coef_row[iCoef], B, shift, l);
|
||||||
|
mod_q_boot(tmp_poly, tmp_poly_long);
|
||||||
|
// acc * (X^coef - 1) x bk[i] + acc
|
||||||
|
for (int i = 0; i<N; ++i)
|
||||||
|
acc[i] += tmp_poly[i];
|
||||||
|
//extprod_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
}
|
||||||
|
coef_counter += parLWE.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
//cout << "Cmux: " << cmux_time << endl;
|
||||||
|
//cout << "Ext. prod: " << extprod_time << endl;
|
||||||
|
|
||||||
|
// add floor(q_boot/(2*t)) to all coefficients of the accumulator
|
||||||
|
for (auto it = acc.begin(); it < acc.end(); ++it)
|
||||||
|
*it += half_delta_boot;
|
||||||
|
|
||||||
|
//mod q_boot of the accumulator
|
||||||
|
mod_q_boot(acc);
|
||||||
|
|
||||||
|
//decrypt_poly_boot_and_print(acc, sk_boot);
|
||||||
|
|
||||||
|
//mod switch to q_base
|
||||||
|
modulo_switch_to_base_lwe(acc);
|
||||||
|
|
||||||
|
//decrypt_poly_boot_and_print(acc, sk_boot);
|
||||||
|
|
||||||
|
//key switch
|
||||||
|
//auto start = clock();
|
||||||
|
key_switch(ct, acc);
|
||||||
|
//cout << "Key-switching: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
//cout << "Bootstrapping: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::bootstrap2(Ctxt_LWE& ct) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
int N = Param::N;
|
||||||
|
int N2 = Param::N2;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
int B_bsk_size = Param::B_bsk_size;
|
||||||
|
int half_delta_boot = Param::half_delta_boot;
|
||||||
|
|
||||||
|
// switch to modulus 2*N
|
||||||
|
modulo_switch_to_boot(ct);
|
||||||
|
// initialize accumulator and rotate accumulator by X^ct.b
|
||||||
|
vector<int> acc(N, half_delta_boot);
|
||||||
|
int b_pow = (N/2 + ct.b)%N2;
|
||||||
|
if (b_pow < 0)
|
||||||
|
b_pow += N2;
|
||||||
|
int b_sign = 1;
|
||||||
|
if (b_pow >= N)
|
||||||
|
{
|
||||||
|
b_pow -= N;
|
||||||
|
b_sign = -1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < b_pow; ++i)
|
||||||
|
acc[i] = (b_sign == 1) ? -acc[i]: acc[i];
|
||||||
|
for (int i = b_pow; i < N; ++i)
|
||||||
|
acc[i] = (b_sign == 1) ? acc[i]: -acc[i];
|
||||||
|
|
||||||
|
//accumulator loop
|
||||||
|
int coef_counter = 0;
|
||||||
|
vector<int>& a = ct.a;
|
||||||
|
int coef1, coef2, coef_sum, coef1_sign, coef2_sign, coef_sum_sign, B, shift, l;
|
||||||
|
double Bd;
|
||||||
|
//auto start = clock();
|
||||||
|
//float cmux_time = 0.0;
|
||||||
|
//float extprod_time = 0.0;
|
||||||
|
vector<int> tmp_poly(N);
|
||||||
|
vector<long> tmp_poly_long(N);
|
||||||
|
|
||||||
|
const BSKey_LWE& boot_key = bk;
|
||||||
|
for (int iBase = 0; iBase < B_bsk_size; ++iBase)
|
||||||
|
{
|
||||||
|
B = parLWE.B_bsk[iBase];
|
||||||
|
Bd = double(B);
|
||||||
|
shift = parLWE.shift_bsk[iBase];
|
||||||
|
l = parLWE.l_bsk[iBase];
|
||||||
|
//vector<complex<double>> w_powers(l);
|
||||||
|
//w_powers[0] = complex<double>(1.0,0.0);
|
||||||
|
//for (int i = 1; i < l; i++)
|
||||||
|
// w_powers[i] = w_powers[i-1] * Bd;
|
||||||
|
const vector<NGSFFTctxt>& bk_coef_row = boot_key[iBase];
|
||||||
|
vector<FFTPoly> mux_fft(l,FFTPoly(N2p1,complex<double>(0.0,0.0)));
|
||||||
|
for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; iCoef+=2)
|
||||||
|
{
|
||||||
|
//auto start = clock();
|
||||||
|
// normalize coef1
|
||||||
|
coef1 = a[iCoef+coef_counter];
|
||||||
|
coef1_sign = 1;
|
||||||
|
if (coef1 < 0) coef1 += N2;
|
||||||
|
if (coef1 >= N)
|
||||||
|
{
|
||||||
|
coef1 -= N;
|
||||||
|
coef1_sign = -1;
|
||||||
|
}
|
||||||
|
// normalize coef2
|
||||||
|
coef2 = a[iCoef+coef_counter+1];
|
||||||
|
coef2_sign = 1;
|
||||||
|
if (coef2 < 0) coef2 += N2;
|
||||||
|
if (coef2 >= N)
|
||||||
|
{
|
||||||
|
coef2 -= N;
|
||||||
|
coef2_sign = -1;
|
||||||
|
}
|
||||||
|
// normalize coef_sum
|
||||||
|
coef_sum = (a[iCoef+coef_counter] + a[iCoef+coef_counter+1]) % N2;
|
||||||
|
coef_sum_sign = 1;
|
||||||
|
if (coef_sum < 0) coef_sum += N2;
|
||||||
|
if (coef_sum >= N)
|
||||||
|
{
|
||||||
|
coef_sum -= N;
|
||||||
|
coef_sum_sign = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bk_0 * X^(c_0+c_1) + bk_1 * X^(c_0) + bk_2 * X^(c_1) + bk_3
|
||||||
|
const FFTPoly& x_sum = (coef_sum_sign == 1) ? fftN.pos_powers[coef_sum]: fftN.neg_powers[coef_sum];
|
||||||
|
const FFTPoly& x_c1 = (coef1_sign == 1) ? fftN.pos_powers[coef1]: fftN.neg_powers[coef1];
|
||||||
|
const FFTPoly& x_c2 = (coef2_sign == 1) ? fftN.pos_powers[coef2]: fftN.neg_powers[coef2];
|
||||||
|
const NGSFFTctxt& bk_part_row0 = bk_coef_row[4*(iCoef >> 1)];
|
||||||
|
const NGSFFTctxt& bk_part_row1 = bk_coef_row[4*(iCoef >> 1)+1];
|
||||||
|
const NGSFFTctxt& bk_part_row2 = bk_coef_row[4*(iCoef >> 1)+2];
|
||||||
|
const NGSFFTctxt& bk_part_row3 = bk_coef_row[4*(iCoef >> 1)+3];
|
||||||
|
for (int iPart = 0; iPart < l; ++iPart)
|
||||||
|
{
|
||||||
|
mux_fft[iPart] = bk_part_row0[iPart];
|
||||||
|
mux_fft[iPart] *= x_sum;
|
||||||
|
|
||||||
|
if (coef1 == coef2 && coef1_sign == coef2_sign)
|
||||||
|
{
|
||||||
|
FFTPoly tmp_fft = bk_part_row1[iPart];
|
||||||
|
tmp_fft += bk_part_row2[iPart];
|
||||||
|
tmp_fft *= x_c1;
|
||||||
|
mux_fft[iPart] += tmp_fft;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
FFTPoly tmp_fft = bk_part_row1[iPart];
|
||||||
|
tmp_fft *= x_c1;
|
||||||
|
mux_fft[iPart] += tmp_fft;
|
||||||
|
|
||||||
|
tmp_fft = bk_part_row2[iPart];
|
||||||
|
tmp_fft *= x_c2;
|
||||||
|
mux_fft[iPart] += tmp_fft;
|
||||||
|
}
|
||||||
|
|
||||||
|
mux_fft[iPart] += bk_part_row3[iPart];
|
||||||
|
}
|
||||||
|
//cmux_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
//start = clock();
|
||||||
|
external_product(tmp_poly_long, acc, mux_fft, B, shift, l);
|
||||||
|
mod_q_boot(acc, tmp_poly_long);
|
||||||
|
//extprod_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
}
|
||||||
|
coef_counter += parLWE.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
//cout << "Cmux: " << cmux_time << endl;
|
||||||
|
//cout << "Ext. prod: " << extprod_time << endl;
|
||||||
|
|
||||||
|
// add floor(q_boot/(2*t)) to all coefficients of the accumulator
|
||||||
|
for (auto it = acc.begin(); it < acc.end(); ++it)
|
||||||
|
*it += half_delta_boot;
|
||||||
|
|
||||||
|
//mod q_boot of the accumulator
|
||||||
|
mod_q_boot(acc);
|
||||||
|
|
||||||
|
//decrypt_poly_boot_and_print(acc, sk_boot);
|
||||||
|
|
||||||
|
//mod switch to q_base
|
||||||
|
modulo_switch_to_base_lwe(acc);
|
||||||
|
|
||||||
|
//decrypt_poly_boot_and_print(acc, sk_boot);
|
||||||
|
|
||||||
|
//key switch
|
||||||
|
//auto start = clock();
|
||||||
|
key_switch(ct, acc);
|
||||||
|
//cout << "Key-switching: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
//cout << "Bootstrapping: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::nand_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = parLWE.nand_const - (ct1 + ct2);
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "NAND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::and_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = parLWE.and_const - (ct1 + ct2);
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "AND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeLWE::or_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = parLWE.or_const - (ct1 + ct2);
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "OR: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
416
src/ntruhe.cpp
Normal file
416
src/ntruhe.cpp
Normal file
@ -0,0 +1,416 @@
|
|||||||
|
#include "ntruhe.h"
|
||||||
|
#include "sampler.h"
|
||||||
|
#include "fft.h"
|
||||||
|
#include "lwehe.h"
|
||||||
|
|
||||||
|
#include "time.h"
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
Ctxt_NTRU::Ctxt_NTRU(const Ctxt_NTRU& ct)
|
||||||
|
{
|
||||||
|
data = ct.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_NTRU& Ctxt_NTRU::operator=(const Ctxt_NTRU& ct)
|
||||||
|
{
|
||||||
|
data = ct.data;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_NTRU Ctxt_NTRU::operator +(const Ctxt_NTRU& ct) const
|
||||||
|
{
|
||||||
|
Ctxt_NTRU res;
|
||||||
|
for (size_t i = 0; i < parNTRU.n; ++i)
|
||||||
|
res.data[i] = parNTRU.mod_q_base(data[i] + ct.data[i]);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ctxt_NTRU Ctxt_NTRU::operator -(const Ctxt_NTRU& ct) const
|
||||||
|
{
|
||||||
|
Ctxt_NTRU res;
|
||||||
|
for (size_t i = 0; i < parNTRU.n; ++i)
|
||||||
|
res.data[i] = parNTRU.mod_q_base(data[i] - ct.data[i]);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ctxt_NTRU::operator -=(const Ctxt_NTRU& ct)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < parNTRU.n; ++i)
|
||||||
|
{
|
||||||
|
data[i] -= ct.data[i];
|
||||||
|
data[i] = parNTRU.mod_q_base(data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeNTRU::encrypt(Ctxt_NTRU& ct, const int b) const
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
int n = parNTRU.n;
|
||||||
|
|
||||||
|
vector<int> g(n,0L);
|
||||||
|
Sampler::get_ternary_vector(g);
|
||||||
|
g[0] += b * parNTRU.delta_base;
|
||||||
|
ct.data = vector<int>(n,0);
|
||||||
|
vector<long> ct_long(n,0L);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
long g_coef = long(g[i]);
|
||||||
|
const vector<int>& sk_row = sk_base.sk_inv[i];
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
ct_long[j] += long(sk_row[j]) * g_coef;
|
||||||
|
}
|
||||||
|
parNTRU.mod_q_base(ct.data, ct_long);
|
||||||
|
|
||||||
|
//cout << "Encryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SchemeNTRU::decrypt(const Ctxt_NTRU& ct) const
|
||||||
|
{
|
||||||
|
clock_t start = clock();
|
||||||
|
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < parNTRU.n; i++)
|
||||||
|
{
|
||||||
|
output += ct.data[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = parNTRU.mod_q_base(output);
|
||||||
|
output = int(round(double(output)/double(parNTRU.q_base)*Param::t));
|
||||||
|
//cout << "Decryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
void SchemeNTRU::external_product(vector<long>& res, const vector<int>& poly, const vector<FFTPoly>& poly_vector, const int b, const int shift, const int l) const
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
|
||||||
|
ModQPoly poly_sign(N,0L);
|
||||||
|
ModQPoly poly_abs(N,0L);
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
const int& polyi = poly[i];
|
||||||
|
poly_abs[i] = abs(polyi);
|
||||||
|
poly_sign[i] = (polyi < 0)? -1 : 1;
|
||||||
|
}
|
||||||
|
FFTPoly res_fft(N2p1);
|
||||||
|
FFTPoly tmp_fft(N2p1);
|
||||||
|
int mask = b-1;
|
||||||
|
int bound = b >> 1;
|
||||||
|
int digit, sgn, abs_val;
|
||||||
|
vector<int> poly_decomp(N);
|
||||||
|
for (int j = 0; j < l; j++)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
abs_val = poly_abs[i];
|
||||||
|
digit = abs_val & mask; //poly_abs[i] % b;
|
||||||
|
if (digit > bound)
|
||||||
|
{
|
||||||
|
poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit);
|
||||||
|
poly_abs[i] = (abs_val >> shift) + 1; //(abs_val - digit)/b + 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit;
|
||||||
|
poly_abs[i] = abs_val >> shift; //(abs_val - digit)/b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fftN.to_fft(tmp_fft, poly_decomp);
|
||||||
|
tmp_fft *= poly_vector[j];
|
||||||
|
res_fft += tmp_fft;
|
||||||
|
}
|
||||||
|
fftN.from_fft(res, res_fft);
|
||||||
|
//mod_q_boot(poly);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int B_ksk = Param::B_ksk;
|
||||||
|
int Nl = parNTRU.Nl;
|
||||||
|
int l_ksk = parNTRU.l_ksk;
|
||||||
|
int n = parNTRU.n;
|
||||||
|
|
||||||
|
vector<int> poly_decomp(Nl);
|
||||||
|
ModQPoly poly_sign(N);
|
||||||
|
ModQPoly poly_abs(N);
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
const int& polyi = poly[i];
|
||||||
|
poly_abs[i] = abs(polyi);
|
||||||
|
poly_sign[i] = (polyi < 0)? -1 : 1;
|
||||||
|
}
|
||||||
|
int bound = B_ksk >> 1;
|
||||||
|
int il = 0;
|
||||||
|
int digit, tmp, sgn;
|
||||||
|
for (int i = 0; i < N; ++i)
|
||||||
|
{
|
||||||
|
tmp = poly_abs[i];
|
||||||
|
sgn = poly_sign[i];
|
||||||
|
for (int j = 0; j < l_ksk; j++)
|
||||||
|
{
|
||||||
|
int digit = tmp % B_ksk;
|
||||||
|
if (digit > bound)
|
||||||
|
{
|
||||||
|
poly_decomp[il+j] = (sgn == 1) ? (digit - B_ksk): (B_ksk - digit);
|
||||||
|
tmp /= B_ksk;
|
||||||
|
++tmp;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
poly_decomp[il+j] = (sgn == 1) ? digit: - digit;
|
||||||
|
tmp /= B_ksk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
il += l_ksk;
|
||||||
|
}
|
||||||
|
vector<long> ct_long(n);
|
||||||
|
for (int i = 0; i < Nl; ++i)
|
||||||
|
{
|
||||||
|
long tmp_int = long(poly_decomp[i]);
|
||||||
|
const vector<int>& ksk_row = ksk[i];
|
||||||
|
for (int j = 0; j < n; ++j)
|
||||||
|
{
|
||||||
|
ct_long[j] += long(ksk_row[j]) * tmp_int;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parNTRU.mod_q_base(ct.data, ct_long);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// debugger functions
|
||||||
|
void print(const vector<int>& vec)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < vec.size(); i++)
|
||||||
|
{
|
||||||
|
printf("[%zu] %d ", i, vec[i]);
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk, const Param& param)
|
||||||
|
{
|
||||||
|
FFTPoly sk_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(sk_fft, sk.sk);
|
||||||
|
FFTPoly ct_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(ct_fft, ct);
|
||||||
|
|
||||||
|
FFTPoly output_fft;
|
||||||
|
output_fft = ct_fft * sk_fft;
|
||||||
|
ModQPoly output;
|
||||||
|
vector<long> output_long;
|
||||||
|
fftN.from_fft(output_long, output_fft);
|
||||||
|
mod_q_boot(output, output_long);
|
||||||
|
print(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrypt_poly_base_and_print(const ModQPoly& ct, const Param& param, const SKey_boot& sk)
|
||||||
|
{
|
||||||
|
FFTPoly sk_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(sk_fft, sk.sk);
|
||||||
|
FFTPoly ct_fft(Param::N2p1);
|
||||||
|
fftN.to_fft(ct_fft, ct);
|
||||||
|
|
||||||
|
FFTPoly output_fft;
|
||||||
|
output_fft = ct_fft * sk_fft;
|
||||||
|
ModQPoly output;
|
||||||
|
vector<long> output_long;
|
||||||
|
fftN.from_fft(output_long, output_fft);
|
||||||
|
param.mod_q_base(output, output_long);
|
||||||
|
print(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decryptN2(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk)
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int N2 = Param::N2;
|
||||||
|
int n = parNTRU.n;
|
||||||
|
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
output += ct.data[i] * sk.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%N2;
|
||||||
|
if (output > N)
|
||||||
|
output -= N2;
|
||||||
|
if (output <= -N)
|
||||||
|
output += N2;
|
||||||
|
cout << output << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrypt_base(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk)
|
||||||
|
{
|
||||||
|
int n = parNTRU.n;
|
||||||
|
int q_base = parNTRU.q_base;
|
||||||
|
int half_q_base= parNTRU.half_q_base;
|
||||||
|
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
output += ct.data[i] * sk.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%q_base;
|
||||||
|
if (output > half_q_base)
|
||||||
|
output -= q_base;
|
||||||
|
if (output <= -half_q_base)
|
||||||
|
output += q_base;
|
||||||
|
cout << output << endl;
|
||||||
|
}
|
||||||
|
// end debugger functions
|
||||||
|
*/
|
||||||
|
void SchemeNTRU::mask_constant(Ctxt_NTRU& ct, int constant)
|
||||||
|
{
|
||||||
|
int n = parNTRU.n;
|
||||||
|
|
||||||
|
vector<int> g(n);
|
||||||
|
Sampler::get_ternary_vector(g);
|
||||||
|
g[0] += constant;
|
||||||
|
ct.data = vector<int>(n);
|
||||||
|
vector<long> ct_long(n);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
long g_coef = long(g[i]);
|
||||||
|
const vector<int>& sk_row = sk_base.sk_inv[i];
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
ct_long[j] += sk_row[j] * g_coef;
|
||||||
|
}
|
||||||
|
parNTRU.mod_q_base(ct.data, ct_long);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeNTRU::bootstrap(Ctxt_NTRU& ct) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
int N = Param::N;
|
||||||
|
int N2 = Param::N2;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
int B_bsk_size = Param::B_bsk_size;
|
||||||
|
int half_delta_boot = Param::half_delta_boot;
|
||||||
|
|
||||||
|
// switch to modulus 2*N
|
||||||
|
modulo_switch_to_boot(ct);
|
||||||
|
// initialize accumulator
|
||||||
|
ModQPoly acc(N, half_delta_boot);
|
||||||
|
//vector<long> acc_long(N);
|
||||||
|
for (size_t i = 0; i < N/2; i++)
|
||||||
|
acc[i] = -acc[i];
|
||||||
|
|
||||||
|
//accumulator loop
|
||||||
|
int coef_counter = 0;
|
||||||
|
vector<int>& data = ct.data;
|
||||||
|
int coef, neg_coef, coef_sign, neg_coef_sign, B, shift, l;
|
||||||
|
double Bd;
|
||||||
|
vector<int> tmp_poly(N);
|
||||||
|
vector<long> tmp_poly_long(N);
|
||||||
|
|
||||||
|
const BSKey_NTRU& boot_key = bk;
|
||||||
|
for (int iBase = 0; iBase < B_bsk_size; ++iBase)
|
||||||
|
{
|
||||||
|
B = parNTRU.B_bsk[iBase];
|
||||||
|
Bd = double(B);
|
||||||
|
shift = parNTRU.shift_bsk[iBase];
|
||||||
|
l = parNTRU.l_bsk[iBase];
|
||||||
|
//vector<complex<double>> w_power_fft(l);
|
||||||
|
//w_power_fft[0] = complex<double>(1.0,0.0);
|
||||||
|
//for (int i = 1; i < l; i++)
|
||||||
|
// w_power_fft[i] = w_power_fft[i-1] * Bd;
|
||||||
|
const vector<vector<NGSFFTctxt>>& bk_coef_row = boot_key[iBase];
|
||||||
|
vector<FFTPoly> mux_fft(l, FFTPoly(N2p1));
|
||||||
|
for (int iCoef = 0; iCoef < parNTRU.bsk_partition[iBase]; ++iCoef)
|
||||||
|
{
|
||||||
|
coef = data[iCoef+coef_counter];
|
||||||
|
if (coef == 0) continue;
|
||||||
|
coef_sign = 1;
|
||||||
|
if (coef < 0) coef += N2;
|
||||||
|
if (coef >= N)
|
||||||
|
{
|
||||||
|
coef -= N;
|
||||||
|
coef_sign = -1;
|
||||||
|
}
|
||||||
|
neg_coef = N - coef;
|
||||||
|
neg_coef_sign = -coef_sign;
|
||||||
|
if(neg_coef == N)
|
||||||
|
{
|
||||||
|
neg_coef = 0;
|
||||||
|
neg_coef_sign = -neg_coef_sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
// acc * (X^coef - 1)
|
||||||
|
if (coef_sign == 1)
|
||||||
|
{
|
||||||
|
for (int i = 0; i<coef; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(-acc[i-coef+N] - acc[i]);
|
||||||
|
for (int i = coef; i < N; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(acc[i-coef] - acc[i]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i<coef; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(acc[i-coef+N] - acc[i]);
|
||||||
|
for (int i = coef; i < N; ++i)
|
||||||
|
tmp_poly[i] = mod_q_boot(-acc[i-coef] - acc[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// bk_0[i] - bk_1[i]*X^-coef
|
||||||
|
const FFTPoly& x_power_neg = (neg_coef_sign == 1) ? fftN.pos_powers[neg_coef]: fftN.neg_powers[neg_coef];
|
||||||
|
const NGSFFTctxt& bk_part_row0 = bk_coef_row[iCoef][0];
|
||||||
|
const NGSFFTctxt& bk_part_row1 = bk_coef_row[iCoef][1];
|
||||||
|
for (int iPart = 0; iPart < l; ++iPart)
|
||||||
|
{
|
||||||
|
FFTPoly tmp_fft = bk_part_row1[iPart];
|
||||||
|
tmp_fft *= x_power_neg;
|
||||||
|
mux_fft[iPart] = bk_part_row0[iPart];
|
||||||
|
mux_fft[iPart] -= tmp_fft;
|
||||||
|
}
|
||||||
|
// acc * (X^coef - 1) x (bk_0[i] - bk_1[i]*X^-coef)
|
||||||
|
external_product(tmp_poly_long, tmp_poly, mux_fft, B, shift, l);
|
||||||
|
mod_q_boot(tmp_poly, tmp_poly_long);
|
||||||
|
// acc * (X^coef - 1) x (bk_0[i] - bk_1[i]*X^-coef) + acc
|
||||||
|
for (int i = 0; i<N; ++i)
|
||||||
|
acc[i] += tmp_poly[i];
|
||||||
|
}
|
||||||
|
coef_counter += parNTRU.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
|
||||||
|
// add floor(q_boot/(2*t)) to all coefficients of the accumulator
|
||||||
|
for (auto it = acc.begin(); it < acc.end(); ++it)
|
||||||
|
*it += half_delta_boot;
|
||||||
|
|
||||||
|
//mod q_boot of the accumulator
|
||||||
|
mod_q_boot(acc);
|
||||||
|
|
||||||
|
//mod switch to q_base
|
||||||
|
modulo_switch_to_base_ntru(acc);
|
||||||
|
|
||||||
|
//key switch
|
||||||
|
key_switch(ct, acc);
|
||||||
|
|
||||||
|
//cout << "Bootstrapping: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeNTRU::nand_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = ct_nand_const - ct1 - ct2;
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "NAND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeNTRU::and_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = ct_and_const - ct1 - ct2;
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "AND: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SchemeNTRU::or_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const
|
||||||
|
{
|
||||||
|
//clock_t start = clock();
|
||||||
|
ct_res = ct_or_const - ct1 - ct2;
|
||||||
|
bootstrap(ct_res);
|
||||||
|
//cout << "OR: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
}
|
185
src/sampler.cpp
Normal file
185
src/sampler.cpp
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
#include "sampler.h"
|
||||||
|
#include "params.h"
|
||||||
|
#include <NTL/ZZ_pX.h>
|
||||||
|
#include <NTL/mat_ZZ_p.h>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
void Sampler::get_ternary_vector(vector<int>& vec)
|
||||||
|
{
|
||||||
|
|
||||||
|
for(int i=0; i<vec.size(); i++)
|
||||||
|
vec[i] = ternary_sampler(rand_engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_ternary_matrix(vector<vector<int>>& mat)
|
||||||
|
{
|
||||||
|
|
||||||
|
for(int i=0; i<mat.size(); i++)
|
||||||
|
{
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
get_ternary_vector(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_binary_vector(vector<int>& vec)
|
||||||
|
{
|
||||||
|
for(int i=0; i<vec.size(); i++)
|
||||||
|
vec[i] = binary_sampler(rand_engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_uniform_vector(vector<int>& vec)
|
||||||
|
{
|
||||||
|
for(int i=0; i<vec.size(); i++)
|
||||||
|
vec[i] = mod_q_base_sampler(rand_engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_uniform_matrix(vector<vector<int>>& mat)
|
||||||
|
{
|
||||||
|
for(int i=0; i<mat.size(); i++)
|
||||||
|
{
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
get_uniform_vector(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_gaussian_vector(vector<int>& vec, double st_dev)
|
||||||
|
{
|
||||||
|
normal_distribution<double> gaussian_sampler(0.0, st_dev);
|
||||||
|
for(size_t i=0; i<vec.size(); i++)
|
||||||
|
vec[i] = static_cast<int>(round(gaussian_sampler(rand_engine)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_gaussian_matrix(vector<vector<int>>& mat, double st_dev)
|
||||||
|
{
|
||||||
|
for(size_t i=0; i<mat.size(); i++)
|
||||||
|
{
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
get_gaussian_vector(row, st_dev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_invertible_vector(vector<int>& vec, vector<int>& vec_inv, int scale, int shift)
|
||||||
|
{
|
||||||
|
//polynomial with the coefficient vector vec (will be generated later)
|
||||||
|
ZZ_pX poly;
|
||||||
|
//element of Z_(q_boot)
|
||||||
|
ZZ_p coef;
|
||||||
|
coef.init(ZZ(q_boot));
|
||||||
|
//the inverse of poly modulo poly_mod (will be generated later)
|
||||||
|
ZZ_pX inv_poly;
|
||||||
|
//random sampling
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
//create the polynomial with the coefficient vector of the desired form
|
||||||
|
SetCoeff(poly, 0, ternary_sampler(rand_engine)*scale + shift);
|
||||||
|
for (size_t i = 1; i < vec.size(); i++)
|
||||||
|
{
|
||||||
|
coef = ternary_sampler(rand_engine)*scale;
|
||||||
|
SetCoeff(poly, i, coef);
|
||||||
|
}
|
||||||
|
//test invertibility
|
||||||
|
try
|
||||||
|
{
|
||||||
|
InvMod(inv_poly, poly, Param::get_def_poly());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
cout << "Polynomial " << poly << " isn't a unit" << endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//cout << "Poly: " << poly << endl;
|
||||||
|
//cout << "Poly inverse: " << inv_poly << endl;
|
||||||
|
//extract the coefficient vector of poly
|
||||||
|
int tmp_coef;
|
||||||
|
for (int i = 0; i <= deg(poly); i++)
|
||||||
|
{
|
||||||
|
tmp_coef = conv<long>(poly[i]);
|
||||||
|
if (tmp_coef > half_q_boot)
|
||||||
|
tmp_coef -= q_boot;
|
||||||
|
vec[i] = tmp_coef;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i <= deg(inv_poly); i++)
|
||||||
|
{
|
||||||
|
tmp_coef = conv<long>(inv_poly[i]);
|
||||||
|
if (tmp_coef > half_q_boot)
|
||||||
|
tmp_coef -= q_boot;
|
||||||
|
vec_inv[i] = tmp_coef;
|
||||||
|
}
|
||||||
|
|
||||||
|
//cout << "Vector:" << vec << endl;
|
||||||
|
//cout << "Inverse vector:" << vec_inv << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::get_invertible_matrix(vector<vector<int>>& mat, vector<vector<int>>& mat_inv, int scale, int shift)
|
||||||
|
{
|
||||||
|
//check that the input matrices are squares
|
||||||
|
assert(mat[0].size() == mat.size());
|
||||||
|
assert(mat_inv[0].size() == mat_inv.size());
|
||||||
|
//check that both input matrices have the same dimension
|
||||||
|
assert(mat.size() == mat_inv.size());
|
||||||
|
|
||||||
|
//number of rows of the input matrix
|
||||||
|
int dim = mat.size();
|
||||||
|
|
||||||
|
//element of Z_(q_boot)
|
||||||
|
ZZ_p coef;
|
||||||
|
coef.init(ZZ(param.q_base));
|
||||||
|
|
||||||
|
//candidate matrix
|
||||||
|
mat_ZZ_p tmp_mat(INIT_SIZE, dim, dim);
|
||||||
|
|
||||||
|
//candidate inverse matrix
|
||||||
|
mat_ZZ_p tmp_mat_inv(INIT_SIZE, dim, dim);
|
||||||
|
|
||||||
|
//sampling and testing
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
//sampling
|
||||||
|
for (int i = 0; i < dim; i++)
|
||||||
|
{
|
||||||
|
Vec<ZZ_p>& row = tmp_mat[i];
|
||||||
|
for (int j = 0; j < dim; j++)
|
||||||
|
{
|
||||||
|
coef = ternary_sampler(rand_engine)*scale;
|
||||||
|
if (i==j)
|
||||||
|
coef += ZZ_p(shift);
|
||||||
|
row[j] = coef;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//test invertibility
|
||||||
|
try
|
||||||
|
{
|
||||||
|
inv(tmp_mat_inv, tmp_mat);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
cout << "Matrix " << tmp_mat << " is singular" << endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//lift mod q representation to integers
|
||||||
|
int tmp_coef;
|
||||||
|
for (int i = 0; i < dim; i++)
|
||||||
|
{
|
||||||
|
Vec<ZZ_p>& tmp_row = tmp_mat[i];
|
||||||
|
Vec<ZZ_p>& tmp_row_inv = tmp_mat_inv[i];
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
vector<int>& row_inv = mat_inv[i];
|
||||||
|
for (int j = 0; j < dim; j++)
|
||||||
|
{
|
||||||
|
tmp_coef = conv<long>(tmp_row[j]);
|
||||||
|
if (tmp_coef > param.half_q_base)
|
||||||
|
tmp_coef -= param.q_base;
|
||||||
|
row[j] = tmp_coef;
|
||||||
|
|
||||||
|
tmp_coef = conv<long>(tmp_row_inv[j]);
|
||||||
|
if (tmp_coef > param.half_q_base)
|
||||||
|
tmp_coef -= param.q_base;
|
||||||
|
row_inv[j] = tmp_coef;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
960
test.cpp
Normal file
960
test.cpp
Normal file
@ -0,0 +1,960 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <cassert>
|
||||||
|
#include "params.h"
|
||||||
|
#include "sampler.h"
|
||||||
|
#include "keygen.h"
|
||||||
|
#include "fft.h"
|
||||||
|
#include "ntruhe.h"
|
||||||
|
#include "lwehe.h"
|
||||||
|
|
||||||
|
#include <time.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <chrono>
|
||||||
|
#include <limits.h>
|
||||||
|
|
||||||
|
#include <NTL/ZZX.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace NTL;
|
||||||
|
|
||||||
|
void test_params()
|
||||||
|
{
|
||||||
|
{
|
||||||
|
Param param(LWE);
|
||||||
|
cout << "Ciphertext modulus of the base scheme (LWE): " << param.q_base << endl;
|
||||||
|
cout << "Dimension of the base scheme (LWE): " << param.n << endl;
|
||||||
|
cout << "Ciphertext modulus for bootstrapping (LWE): " << q_boot << endl;
|
||||||
|
cout << "Polynomial modulus (LWE): " << Param::get_def_poly() << endl;
|
||||||
|
assert(param.l_ksk == int(ceil(log(double(param.q_base))/log(double(Param::B_ksk)))));
|
||||||
|
cout << "Decomposition length for key-switching (LWE): " << param.l_ksk << endl;
|
||||||
|
cout << "Decomposition bases for key-switching (LWE): " << Param::B_ksk << endl;
|
||||||
|
cout << "Dimension for bootstrapping (LWE): " << Param::N << endl;
|
||||||
|
cout << "Decomposition bases for bootstrapping (LWE): ";
|
||||||
|
for (const auto &v: param.B_bsk) cout << v << ' ';
|
||||||
|
cout << endl;
|
||||||
|
cout << "Delta (LWE): " << param.delta_base << endl;
|
||||||
|
cout << "Half Delta (LWE): " << param.half_delta_base << endl;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Param param(NTRU);
|
||||||
|
cout << "Ciphertext modulus of the base scheme (MNTRU): " << param.q_base << endl;
|
||||||
|
cout << "Dimension of the base scheme (NTRU): " << param.n << endl;
|
||||||
|
cout << "Ciphertext modulus for bootstrapping (NTRU): " << q_boot << endl;
|
||||||
|
cout << "Polynomial modulus (NTRU): " << Param::get_def_poly() << endl;
|
||||||
|
assert(param.l_ksk == int(ceil(log(double(param.q_base))/log(double(Param::B_ksk)))));
|
||||||
|
cout << "Decomposition length for key-switching (MNTRU): " << param.l_ksk << endl;
|
||||||
|
cout << "Decomposition bases for key-switching (MNTRU): " << Param::B_ksk << endl;
|
||||||
|
cout << "Dimension for bootstrapping (MNTRU): " << Param::N << endl;
|
||||||
|
cout << "Decomposition bases for bootstrapping (MNTRU): ";
|
||||||
|
for (const auto &v: param.B_bsk) cout << v << ' ';
|
||||||
|
cout << endl;
|
||||||
|
cout << "Decomposition lengths for bootstrapping (MNTRU): ";
|
||||||
|
for (int i = 0; i < Param::B_bsk_size; i++)
|
||||||
|
{
|
||||||
|
assert(param.l_bsk[i] == int(ceil(log(double(q_boot))/log(double(param.B_bsk[i])))));
|
||||||
|
cout << param.l_bsk[i] << ' ';
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
cout << "Decomposition lengths for bootstrapping (MNTRU): ";
|
||||||
|
for (int i = 0; i < Param::B_bsk_size; i++)
|
||||||
|
{
|
||||||
|
assert(param.l_bsk[i] == int(ceil(log(double(q_boot))/log(double(param.B_bsk[i])))));
|
||||||
|
cout << param.l_bsk[i] << ' ';
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
cout << "Delta (MNTRU): " << param.delta_base << endl;
|
||||||
|
cout << "Half Delta (MNTRU): " << param.half_delta_base << endl;
|
||||||
|
|
||||||
|
{
|
||||||
|
assert(0L == mod_q_boot(0L));
|
||||||
|
assert(1L == mod_q_boot(1L));
|
||||||
|
assert(0L == mod_q_boot(q_boot));
|
||||||
|
assert(half_q_boot == mod_q_boot(half_q_boot));
|
||||||
|
assert(-half_q_boot == mod_q_boot(-half_q_boot));
|
||||||
|
cout << "MODULO REDUCTION IS OK" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "Plaintext modulus: " << Param::t << endl;
|
||||||
|
cout << endl;
|
||||||
|
cout << "PARAMS ARE OK" << endl;
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
decompose(res, 0, 2, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
for (auto iter=res.begin(); iter < res.end(); iter++)
|
||||||
|
assert(0L == *iter);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
decompose(res, 1, 2, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == 1);
|
||||||
|
for (auto iter=res.begin()+1; iter < res.end(); iter++)
|
||||||
|
assert(0L == *iter);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
decompose(res, 2, 3, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == -1 && res[1] == 1 && res[2] == 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
decompose(res, 2, 4, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == 2 && res[1] == 0 && res[2] == 0);
|
||||||
|
decompose(res, 3, 4, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == -1 && res[1] == 1 && res[2] == 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
decompose(res, 14, 3, 3);
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
catch (overflow_error)
|
||||||
|
{
|
||||||
|
assert(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
decompose(res, -14, 3, 3);
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
catch (overflow_error)
|
||||||
|
{
|
||||||
|
assert(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
vector<int> res;
|
||||||
|
decompose(res, 13, 3, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == 1 && res[1] == 1 && res[2] == 1);
|
||||||
|
decompose(res, -13, 3, 3);
|
||||||
|
assert(res.size() == 3);
|
||||||
|
assert(res[0] == -1 && res[1] == -1 && res[2] == -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
cout << "DECOMPOSITION IS OK" << endl;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_sampler()
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
|
||||||
|
Param pLWE(LWE);
|
||||||
|
Param pNTRU(NTRU);
|
||||||
|
for (int run = 0; run < 1; run++)
|
||||||
|
{
|
||||||
|
//cout << "Run: " << run+1 << endl;
|
||||||
|
{
|
||||||
|
vector<int> vec(pNTRU.n, 0L);
|
||||||
|
Sampler::get_ternary_vector(vec);
|
||||||
|
|
||||||
|
assert(vec.size() == pNTRU.n);
|
||||||
|
for (int i = 0; i < pNTRU.n; i++)
|
||||||
|
{
|
||||||
|
assert((vec[i]==0) || (vec[i]==-1) || (vec[i]==1) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<int> vec(N,0L);
|
||||||
|
Sampler::get_ternary_vector(vec);
|
||||||
|
|
||||||
|
assert(vec.size() == N);
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
assert((vec[i]==0) || (vec[i]==-1) || (vec[i]==1) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<int> vec(N,0L);
|
||||||
|
Sampler::get_binary_vector(vec);
|
||||||
|
|
||||||
|
assert(vec.size() == N);
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
assert((vec[i]==0) || (vec[i]==1) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int n = pLWE.n;
|
||||||
|
vector<vector<int>> mat(n, vector<int>(N,0L));
|
||||||
|
Sampler::get_ternary_matrix(mat);
|
||||||
|
|
||||||
|
assert(mat.size() == n && mat[0].size() == N);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
for (int j = 0; j < N; j++)
|
||||||
|
assert((row[j]==0) || (row[j]==-1) || (row[j]==1) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int n = pLWE.n;
|
||||||
|
vector<int> vec(n, 0L);
|
||||||
|
double st_dev = 4.0;
|
||||||
|
Sampler::get_gaussian_vector(vec, st_dev);
|
||||||
|
|
||||||
|
assert(vec.size() == n);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
assert(conv<double>(abs(vec[i])) < 6*st_dev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int n = pNTRU.n;
|
||||||
|
vector<vector<int>> mat(n, vector<int>(N,0L));
|
||||||
|
double st_dev = 4.0;
|
||||||
|
Sampler::get_gaussian_matrix(mat, st_dev);
|
||||||
|
|
||||||
|
assert(mat.size() == n && mat[0].size() == N);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
vector<int>& row = mat[i];
|
||||||
|
for (int j = 0; j < N; j++)
|
||||||
|
assert(conv<double>(abs(row[j])) < 6*st_dev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<int> vec_inv(N,0L);
|
||||||
|
vector<int> vec(N,0L);
|
||||||
|
Sampler s(pNTRU);
|
||||||
|
s.get_invertible_vector(vec, vec_inv, 4, 1);
|
||||||
|
|
||||||
|
assert(vec.size() == N && vec_inv.size() == N);
|
||||||
|
assert((vec[0]==1) || (vec[0]==-3) || (vec[0]==5) );
|
||||||
|
for (int i = 1; i < N; i++)
|
||||||
|
{
|
||||||
|
assert((vec[i]==0) || (vec[i]==-4) || (vec[i]==4));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int n = pLWE.n;
|
||||||
|
vector<vector<int>> mat_inv(n, vector<int>(n,0L));
|
||||||
|
vector<vector<int>> mat(n, vector<int>(n,0L));
|
||||||
|
Sampler s(pLWE);
|
||||||
|
s.get_invertible_matrix(mat, mat_inv, 5, 1);
|
||||||
|
|
||||||
|
assert(mat.size() == n && mat[0].size() == n
|
||||||
|
&& mat_inv.size() == n && mat_inv[0].size() == n);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
assert((mat[i][i]==1) || (mat[i][i]==-4) || (mat[i][i]==6) );
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
for (int j = 0; (j < n) && (j != i); j++)
|
||||||
|
{
|
||||||
|
assert((mat[i][j]==0) || (mat[i][j]==-5) || (mat[i][j]==5) );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "SAMPLER IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntru_key_gen()
|
||||||
|
{
|
||||||
|
Param param(NTRU);
|
||||||
|
int n = param.n;
|
||||||
|
int Nl = param.Nl;
|
||||||
|
int half_q_base = param.half_q_base;
|
||||||
|
int q_base = param.q_base;
|
||||||
|
int l_ksk = param.l_ksk;
|
||||||
|
int N = Param::N;
|
||||||
|
int t = Param::t;
|
||||||
|
int B_ksk = Param::B_ksk;
|
||||||
|
int B_bsk_size = Param::B_bsk_size;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
|
||||||
|
SKey_base_NTRU sk_base;
|
||||||
|
KeyGen k(param);
|
||||||
|
k.get_sk_base(sk_base);
|
||||||
|
cout << "Secret key of the base scheme is generated" << endl;
|
||||||
|
assert(sk_base.sk.size() == n && sk_base.sk[0].size() == n
|
||||||
|
&& sk_base.sk_inv.size() == n && sk_base.sk_inv[0].size() == n);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
{
|
||||||
|
assert((sk_base.sk[i][j]==0) || (sk_base.sk[i][j]==-1) || (sk_base.sk[i][j]==1) );
|
||||||
|
}
|
||||||
|
|
||||||
|
SKey_boot sk_boot;
|
||||||
|
k.get_sk_boot(sk_boot);
|
||||||
|
cout << "Secret key of the bootstrapping scheme is generated" << endl;
|
||||||
|
assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N);
|
||||||
|
assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1)));
|
||||||
|
for (int i = 1; i < N; i++)
|
||||||
|
{
|
||||||
|
assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) );
|
||||||
|
}
|
||||||
|
|
||||||
|
KSKey_NTRU ksk;
|
||||||
|
k.get_ksk(ksk, sk_base, sk_boot);
|
||||||
|
cout << "Key-switching key is generated" << endl;
|
||||||
|
assert(ksk.size() == Nl && ksk[0].size() == n);
|
||||||
|
for (int i = 0; i < Nl; i++)
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
{
|
||||||
|
//cout << ksk[i][j] << endl;
|
||||||
|
assert(ksk[i][j] <= half_q_base && ksk[i][j] >= -half_q_base);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> q4_decomp;
|
||||||
|
decompose(q4_decomp, q_base/4, B_ksk, l_ksk);
|
||||||
|
vector<int> ks_res(n,0L);
|
||||||
|
for (int i = 0; i < l_ksk; i++)
|
||||||
|
{
|
||||||
|
int tmp_int = q4_decomp[i];
|
||||||
|
vector<int>& ksk_row = ksk[i];
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
{
|
||||||
|
ks_res[j] += ksk_row[j] * tmp_int;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
param.mod_q_base(ks_res);
|
||||||
|
int ks_int = 0;
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
ks_int += ks_res[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
ks_int = param.mod_q_base(ks_int);
|
||||||
|
ks_int = int(round(double(ks_int*4)/double(q_base)));
|
||||||
|
assert(ks_int == 1L);
|
||||||
|
|
||||||
|
// bootstrapping key test
|
||||||
|
BSKey_NTRU bsk;
|
||||||
|
k.get_bsk(bsk, sk_base, sk_boot);
|
||||||
|
cout << "Bootstrapping key is generated" << endl;
|
||||||
|
|
||||||
|
// check dimensions
|
||||||
|
assert(bsk.size() == B_bsk_size);
|
||||||
|
for (int i = 0; i < bsk.size(); i++)
|
||||||
|
{
|
||||||
|
assert(bsk[i].size() == param.bsk_partition[i]);
|
||||||
|
for (int j = 0; j < bsk[i].size(); j++)
|
||||||
|
{
|
||||||
|
assert(bsk[i][j].size() == 2);
|
||||||
|
assert(bsk[i][j][0].size() == param.l_bsk[i]);
|
||||||
|
assert(bsk[i][j][1].size() == param.l_bsk[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// convert sk_boot to FFT
|
||||||
|
vector<complex<double>> sk_boot_fft(N2p1);
|
||||||
|
fftN.to_fft(sk_boot_fft, sk_boot.sk);
|
||||||
|
|
||||||
|
int coef_counter = 0;
|
||||||
|
for (int iBase = 0; iBase < B_bsk_size; iBase++)
|
||||||
|
{
|
||||||
|
decompose(q4_decomp, q_boot/4, param.B_bsk[iBase], param.l_bsk[iBase]);
|
||||||
|
for (size_t iCoef = 0; iCoef < bsk[iBase].size(); iCoef++)
|
||||||
|
{
|
||||||
|
int sk_coef = 0;
|
||||||
|
int sk_base_coef_bits[2];
|
||||||
|
for (int iBit = 0; iBit < 2; iBit++)
|
||||||
|
{
|
||||||
|
vector<complex<double>> tmp_fft(N2p1, complex<double>(0.0,0.0));
|
||||||
|
for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++)
|
||||||
|
{
|
||||||
|
tmp_fft = tmp_fft + bsk[iBase][iCoef][iBit][iPart] * q4_decomp[iPart];
|
||||||
|
}
|
||||||
|
tmp_fft = tmp_fft * sk_boot_fft;
|
||||||
|
vector<int> tmp_int;
|
||||||
|
vector<long> tmp_long;
|
||||||
|
fftN.from_fft(tmp_long, tmp_fft);
|
||||||
|
mod_q_boot(tmp_int, tmp_long);
|
||||||
|
sk_base_coef_bits[iBit] = int(round(double(tmp_int[0]*4)/double(q_boot)));
|
||||||
|
}
|
||||||
|
if (sk_base_coef_bits[1] == 1)
|
||||||
|
sk_coef = -1;
|
||||||
|
else if (sk_base_coef_bits[0] == 1)
|
||||||
|
sk_coef = 1;
|
||||||
|
|
||||||
|
assert(sk_coef == sk_base.sk[coef_counter + iCoef][0]);
|
||||||
|
}
|
||||||
|
coef_counter += param.bsk_partition[iBase];
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "KEYGEN IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwe_key_gen()
|
||||||
|
{
|
||||||
|
Param param(LWE);
|
||||||
|
int n = param.n;
|
||||||
|
int N = Param::N;
|
||||||
|
int t = Param::t;
|
||||||
|
|
||||||
|
SKey_base_LWE sk_base;
|
||||||
|
KeyGen k(param);
|
||||||
|
k.get_sk_base(sk_base);
|
||||||
|
cout << "Secret key of the base scheme is generated" << endl;
|
||||||
|
assert(sk_base.size() == n);
|
||||||
|
for (int j = 0; j < n; j++)
|
||||||
|
{
|
||||||
|
assert((sk_base[j]==0) || (sk_base[j]==1));
|
||||||
|
}
|
||||||
|
|
||||||
|
SKey_boot sk_boot;
|
||||||
|
k.get_sk_boot(sk_boot);
|
||||||
|
cout << "Secret key of the bootstrapping scheme is generated" << endl;
|
||||||
|
assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N);
|
||||||
|
assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1)));
|
||||||
|
for (int i = 1; i < N; i++)
|
||||||
|
{
|
||||||
|
assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) );
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "KEYGEN IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_fft()
|
||||||
|
{
|
||||||
|
int N = Param::N;
|
||||||
|
int N2p1 = Param::N2p1;
|
||||||
|
|
||||||
|
FFT_engine fft_engine(N);
|
||||||
|
{
|
||||||
|
vector<int> in(N,0L);
|
||||||
|
vector<complex<double>> out(N2p1);
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.to_fft(out, in);
|
||||||
|
cout << "Forward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
for (size_t i = 0; i < N/2; i++)
|
||||||
|
{
|
||||||
|
if (int(round(real(out[i])))!=0 || int(round(imag(out[i])))!=0)
|
||||||
|
{
|
||||||
|
cout << i << " " << out[i] << endl;
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<long> out;
|
||||||
|
vector<complex<double>> in(N2p1, complex<double>(0.0,0.0));
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.from_fft(out, in);
|
||||||
|
cout << "Backward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
for (size_t i = 0; i < N/2; i++)
|
||||||
|
{
|
||||||
|
assert(out[i] == 0L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<int> in(N,0L);
|
||||||
|
in[0] = 1L;
|
||||||
|
vector<complex<double>> out(N2p1);
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.to_fft(out, in);
|
||||||
|
cout << "Forward FFT (1,0,...0): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
for (size_t i = 0; i < N/2; i++)
|
||||||
|
{
|
||||||
|
assert(int(round(real(out[i])))==1 && int(round(imag(out[i])))==0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
vector<long> out;
|
||||||
|
vector<complex<double>> in(N2p1, complex<double>(1.0,0.0));
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.from_fft(out, in);
|
||||||
|
cout << "Backward FFT (1,1,...1): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
assert(out[0] == 1L);
|
||||||
|
for (size_t i = 1; i < N; i++)
|
||||||
|
{
|
||||||
|
assert(out[i] == 0L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
uniform_int_distribution<int> sampler(INT_MIN, INT_MAX);
|
||||||
|
int coef = sampler(rand_engine);
|
||||||
|
vector<long> out;
|
||||||
|
vector<complex<double>> in(N2p1, complex<double>(double(coef),0.0));
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.from_fft(out, in);
|
||||||
|
cout << "Backward FFT (a,a,...a): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
assert(out[0] == coef);
|
||||||
|
for (size_t i = 1; i < N; i++)
|
||||||
|
{
|
||||||
|
assert(out[i] == 0L);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
uniform_int_distribution<int> sampler(INT_MIN, INT_MAX);
|
||||||
|
vector<int> in;
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
in.push_back(sampler(rand_engine));
|
||||||
|
vector<complex<double>> interm(N2p1);
|
||||||
|
vector<long> out;
|
||||||
|
clock_t start = clock();
|
||||||
|
fft_engine.to_fft(interm, in);
|
||||||
|
cout << "Forward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
start = clock();
|
||||||
|
fft_engine.from_fft(out, interm);
|
||||||
|
cout << "Backward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
for (size_t i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
//cout << "i: " << i << "in[i]: " << in[i] << " out[i]: " << out[i] << endl;
|
||||||
|
assert(in[i] == out[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
uniform_int_distribution<int> sampler(-100, 100);
|
||||||
|
vector<int> in1, in2, res;
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
in1.push_back(sampler(rand_engine));
|
||||||
|
in2.push_back(sampler(rand_engine));
|
||||||
|
res.push_back(in1[i]+in2[i]);
|
||||||
|
}
|
||||||
|
vector<complex<double>> interm1(N2p1);
|
||||||
|
vector<complex<double>> interm2(N2p1);
|
||||||
|
vector<complex<double>> intermres(N2p1);
|
||||||
|
vector<long> out;
|
||||||
|
|
||||||
|
fft_engine.to_fft(interm1, in1);
|
||||||
|
fft_engine.to_fft(interm2, in2);
|
||||||
|
|
||||||
|
clock_t start = clock();
|
||||||
|
intermres = interm1 + interm2;
|
||||||
|
cout << "FFT addition: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
fft_engine.from_fft(out, intermres);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
//cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl;
|
||||||
|
assert(res[i] == out[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
uniform_int_distribution<int> sampler(-100, 100);
|
||||||
|
vector<int> in1, in2, res;
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
in1.push_back(sampler(rand_engine));
|
||||||
|
in2.push_back(sampler(rand_engine));
|
||||||
|
}
|
||||||
|
ZZX poly1, poly2, poly_res;
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
SetCoeff(poly1, i, in1[i]);
|
||||||
|
SetCoeff(poly2, i, in2[i]);
|
||||||
|
}
|
||||||
|
ZZX poly_mod;
|
||||||
|
SetCoeff(poly_mod, 0, 1);
|
||||||
|
SetCoeff(poly_mod, N, 1);
|
||||||
|
|
||||||
|
clock_t start = clock();
|
||||||
|
MulMod(poly_res, poly1, poly2, poly_mod);
|
||||||
|
cout << "NTL multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
res.push_back(conv<long>(poly_res[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<complex<double>> interm1(N2p1);
|
||||||
|
vector<complex<double>> interm2(N2p1);
|
||||||
|
vector<complex<double>> intermres(N2p1);
|
||||||
|
vector<long> out;
|
||||||
|
|
||||||
|
fft_engine.to_fft(interm1, in1);
|
||||||
|
fft_engine.to_fft(interm2, in2);
|
||||||
|
|
||||||
|
start = clock();
|
||||||
|
intermres = interm1 * interm2;
|
||||||
|
cout << "FFT multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl;
|
||||||
|
|
||||||
|
fft_engine.from_fft(out, intermres);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < N; i++)
|
||||||
|
{
|
||||||
|
//cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl;
|
||||||
|
assert(res[i] == out[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "FFT is OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntruhe_encrypt()
|
||||||
|
{
|
||||||
|
SchemeNTRU s;
|
||||||
|
|
||||||
|
{
|
||||||
|
int input = 0;
|
||||||
|
Ctxt_NTRU ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 1;
|
||||||
|
Ctxt_NTRU ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
cout << "NTRU ENCRYPTION IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwehe_encrypt()
|
||||||
|
{
|
||||||
|
SchemeLWE s;
|
||||||
|
|
||||||
|
{
|
||||||
|
int input = 0;
|
||||||
|
Ctxt_LWE ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 1;
|
||||||
|
Ctxt_LWE ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
cout << "LWE ENCRYPTION IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
void test_mod_switch()
|
||||||
|
{
|
||||||
|
SchemeNTRU s;
|
||||||
|
{
|
||||||
|
int input = 0;
|
||||||
|
Ctxt_NTRU ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
s.modulo_switch_to_base(ct.data);
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < ntru_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct.data[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%Param::N2;
|
||||||
|
if (output > Param::N)
|
||||||
|
output -= Param::N2;
|
||||||
|
else if (output <= -Param::N)
|
||||||
|
output += Param::N2;
|
||||||
|
output = int(round(double(output*t)/double(Param::N2)));
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 1;
|
||||||
|
ntru_he::Ctxt ct;
|
||||||
|
ntru_he::encrypt(ct, input, sk_base);
|
||||||
|
ntru_he::modulo_switch(ct, ntru_he::q_base, Param::N2);
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < ntru_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%Param::N2;
|
||||||
|
if (output > Param::N)
|
||||||
|
output -= Param::N2;
|
||||||
|
else if (output <= -Param::N)
|
||||||
|
output += Param::N2;
|
||||||
|
output = int(round(double(output*t)/double(Param::N2)));
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 0;
|
||||||
|
ntru_he::Ctxt ct;
|
||||||
|
ntru_he::encrypt(ct, input, sk_base);
|
||||||
|
ntru_he::modulo_switch_to_boot(ct);
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < ntru_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%Param::N2;
|
||||||
|
if (output > Param::N)
|
||||||
|
output -= Param::N2;
|
||||||
|
else if (output <= -Param::N)
|
||||||
|
output += Param::N2;
|
||||||
|
output = int(round(double(output*t)/double(Param::N2)));
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 1;
|
||||||
|
ntru_he::Ctxt ct;
|
||||||
|
ntru_he::encrypt(ct, input, sk_base);
|
||||||
|
ntru_he::modulo_switch_to_boot(ct);
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < ntru_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = output%Param::N2;
|
||||||
|
if (output > Param::N)
|
||||||
|
output -= Param::N2;
|
||||||
|
else if (output <= -Param::N)
|
||||||
|
output += Param::N2;
|
||||||
|
output = int(round(double(output*t)/double(Param::N2)));
|
||||||
|
assert(output == input);
|
||||||
|
}
|
||||||
|
cout << "MODULO SWITCHING IS OK" << endl;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
void test_bootstrap()
|
||||||
|
{
|
||||||
|
SchemeNTRU s;
|
||||||
|
{
|
||||||
|
int input = 2;
|
||||||
|
Ctxt_NTRU ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
|
||||||
|
s.bootstrap(ct);
|
||||||
|
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
cout << "Bootstrapping output: " << output << endl;
|
||||||
|
assert(output == 1L);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
int input = 0;
|
||||||
|
Ctxt_NTRU ct;
|
||||||
|
s.encrypt(ct, input);
|
||||||
|
|
||||||
|
s.bootstrap(ct);
|
||||||
|
|
||||||
|
int output = s.decrypt(ct);
|
||||||
|
cout << "Bootstrapping output: " << output << endl;
|
||||||
|
assert(output == 0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << "BOOTSTRAPPING IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
void test_nand_aux()
|
||||||
|
{
|
||||||
|
ntru_he::SKey_base sk_base;
|
||||||
|
ntru_he::get_sk_base(sk_base);
|
||||||
|
|
||||||
|
ntru_he::Ctxt ct;
|
||||||
|
ntru_he::get_nand_aux(ct, sk_base);
|
||||||
|
int output = 0;
|
||||||
|
for (int i = 0; i < ntru_he::n; i++)
|
||||||
|
{
|
||||||
|
output += ct[i] * sk_base.sk[i][0];
|
||||||
|
}
|
||||||
|
output = ntru_he::mod_q_base(output);
|
||||||
|
assert(
|
||||||
|
output == (ntru_he::nand_const-ntru_he::q_base)
|
||||||
|
|| output == (ntru_he::nand_const-ntru_he::q_base+1)
|
||||||
|
|| output == (ntru_he::nand_const-ntru_he::q_base-1)
|
||||||
|
);
|
||||||
|
cout << "NAND ENCRYPTION IS OK" << endl;
|
||||||
|
}*/
|
||||||
|
|
||||||
|
enum GateType {NAND, AND, OR};
|
||||||
|
|
||||||
|
void test_ntruhe_gate_helper(int in1, int in2, const SchemeNTRU& s, GateType g)
|
||||||
|
{
|
||||||
|
float avg_time = 0.0;
|
||||||
|
for (int i = 0; i < 100; i++)
|
||||||
|
{
|
||||||
|
Ctxt_NTRU ct_res, ct1, ct2, ct_nand;
|
||||||
|
s.encrypt(ct1, in1);
|
||||||
|
s.encrypt(ct2, in2);
|
||||||
|
|
||||||
|
if (g == NAND)
|
||||||
|
{
|
||||||
|
auto start = clock();
|
||||||
|
s.nand_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "NAND output: " << output << endl;
|
||||||
|
assert(output == !(in1 & in2));
|
||||||
|
}
|
||||||
|
else if (g == AND) {
|
||||||
|
auto start = clock();
|
||||||
|
s.and_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "AND output: " << output << endl;
|
||||||
|
assert(output == (in1 & in2));
|
||||||
|
}
|
||||||
|
else if (g == OR) {
|
||||||
|
auto start = clock();
|
||||||
|
s.or_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "OR output: " << output << endl;
|
||||||
|
assert(output == (in1 | in2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "Avg. time" << avg_time/100.0 << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntru_gate(GateType g)
|
||||||
|
{
|
||||||
|
SchemeNTRU s;
|
||||||
|
|
||||||
|
test_ntruhe_gate_helper(0, 0, s, g);
|
||||||
|
test_ntruhe_gate_helper(0, 1, s, g);
|
||||||
|
test_ntruhe_gate_helper(1, 0, s, g);
|
||||||
|
test_ntruhe_gate_helper(1, 1, s, g);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntruhe_nand()
|
||||||
|
{
|
||||||
|
GateType g = NAND;
|
||||||
|
|
||||||
|
test_ntru_gate(g);
|
||||||
|
|
||||||
|
cout << "NAND IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntruhe_and()
|
||||||
|
{
|
||||||
|
GateType g = AND;
|
||||||
|
|
||||||
|
test_ntru_gate(g);
|
||||||
|
|
||||||
|
cout << "AND IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ntruhe_or()
|
||||||
|
{
|
||||||
|
GateType g = OR;
|
||||||
|
|
||||||
|
test_ntru_gate(g);
|
||||||
|
|
||||||
|
cout << "OR IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwehe_gate_helper(int in1, int in2, SchemeLWE& s, GateType g)
|
||||||
|
{
|
||||||
|
float avg_time = 0.0;
|
||||||
|
for (int i = 0; i < 100; i++)
|
||||||
|
{
|
||||||
|
Ctxt_LWE ct_res, ct1, ct2, ct_nand;
|
||||||
|
s.encrypt(ct1, in1);
|
||||||
|
s.encrypt(ct2, in2);
|
||||||
|
|
||||||
|
if (g == NAND)
|
||||||
|
{
|
||||||
|
auto start = clock();
|
||||||
|
s.nand_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "NAND output: " << output << endl;
|
||||||
|
assert(output == !(in1 & in2));
|
||||||
|
}
|
||||||
|
else if (g == AND) {
|
||||||
|
auto start = clock();
|
||||||
|
s.and_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "AND output: " << output << endl;
|
||||||
|
assert(output == (in1 & in2));
|
||||||
|
}
|
||||||
|
else if (g == OR) {
|
||||||
|
auto start = clock();
|
||||||
|
s.or_gate(ct_res, ct1, ct2);
|
||||||
|
avg_time += float(clock()-start)/CLOCKS_PER_SEC;
|
||||||
|
|
||||||
|
int output = s.decrypt(ct_res);
|
||||||
|
|
||||||
|
//cout << "OR output: " << output << endl;
|
||||||
|
assert(output == (in1 | in2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "Avg. time" << avg_time/100.0 << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwe_gate(GateType g)
|
||||||
|
{
|
||||||
|
SchemeLWE s;
|
||||||
|
|
||||||
|
test_lwehe_gate_helper(0, 0, s, g);
|
||||||
|
test_lwehe_gate_helper(0, 1, s, g);
|
||||||
|
test_lwehe_gate_helper(1, 0, s, g);
|
||||||
|
test_lwehe_gate_helper(1, 1, s, g);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwehe_nand()
|
||||||
|
{
|
||||||
|
GateType g = NAND;
|
||||||
|
|
||||||
|
test_lwe_gate(g);
|
||||||
|
|
||||||
|
cout << "NAND IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwehe_and()
|
||||||
|
{
|
||||||
|
GateType g = AND;
|
||||||
|
|
||||||
|
test_lwe_gate(g);
|
||||||
|
|
||||||
|
cout << "AND IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lwehe_or()
|
||||||
|
{
|
||||||
|
GateType g = OR;
|
||||||
|
|
||||||
|
test_lwe_gate(g);
|
||||||
|
|
||||||
|
cout << "OR IS OK" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
//test_params();
|
||||||
|
//test_sampler();
|
||||||
|
//test_ntru_key_gen();
|
||||||
|
//test_lwe_key_gen();
|
||||||
|
//test_fft();
|
||||||
|
//test_ntruhe_encrypt();
|
||||||
|
//test_lwehe_encrypt();
|
||||||
|
//test_mod_switch();
|
||||||
|
//test_bootstrap();
|
||||||
|
//test_nand_aux();
|
||||||
|
//test_ntruhe_nand();
|
||||||
|
test_lwehe_nand();
|
||||||
|
//test_ntruhe_and();
|
||||||
|
test_lwehe_and();
|
||||||
|
//test_ntruhe_or();
|
||||||
|
test_lwehe_or();
|
||||||
|
return 0;
|
||||||
|
}
|
Reference in New Issue
Block a user