From f7c29bedb116d963106b5ed799ff083ce3123af5 Mon Sep 17 00:00:00 2001 From: iliailia Date: Tue, 18 Jan 2022 15:18:50 +0100 Subject: [PATCH] Add files via upload --- Makefile | 26 ++ include/fft.h | 52 +++ include/keygen.h | 125 ++++++ include/lwehe.h | 153 ++++++++ include/ntruhe.h | 170 ++++++++ include/params.h | 333 ++++++++++++++++ include/sampler.h | 111 ++++++ src/fft.cpp | 175 +++++++++ src/keygen.cpp | 443 +++++++++++++++++++++ src/lwehe.cpp | 549 ++++++++++++++++++++++++++ src/ntruhe.cpp | 416 ++++++++++++++++++++ src/sampler.cpp | 185 +++++++++ test.cpp | 960 ++++++++++++++++++++++++++++++++++++++++++++++ 13 files changed, 3698 insertions(+) create mode 100644 Makefile create mode 100644 include/fft.h create mode 100644 include/keygen.h create mode 100644 include/lwehe.h create mode 100644 include/ntruhe.h create mode 100644 include/params.h create mode 100644 include/sampler.h create mode 100644 src/fft.cpp create mode 100644 src/keygen.cpp create mode 100644 src/lwehe.cpp create mode 100644 src/ntruhe.cpp create mode 100644 src/sampler.cpp create mode 100644 test.cpp diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6fc87ab --- /dev/null +++ b/Makefile @@ -0,0 +1,26 @@ +CCX = g++ +CCXFLAGS = -O3 -funroll-loops -march=native -std=c++11 -pthread -I. -I./include +DEPS = -lntl -lgmp -lfftw3 -lm + +all: clean test + +clean: + $(RM) test test.o lwehe.o ntruhe.o fft.o sampler.o keygen.o + +test: include/params.h ntruhe.o lwehe.o keygen.o fft.o sampler.o + $(CCX) $(CCXFLAGS) -o test test.cpp ntruhe.o lwehe.o keygen.o fft.o sampler.o $(DEPS) + +ntruhe.o: include/ntruhe.h keygen.o sampler.o lwehe.o src/ntruhe.cpp + $(CCX) $(CCXFLAGS) -c src/ntruhe.cpp + +lwehe.o: include/lwehe.h keygen.o sampler.o src/lwehe.cpp + $(CCX) $(CCXFLAGS) -c src/lwehe.cpp + +keygen.o: include/keygen.h sampler.o fft.o src/keygen.cpp + $(CCX) $(CCXFLAGS) -c src/keygen.cpp + +fft.o: include/fft.h + $(CCX) $(CCXFLAGS) -c src/fft.cpp + +sampler.o: include/sampler.h include/params.h src/sampler.cpp + $(CCX) $(CCXFLAGS) -c src/sampler.cpp \ No newline at end of file diff --git a/include/fft.h b/include/fft.h new file mode 100644 index 0000000..8a77833 --- /dev/null +++ b/include/fft.h @@ -0,0 +1,52 @@ +#ifndef FFT +#define FFT + +#include +#include +#include +#include +#include "params.h" + +using namespace std; + +//TODO: description + +class FFT_engine +{ + int fft_dim; + int fft_dim2; + + fftw_plan plan_to_fft; + fftw_plan plan_from_fft; + + double* in_array; + fftw_complex* out_array; + +public: + //map> x_powers; + vector pos_powers; + vector neg_powers; + + FFT_engine() = delete; + FFT_engine(const int dim); + + void to_fft(FFTPoly& out, const ModQPoly& in) const; + void from_fft(vector& out, const FFTPoly& in) const; + + ~FFT_engine(); +}; + +FFTPoly operator *(const FFTPoly& a, const FFTPoly& b); +void operator *=(FFTPoly& a, const FFTPoly& b); +FFTPoly operator *(const FFTPoly& a, const int b); +FFTPoly operator +(const FFTPoly& a, const FFTPoly& b); +void operator +=(FFTPoly& a, const FFTPoly& b); +void operator +=(FFTPoly& a, const complex b); +FFTPoly operator -(const FFTPoly& a, const FFTPoly& b); +void operator -=(FFTPoly& a, const FFTPoly& b); + +// global FFT engine of dimension N +const FFT_engine fftN(Param::N); + + +#endif \ No newline at end of file diff --git a/include/keygen.h b/include/keygen.h new file mode 100644 index 0000000..2fe0cbb --- /dev/null +++ b/include/keygen.h @@ -0,0 +1,125 @@ +#ifndef KEYGEN +#define KEYGEN + +#include +#include +#include +#include +#include "params.h" +#include "sampler.h" + +// secret key of the bootstrapping scheme +typedef struct { + ModQPoly sk; + ModQPoly sk_inv; +} SKey_boot; + +// secret key of the NTRU base scheme +typedef struct { + ModQMatrix sk; + ModQMatrix sk_inv; +} SKey_base_NTRU; + +// secret key of the LWE base scheme +typedef std::vector SKey_base_LWE; + +/** + * Bootstrapping key. + * It consists of several sets of keys corresponding to different + * decomposition bases of the bootstrapping key B_bsk. + * The i-th set contains vectors with l_bsk[i] complex vectors. + * These complex vectors are an encryption of some bit of the secret key + * of the base scheme in the NGS form. + */ +typedef std::vector>> BSKey_NTRU; + +/** + * Bootstrapping key. + * It consists of several sets of keys corresponding to different + * decomposition bases of the bootstrapping key B_bsk. + * The i-th set contains vectors with l_bsk[i] complex vectors. + * These complex vectors are an encryption of some bit of the secret key + * of the base scheme in the NGS form. + */ +typedef std::vector> BSKey_LWE; + +// key-switching key from NTRU to NTRU +typedef ModQMatrix KSKey_NTRU; + +// key-switching key from NTRU to LWE +typedef struct{ + ModQMatrix A; + std::vector b; +} KSKey_LWE; + + +class KeyGen +{ + Param param; + Sampler sampler; + + public: + + KeyGen(Param _param): param(_param), sampler(_param) + {} + + /** + * Generate a secret key of the bootstrapping scheme. + * @param[out] sk_boot secret key of the bootstrapping scheme. + */ + void get_sk_boot(SKey_boot& sk_boot); + + /** + * Generate a secret key of the base scheme. + * @param[out] sk_base secret key of the base scheme. + */ + void get_sk_base(SKey_base_NTRU& sk_base); + + /** + * Generate a secret key of the base scheme. + * @param[out] sk_base secret key of the base scheme. + */ + void get_sk_base(SKey_base_LWE& sk_base); + + /** + * Generate a key-switching key from the bootstrapping scheme to the base scheme. + * @param[out] ksk key-switching key. + * @param[in] sk_base secret key of the base scheme. + * @param[in] sk_boot secret key of the bootstrapping scheme. + */ + void get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot); + + /** + * Generate a key-switching key from the bootstrapping scheme to the base scheme. + * @param[out] ksk key-switching key. + * @param[in] sk_base secret key of the base scheme. + * @param[in] sk_boot secret key of the bootstrapping scheme. + */ + void get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot); + + /** + * Generate a bootstrapping key + * @param[out] bsk bootstrapping key. + * @param[in] sk_base secret key of the base scheme. + * @param[in] sk_boot secret key of the bootstrapping scheme. + */ + void get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot); + + /** + * Generate a bootstrapping key + * @param[out] bsk bootstrapping key. + * @param[in] sk_base secret key of the base scheme. + * @param[in] sk_boot secret key of the bootstrapping scheme. + */ + void get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot); + + /** + * Generate a bootstrapping key + * @param[out] bsk bootstrapping key. + * @param[in] sk_base secret key of the base scheme. + * @param[in] sk_boot secret key of the bootstrapping scheme. + */ + void get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot); +}; + +#endif \ No newline at end of file diff --git a/include/lwehe.h b/include/lwehe.h new file mode 100644 index 0000000..47690c2 --- /dev/null +++ b/include/lwehe.h @@ -0,0 +1,153 @@ +#ifndef LWEHE +#define LWEHE + +#include "params.h" +#include "keygen.h" + +class Ctxt_LWE +{ + public: + std::vector a; + int b; + + Ctxt_LWE() + { + a.clear(); + a.resize(parLWE.n); + } + Ctxt_LWE(const Ctxt_LWE& ct); + Ctxt_LWE& operator=(const Ctxt_LWE& ct); + + Ctxt_LWE operator +(const Ctxt_LWE& ct) const; + Ctxt_LWE operator -(const Ctxt_LWE& ct) const; +}; + +Ctxt_LWE operator -(const int c, const Ctxt_LWE& ct); + +/** + * Switches a given ciphertext to a given modulus. + * @param[in,out] ct ciphertext + * @param[in] old_q old modulus + * @param[in] new_q old modulus + */ +inline void modulo_switch_lwe(Ctxt_LWE& ct, int old_q, int new_q) +{ + std::vector& a = ct.a; + for (size_t i = 0; i < a.size(); i++) + a[i] = int((a[i]*new_q)/old_q); + + ct.b = int((ct.b*new_q)/old_q); +} + +/** + * Switches a given polynomial from q_base to modulus 2*N. + * @param[in,out] poly polynomial + */ +inline void modulo_switch_to_boot(Ctxt_LWE& poly) +{ + modulo_switch_lwe(poly, parLWE.q_base, Param::N2); +} + +/** + * Switches a given polynomial from q_boot to q_base. + * @param[in,out] poly polynomial + */ +inline void modulo_switch_to_base_lwe(ModQPoly& poly) +{ + modulo_switch(poly, q_boot, parLWE.q_base); +} + +/** + * Computes the external product of a given polynomial ciphertext + * with an NGS ciphertext in the FFT form + * @param[in,out] poly polynomial ciphertext + * @param[in] poly_vector NGS ciphertext + * @param[in] b decomposition base, power of 2 + * @param[in] shift bit shift to divide by b + * @param[in] l decomposition length + */ +void external_product(std::vector& res, const std::vector& poly, const std::vector& poly_vector, int b, int shift, int l); + +class SchemeLWE +{ + SKey_base_LWE sk_base; + SKey_boot sk_boot; + KSKey_LWE ksk; + BSKey_LWE bk; + + public: + + SchemeLWE() + { + KeyGen keygen(parLWE); + //sampler = Sampler(param); + + keygen.get_sk_base(sk_base); + keygen.get_sk_boot(sk_boot); + keygen.get_ksk(ksk,sk_base,sk_boot); + keygen.get_bsk(bk,sk_base,sk_boot); + } + /** + * Encrypts a bit using LWE. + * @param[out] ct ciphertext encrypting the input bit + * @param[in] m bit to encrypt + */ + void encrypt(Ctxt_LWE& ct, int m) const; + + /** + * Decrypts a bit using LWE. + * @param[out] ct ciphertext encrypting a bit + * @return b bit + */ + int decrypt(const Ctxt_LWE& ct) const; + + /** + * Performs key switching of a given ciphertext from a polynomial NTRU + * to LWE + * @param[out] ct LWE ciphertext (vector of dimension n) + * @param[in] poly polynomial ciphertext (vector of dimension N) + */ + void key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const; + + /** + * Bootstrapps a given ciphertext + * @param[in,out] ct ciphertext to bootstrap + */ + void bootstrap(Ctxt_LWE& ct) const; + + /** + * Bootstrapps a given ciphertext + * @param[in,out] ct ciphertext to bootstrap + */ + void bootstrap2(Ctxt_LWE& ct) const; + + /** + * Computes the NAND gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void nand_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const; + + /** + * Computes the AND gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void and_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const; + + /** + * Computes the OR gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void or_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const; +}; + + + + + +#endif \ No newline at end of file diff --git a/include/ntruhe.h b/include/ntruhe.h new file mode 100644 index 0000000..a306ee9 --- /dev/null +++ b/include/ntruhe.h @@ -0,0 +1,170 @@ +#ifndef NTRUHE +#define NTRUHE + +#include "params.h" +#include "keygen.h" + +class Ctxt_NTRU +{ + public: + std::vector data; + + Ctxt_NTRU() + { + data.clear(); + data.resize(parNTRU.n); + } + Ctxt_NTRU(const Ctxt_NTRU& ct); + Ctxt_NTRU& operator=(const Ctxt_NTRU& ct); + + Ctxt_NTRU operator +(const Ctxt_NTRU& ct) const; + Ctxt_NTRU operator -(const Ctxt_NTRU& ct) const; + void operator -=(const Ctxt_NTRU& ct); +}; + +/** + * Switches a given ciphertext to a given modulus. + * @param[in,out] ct ciphertext + * @param[in] old_q old modulus + * @param[in] new_q old modulus + */ +inline void modulo_switch_ntru(Ctxt_NTRU& ct, int old_q, int new_q) +{ + std::vector& a = ct.data; + for (size_t i = 0; i < a.size(); i++) + a[i] = int((a[i]*new_q)/old_q); +} + +/** + * Switches a given polynomial from q_base to modulus 2*N. + * @param[in,out] poly polynomial + */ +inline void modulo_switch_to_boot(Ctxt_NTRU& poly) +{ + modulo_switch_ntru(poly, parNTRU.q_base, Param::N2); +} + +/** + * Switches a given polynomial from q_boot to q_base. + * @param[in,out] poly polynomial + */ +inline void modulo_switch_to_base_ntru(ModQPoly& poly) +{ + modulo_switch(poly, q_boot, parNTRU.q_base); +} + +/** + * Computes the external product of a given polynomial ciphertext + * with an NGS ciphertext in the FFT form + * @param[in,out] poly polynomial ciphertext + * @param[in] poly_vector NGS ciphertext + * @param[in] b decomposition base, power of 2 + * @param[in] shift bit shift to divide by b + * @param[in] l decomposition length + */ +//void external_product(std::vector& res, const std::vector& poly, const std::vector& poly_vector, const int b, const int shift, const int l); + +class SchemeNTRU +{ + SKey_base_NTRU sk_base; + SKey_boot sk_boot; + KSKey_NTRU ksk; + BSKey_NTRU bk; + + Ctxt_NTRU ct_nand_const; + Ctxt_NTRU ct_and_const; + Ctxt_NTRU ct_or_const; + + void mask_constant(Ctxt_NTRU& ct, int constant); + + inline void set_nand_const() + { + //clock_t start = clock(); + mask_constant(ct_nand_const, parNTRU.nand_const); + //cout << "Encryption of NAND: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + } + + inline void set_and_const() + { + //clock_t start = clock(); + mask_constant(ct_and_const, parNTRU.and_const); + //cout << "Encryption of AND: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + } + + inline void set_or_const() + { + //clock_t start = clock(); + mask_constant(ct_or_const, parNTRU.or_const); + //cout << "Encryption of OR: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + } + + public: + + SchemeNTRU() + { + KeyGen keygen(parNTRU); + //sampler = Sampler(param); + + keygen.get_sk_base(sk_base); + keygen.get_sk_boot(sk_boot); + keygen.get_ksk(ksk,sk_base,sk_boot); + keygen.get_bsk(bk,sk_base,sk_boot); + + set_nand_const(); + set_and_const(); + set_or_const(); + } + /** + * Encrypts a bit using matrix NTRU. + * @param[out] ct ciphertext encrypting the input bit + * @param[in] b bit to encrypt + */ + void encrypt(Ctxt_NTRU& ct, const int b) const; + + /** + * Decrypts a bit using matrix NTRU. + * @param[out] ct ciphertext encrypting a bit + * @return b bit + */ + int decrypt(const Ctxt_NTRU& ct) const; + + /** + * Performs key switching of a given ciphertext from a polynomial NTRU + * to a matrix NTRU + * @param[out] ct matrix NTRU ciphertext (vector of dimension n) + * @param[in] poly polynomial ciphertext (vector of dimension N) + */ + void key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const; + + /** + * Bootstrapps a given ciphertext + * @param[in,out] ct ciphertext to bootstrap + */ + void bootstrap(Ctxt_NTRU& ct) const; + + /** + * Computes the NAND gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void nand_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const; + + /** + * Computes the AND gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void and_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const; + + /** + * Computes the OR gate of two given ciphertexts ct1 and ct2 + * @param[out] ct_res encryptions of the outuput of the NAND gate + * @param[in] ct_1 encryption of the first input bit + * @param[in] ct_2 encryption of the second input bit + */ + void or_gate(Ctxt_NTRU& ct_res, const Ctxt_NTRU& ct1, const Ctxt_NTRU& ct2) const; +}; + +#endif \ No newline at end of file diff --git a/include/params.h b/include/params.h new file mode 100644 index 0000000..9cb61af --- /dev/null +++ b/include/params.h @@ -0,0 +1,333 @@ +#ifndef PARAMS +#define PARAMS + +#include +#include +#include +#include +#include + +using namespace NTL; + +enum SchemeType {NTRU, LWE}; + +// representation of a polynomial modulo some integer +typedef std::vector ModQPoly; +// matrix modulo some integer +typedef std::vector> ModQMatrix; +// representation of an FFT transformation of some poly +typedef std::vector> FFTPoly; +// NGS ciphertest in NTT form +typedef std::vector NGSFFTctxt; + +/** + * Reduction modulo q in the symmetric interval (-q/2, q/2] + * Correct only for odd q. + * @param[in] input integer to reduce. + * @param[in] q modulus + * @param[in] half_q (q-1)/2 + * @returns reduced integer in the symmetric interval (-q/2, q/2] + */ +inline long lazy_mod_q(const long input, const long q, const long half_q) +{ + int coef = input%q; + if (coef > half_q) + return coef - q; + if (coef < -half_q) + return coef + q; + return coef; +} + +inline int lazy_mod_q(const int input, const int q, const int half_q) +{ + int coef = input%q; + if (coef > half_q) + return coef - q; + if (coef < -half_q) + return coef + q; + return coef; +} + +inline void lazy_mod_q(std::vector& output, const std::vector& input, const long q, const long half_q) +{ + assert(output.size() == input.size()); + + std::vector::iterator oit = output.begin(); + for (auto iit = input.begin(); iit < input.end(); iit++, oit++) + *oit = static_cast(lazy_mod_q(*iit, q, half_q)); +} + +/** ciphertext modulus of the ring-based scheme used + * for bootstrapping keys and test vectors/lookup-table encodings + */ +const int q_boot = 912829; // ~2^19.8, prime +const long q_boot_long = long(q_boot); + +//half of the above modulus +const int half_q_boot = q_boot/2; +const long half_q_boot_long = long(half_q_boot); + +/** + * Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2] + * Correct only for odd q_boot. + * @param[in] input integer to reduce. + * @returns reduced integer in the symmetric interval (-q_boot/2, q_boot/2] + */ +inline int mod_q_boot(const int input) +{ + return lazy_mod_q(input, q_boot, half_q_boot); +} + +inline int mod_q_boot(const long input) +{ + return lazy_mod_q(input, long(q_boot), long(half_q_boot)); +} + +/** + * Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2] + * Correct only for odd q_boot. + * @param[in,out] input vector to reduce. + * @param[in] q integer modulus. + */ +inline void mod_q_boot(std::vector& input) +{ + for (auto it = input.begin(); it < input.end(); it++) + *it = mod_q_boot(*it); +} + +/** + * Reduction modulo q_boot in the symmetric interval (-q_boot/2, q_boot/2] + * Correct only for odd q_boot. + * @param[in,out] input vector to reduce. + * @param[in] q integer modulus. + */ +inline void mod_q_boot(std::vector& output, std::vector& input) +{ + lazy_mod_q(output, input, q_boot_long, half_q_boot_long); +} + +class Param +{ + public: + // decomposition base of key-switching keys (DO NOT CHANGE!) + const static int B_ksk = 3; + + // dimension of bootstrapping keys and test vectors/lookup-table encodings + const static int N = 1024; + // N/2 + 1, needed for FFT + const static int N2p1 = N/2+1; + // order of the cyclotomic ring, which is used for the bootstrapping scheme + const static int N2 = 2*N; + + static ZZ_pX get_def_poly() + { + ZZ_pX poly; + // polynomial modulus of the ring-based scheme + //element of Z_(q_boot) + ZZ_p coef; + coef.init(ZZ(q_boot)); + //polynomial modulus X^N+1 + coef = 1; + SetCoeff(poly, 0, coef); + SetCoeff(poly, N, coef); + + return poly; + } + + ZZ_pX xToNplus1; + + const static int B_bsk_size = 2; + + // plaintext modulus + const static int t = 4; + + // Delta scalar used in bootstrapping + const static int half_delta_boot = q_boot/(2*t); + + // standard deviation for discrete Gaussian distribution + constexpr static double e_st_dev = 4.39; + + // Type of the base scheme + SchemeType scheme_type; + + // ciphertext modulus of the base scheme used for encryption + int q_base; + //half of the above modulus + int half_q_base; + // dimension of the ciphertext space + int n; + + // decomposition base of key-switching keys + int l_ksk; + + // number of rows of key-switching key matrices + int Nl; + + // decomposition bases of bootstrapping keys + int B_bsk[B_bsk_size]; + // binary logarithms of decomposition bases + int shift_bsk[B_bsk_size]; + // partition of bootstrapping keys per decomposition base + int bsk_partition[B_bsk_size]; + + // decomposition lengths of bootstrapping keys + int l_bsk[2]; + + // Delta scalars used in encryption + int half_delta_base; + int delta_base; + + // NAND constant + int nand_const; + // AND constant + int and_const; + // OR constant + int or_const; + + void init() + { + if (scheme_type == SchemeType::NTRU) + { + q_base = 131071; // ~2^17, prime + n = 800; + + B_bsk[0] = 8; + B_bsk[1] = 16; + shift_bsk[0] = 3; + shift_bsk[1] = 4; + bsk_partition[0] = 750; + bsk_partition[1] = 50; + } + else if (scheme_type == SchemeType::LWE) + { + q_base = 92683; // ~2^16.5, prime + n = 610; + + B_bsk[0] = 8; + B_bsk[1] = 16; + shift_bsk[0] = 3; + shift_bsk[1] = 4; + bsk_partition[0] = 140; + bsk_partition[1] = 470; + } + + half_q_base = q_base/2; + l_ksk = int(ceil(log(double(q_base))/log(double(B_ksk)))); + Nl = N * l_ksk; + + for (size_t i = 0; i < 2; i++) + l_bsk[i] = int(ceil(log(double(q_boot))/log(double(B_bsk[i])))); + + half_delta_base = q_base/(2*t); + delta_base = 2*half_delta_base; + + nand_const = 5*half_delta_base; + and_const = half_delta_base; + or_const = 7*half_delta_base; + + xToNplus1 = get_def_poly(); + } + Param(){}; + Param(SchemeType _scheme_type): scheme_type(_scheme_type) + { + init(); + } + Param(const Param ¶m): scheme_type(param.scheme_type) + { + init(); + } + Param& operator=(const Param& param) + { + scheme_type = param.scheme_type; + init(); + return *this; + } + + bool operator == (const Param& param) const + { + return this->scheme_type == param.scheme_type; + } + + /** + * Reduction modulo q_base in the symmetric interval (-q_base/2, q_base/2] + * Correct only for odd q_base. + * @param[in] input integer to reduce. + * @returns reduced integer in the symmetric interval (-q_base/2, q_base/2] + */ + inline int mod_q_base(const int input) const + { + return lazy_mod_q(input, q_base, half_q_base); + } + inline int mod_q_base(const long input) const + { + return lazy_mod_q(input, long(q_base), long(half_q_base)); + } + + /** + * Reduction modulo q_base in the symmetric interval (-q_base/2, q_base/2] + * Correct only for odd q_base. + * @param[in,out] input vector to reduce. + */ + inline void mod_q_base(std::vector& input) const + { + for (auto it = input.begin(); it < input.end(); it++) + *it = mod_q_base(*it); + } + inline void mod_q_base(std::vector& output, std::vector& input) const + { + output.resize(input.size()); + for (size_t i = 0; i < input.size(); i++) + output[i] = mod_q_base(input[i]); + } +}; + +const Param parLWE(LWE); +const Param parNTRU(NTRU); + +/** + * Switches a given polynomial to a given modulus. + * @param[in,out] poly polynomial + * @param[in] old_q old modulus + * @param[in] new_q old modulus + */ +inline void modulo_switch(ModQPoly& poly, int old_q, int new_q) +{ + double ratio = double(new_q)/double(old_q); + for (auto it = poly.begin(); it < poly.end(); it++) + *it = int(round(double(*it)*ratio)); +} + +/** + * Balanced decomposition of an integer in base b. + * The length of decomposition must be l. + * @param[out] res decomposition vector + * @param[in] input integer to decompose + * @param[in] b decomposition base + * @param[in] l decomposition length + */ +inline void decompose(std::vector& res, const int input, const int b, const int l) +{ + res.clear(); + + int input_sign = (input < 0) ? -1: 1; + int input_rem = abs(input); + for (int i=0; i b) + { + res.push_back(input_sign * (digit - b)); + input_rem = (input_rem - digit)/b + 1; + } + else + { + res.push_back(input_sign * digit); + input_rem = (input_rem - digit)/b; + } + } + if (input_rem != 0) + throw std::overflow_error("Input is too big for given length\n"); +} + +#endif \ No newline at end of file diff --git a/include/sampler.h b/include/sampler.h new file mode 100644 index 0000000..e717505 --- /dev/null +++ b/include/sampler.h @@ -0,0 +1,111 @@ +#ifndef SAMPLER +#define SAMPLER + +#include +#include +#include +#include +#include "params.h" + +using namespace NTL; +using namespace std; + + +// random engine +static default_random_engine rand_engine(std::chrono::system_clock::now().time_since_epoch().count()); +// uniform distribution on the ternary set +static uniform_int_distribution ternary_sampler(-1,1); +// uniform distribution on the binary set +static uniform_int_distribution binary_sampler(0,1); + +class Sampler +{ + Param param; + uniform_int_distribution mod_q_base_sampler; + + public: + Sampler(Param _param): param(_param) + { + mod_q_base_sampler = uniform_int_distribution(-param.half_q_base, param.half_q_base); + } + + /** + * Generate a uniformly random vector modulo q_base. + * + * @param[out] vec vector with uniformly random coefficients. + */ + void get_uniform_vector(vector& vec); + + /** + * Generate a uniformly random matrix modulo q_base. + * + * @param[out] mat matrix with uniformly random coefficients. + */ + void get_uniform_matrix(vector>& mat); + + /** + * Generate a matrix of the form scale*M+shift*I + * where M has uniformly random ternary coefficients + * and I is an identity matrix. + * This matrix must be invertible + * @param[out] mat random matrix of the above form. + * @param[out] mat_inv inverse matrix. + * @param[in] scale scale in the above form. + * @param[in] shift shift in the above form + */ + void get_invertible_matrix(vector>& mat, vector>& mat_inv, int scale, int shift); + + /** + * Generate a uniformly random matrix with ternary coefficients. + * + * @param[out] mat matrix with ternary coefficients. + */ + static void get_ternary_matrix(vector>& mat); + + /** + * Generate a uniformly random vector with ternary coefficients. + * + * @param[out] vec vector with ternary coefficients. + */ + static void get_ternary_vector(vector& vec); + + /** + * Generate a uniformly random vector with binary coefficients. + * + * @param[out] vec vector with binary coefficients. + */ + static void get_binary_vector(vector& vec); + + /** + * Generate a random matrix with coefficients distributed + * according to the discrete Gaussian distribution + * with zero mean and standard deviation st_dev. + * @param[out] mat random matrix. + * @param[in] st_dev standard deviation. + */ + static void get_gaussian_matrix(vector>& mat, double st_dev); + + /** + * Generate a random vector with coefficients distributed + * according to the discrete Gaussian distribution + * with zero mean and standard deviation st_dev. + * @param[out] vec random vector. + * @param[in] st_dev standard deviation. + */ + static void get_gaussian_vector(vector& vec, double st_dev); + + /** + * Generate a vector of the form scale*v+shift + * where v has uniformly random ternary coefficients. + * The polynomial with the above vector of coefficients + * must be invertible modulo X^N+1 and q_boot. + * @param[out] vec random vector of the above form. + * @param[out] vec_inv coefficient vector of the polynomial inverse modulo X^N+1. + * @param[in] scale scale in the above form. + * @param[in] shift shift in the above form + */ + void get_invertible_vector(vector& vec, vector& vec_inv, int scale, int shift); +}; + + +#endif \ No newline at end of file diff --git a/src/fft.cpp b/src/fft.cpp new file mode 100644 index 0000000..40d9531 --- /dev/null +++ b/src/fft.cpp @@ -0,0 +1,175 @@ +#include "fft.h" +#include +#include +#include +#include + +FFT_engine::FFT_engine(const int dim): fft_dim(dim) +{ + assert(dim%2 == 0); + + fft_dim2 = (dim >> 1) + 1; + + in_array = (double*) fftw_malloc(sizeof(double) * 2*dim); + out_array = (fftw_complex*) fftw_malloc(sizeof(fftw_complex) * (dim + 2)); + plan_to_fft = fftw_plan_dft_r2c_1d(2*dim, in_array, out_array, FFTW_PATIENT); + plan_from_fft = fftw_plan_dft_c2r_1d(2*dim, out_array, in_array, FFTW_PATIENT); + + pos_powers = vector(dim,FFTPoly(fft_dim2)); + neg_powers = vector(dim,FFTPoly(fft_dim2)); + for(int i = 0; i < dim; i++) + { + ModQPoly x_power(dim,0); + //x_power[0] = -1; + x_power[i] += 1; + FFTPoly x_power_fft(fft_dim2); + to_fft(x_power_fft, x_power); + pos_powers[i] = x_power_fft; + + x_power[i] -= 2; + to_fft(x_power_fft, x_power); + neg_powers[i] = x_power_fft; + } + //x_powers.insert({{-1,neg_powers}, {1,pos_powers}}); +} + +void FFT_engine::to_fft(FFTPoly& out, const ModQPoly& in) const +{ + assert(out.size() == fft_dim2); + + double* in_arr = in_array; + fftw_complex* out_arr = out_array; + int N = fft_dim; + + for (int i = 0; i < N; ++i) + { + in_arr[i] = double(in[i]); + in_arr[i+N] = 0.0; + } + fftw_execute(plan_to_fft); + int tmp = 1; + //for (int i = 0; i < fft_dim2; i++) + for (auto it = out.begin(); it < out.end(); ++it) + { + fftw_complex& out_z = out_arr[tmp]; + complex& outi = *it; //out[i]; + outi.real(out_z[0]); + outi.imag(out_z[1]); + tmp += 2; + } +} + +void FFT_engine::from_fft(vector& out, const FFTPoly& in) const +{ + int tmp = 0; + double* in_arr = in_array; + fftw_complex* out_arr = out_array; + int N = fft_dim; + int Nd = double(N); + + //for (int i = 0; i < fft_dim2; ++i) + for (auto it = in.begin(); it < in.end(); ++it) + { + //std::cout << "i: " << i << ", number: " << in[i] << std::endl; + out_arr[tmp+1][0] = real(*it)/Nd; + out_arr[tmp+1][1] = imag(*it)/Nd; + out_arr[tmp][0] = 0.0; + out_arr[tmp][1] = 0.0; + tmp += 2; + } + fftw_execute(plan_from_fft); + out.resize(fft_dim); + for (int i = 0; i < N; ++i) + { + out[i] = long(round(in_arr[i])); + //std::cout << "i: " << i << ", number: " << out[i] << std::endl; + } +} + +FFT_engine::~FFT_engine() +{ + fftw_destroy_plan(plan_to_fft); + fftw_destroy_plan(plan_from_fft); + fftw_free(in_array); + fftw_free(out_array); +} + +FFTPoly operator +(const FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + FFTPoly res(a.size()); + for (size_t i = 0; i < a.size(); ++i) + res[i] = a[i]+b[i]; + + return res; +} + +void operator +=(FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + for (size_t i = 0; i < a.size(); ++i) + a[i]+=b[i]; +} + +void operator +=(FFTPoly& a, const complex b) +{ + for (size_t i = 0; i < a.size(); ++i) + a[i]+=b; +} + +FFTPoly operator -(const FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + FFTPoly res(a.size()); + for (size_t i = 0; i < a.size(); ++i) + res[i] = a[i]-b[i]; + + return res; +} + +void operator -=(FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + for (size_t i = 0; i < a.size(); ++i) + a[i]-=b[i]; +} + +FFTPoly operator *(const FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + FFTPoly res(a.size()); + for (size_t i = 0; i < a.size(); ++i) + res[i] = a[i]*b[i]; + + return res; +} + +void operator *=(FFTPoly& a, const FFTPoly& b) +{ + // check that input vectors have the same size + assert(a.size() == b.size()); + + for (size_t i = 0; i < a.size(); ++i) + a[i]*=b[i]; +} + +// TODO: make a test +FFTPoly operator *(const FFTPoly& a, const int b) +{ + FFTPoly res(a.size()); + double bd = double(b); + for (size_t i = 0; i < a.size(); ++i) + res[i] = a[i] * bd; + + return res; +} \ No newline at end of file diff --git a/src/keygen.cpp b/src/keygen.cpp new file mode 100644 index 0000000..c5792f1 --- /dev/null +++ b/src/keygen.cpp @@ -0,0 +1,443 @@ +#include "keygen.h" +#include "params.h" +#include "fft.h" + +#include +#include +#include + +using namespace NTL; +using namespace std; + +void KeyGen::get_sk_boot(SKey_boot& sk_boot) +{ + cout << "Started generating the secret key of the bootstrapping scheme" << endl; + clock_t start = clock(); + sk_boot.sk = ModQPoly(Param::N,0); + sk_boot.sk_inv = ModQPoly(Param::N,0); + + sampler.get_invertible_vector(sk_boot.sk, sk_boot.sk_inv, Param::t, 1L); + cout << "Generation time of the secret key of the bootstrapping scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_sk_base(SKey_base_NTRU& sk_base) +{ + cout << "Started generating the secret key of the base scheme" << endl; + clock_t start = clock(); + sk_base.sk = ModQMatrix(param.n, vector(param.n,0L)); + sk_base.sk_inv = ModQMatrix(param.n, vector(param.n,0L)); + + sampler.get_invertible_matrix(sk_base.sk, sk_base.sk_inv, 1L, 0L); + cout << "Generation time of the secret key of the base scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_sk_base(SKey_base_LWE& sk_base) +{ + cout << "Started generating the secret key of the base scheme" << endl; + clock_t start = clock(); + sk_base.clear(); + sk_base = vector(param.n,0L); + + sampler.get_binary_vector(sk_base); + cout << "Generation time of the secret key of the base scheme: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_ksk(KSKey_NTRU& ksk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot) +{ + cout << "Started key-switching key generation" << endl; + clock_t start = clock(); + // reset key-switching key + ksk.clear(); + ksk = ModQMatrix(param.Nl, vector(param.n,0)); + vector> ksk_long(param.Nl, vector(param.n,0L)); + cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // noise matrix G as in the paper + ModQMatrix G(param.Nl, vector(param.n,0L)); + sampler.get_ternary_matrix(G); + cout << "G gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // matrix G + P * Phi(f) * E as in the paper + int coef_w_pwr = sk_boot.sk[0]; + for (int i = 0; i < param.l_ksk; i++) + { + G[i][0] += coef_w_pwr; + coef_w_pwr *= Param::B_ksk; + } + for (int i = 1; i < Param::N; i++) + { + coef_w_pwr = -sk_boot.sk[Param::N-i]; + for (int j = 0; j < param.l_ksk; j++) + { + G[i*param.l_ksk+j][0] += coef_w_pwr; + coef_w_pwr *= Param::B_ksk; + } + } + cout << "G+P time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // parameters of the block optimization of matrix multiplication + int block = 4; + int blocks = (param.n/block)*block; + int rem_block = param.n%block; + // (G + P * Phi(f) * E) * F^(-1) as in the paper + for (int i = 0; i < param.Nl; i++) + { + //cout << "i: " << i << endl; + vector& k_row = ksk_long[i]; + vector& g_row = G[i]; + for (int k = 0; k < param.n; k++) + { + const vector& f_row = sk_base.sk_inv[k]; + //cout << "j: " << j << endl; + long coef = long(g_row[k]); + for (int j = 0; j < blocks; j+=block) + { + k_row[j] += (coef * f_row[j]); + k_row[j+1] += (coef * f_row[j+1]); + k_row[j+2] += (coef * f_row[j+2]); + k_row[j+3] += (coef * f_row[j+3]); + } + for (int j = 0; j < rem_block; j++) + k_row[blocks+j] += (coef * f_row[blocks+j]); + } + } + cout << "After K time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // reduce modulo q_base + for (int i = 0; i < param.Nl; i++) + param.mod_q_base(ksk[i], ksk_long[i]); + cout << "KSKey-gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_ksk(KSKey_LWE& ksk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) +{ + cout << "Started key-switching key generation" << endl; + clock_t start = clock(); + // reset key-switching key + ksk.A.clear(); + ksk.b.clear(); + for (int i = 0; i < param.Nl; i++) + { + vector row(param.n,0L); + ksk.A.push_back(row); + } + ksk.b = vector(param.Nl, 0L); + cout << "Reset time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // noise matrix G as in the paper + sampler.get_uniform_matrix(ksk.A); + cout << "A gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // matrix P * f_0 as in the paper + vector Pf0(param.Nl, 0L); + int coef_w_pwr = sk_boot.sk[0]; + for (int i = 0; i < param.l_ksk; i++) + { + Pf0[i] += coef_w_pwr; + coef_w_pwr *= Param::B_ksk; + } + for (int i = 1; i < Param::N; i++) + { + coef_w_pwr = -sk_boot.sk[Param::N-i]; + for (int j = 0; j < param.l_ksk; j++) + { + Pf0[i*param.l_ksk+j] += coef_w_pwr; + coef_w_pwr *= Param::B_ksk; + } + } + cout << "Pf0 time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + // A*s_base + e + Pf0 as in the paper + normal_distribution gaussian_sampler(0.0, Param::e_st_dev); + for (int i = 0; i < param.Nl; i++) + { + //cout << "i: " << i << endl; + vector& k_row = ksk.A[i]; + for (int k = 0; k < param.n; k++) + ksk.b[i] -= k_row[k] * sk_base[k]; + ksk.b[i] += (Pf0[i] + static_cast(round(gaussian_sampler(rand_engine)))); + param.mod_q_base(ksk.b[i]); + } + cout << "KSKey-gen time: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_bsk(BSKey_NTRU& bsk, const SKey_base_NTRU& sk_base, const SKey_boot& sk_boot) +{ + clock_t start = clock(); + + // index of a secret key coefficient of the base scheme + int coef_counter = 0; + + // reset the input + bsk.clear(); + + // transform the secret key of the bootstrapping scheme to the DFT domain + FFTPoly sk_boot_inv_fft(Param::N2p1); + fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv); + + // index of a secret key coefficient of the base scheme + coef_counter = 0; + // vector to keep the DFT transform of a random ternary vector + FFTPoly g_fft(Param::N2p1); + // vector to keep the DFT transform of a bootstrapping key part + FFTPoly tmp_bsk_fft(Param::N2p1); + // vector to keep a bootstrapping key part + vector tmp_bsk(Param::N); + vector tmp_bsk_int(Param::N); + // precompute FFT transformed powers of decomposition bases + vector> B_bsk_pwr_poly; + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + double B_bsk_double = param.B_bsk[iBase]; + vector base_row; + // FFT transform of (1,0,...,0) + FFTPoly tmp_fft(Param::N2p1,complex(1.0, 0.0)); + base_row.push_back(tmp_fft); + for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++) + { + transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(), + [B_bsk_double](complex &z){ return z*B_bsk_double; }); + base_row.push_back(tmp_fft); + } + B_bsk_pwr_poly.push_back(base_row); + } + + // loop over different decomposition bases + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + vector> base_row; + vector& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase]; + for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef++) + { + vector coef_row; + int sk_base_coef = sk_base.sk[iCoef][0]; + /** + * represent coefficient of the secret key + * of the base scheme using 2 bits. + * The representation rule is as follows: + * -1 => [0,1] + * 0 => [0,0] + * 1 => [1,0] + * */ + int coef_bits[2] = {0,0}; + if (sk_base_coef == -1) + coef_bits[1] = 1; + else if (sk_base_coef == 1) + coef_bits[0] = 1; + // encrypt each bit using the NGS scheme + for (int iBit = 0; iBit < 2; iBit++) + { + NGSFFTctxt bit_row; + for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++) + { + // sample random ternary vector + ModQPoly g(Param::N,0L); + sampler.get_ternary_vector(g); + // FFT transform it + fftN.to_fft(g_fft, g); + // compute g * sk_boot^(-1) + tmp_bsk_fft = g_fft * sk_boot_inv_fft; + // compute g * sk_boot^(-1) + B^i * bit + if (coef_bits[iBit] == 1) + tmp_bsk_fft = B_bsk_pwr_poly_row[iPart] + tmp_bsk_fft; + // inverse FFT of the above result + fftN.from_fft(tmp_bsk, tmp_bsk_fft); + // reduction modulo q_boot + mod_q_boot(tmp_bsk_int, tmp_bsk); + // FFT transform for further use + fftN.to_fft(tmp_bsk_fft, tmp_bsk_int); + + bit_row.push_back(tmp_bsk_fft); + } + coef_row.push_back(bit_row); + } + base_row.push_back(coef_row); + } + bsk.push_back(base_row); + coef_counter += param.bsk_partition[iBase]; + } + + cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_bsk(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) +{ + clock_t start = clock(); + + // index of a secret key coefficient of the base scheme + int coef_counter = 0; + + // reset the input + bsk.clear(); + + // transform the secret key of the bootstrapping scheme to the DFT domain + FFTPoly sk_boot_inv_fft(Param::N2p1); + fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv); + + // index of a secret key coefficient of the base scheme + coef_counter = 0; + // vector to keep the DFT transform of a random ternary vector + FFTPoly g_fft(Param::N2p1); + // vector to keep the DFT transform of a bootstrapping key part + FFTPoly tmp_bsk_fft(Param::N2p1); + // vector to keep a bootstrapping key part + ModQPoly tmp_bsk(Param::N); + vector tmp_bsk_long; + // precompute FFT transformed powers of decomposition bases + vector> B_bsk_pwr_poly; + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + double B_bsk_double = param.B_bsk[iBase]; + vector base_row; + // FFT transform of (1,0,...,0) + FFTPoly tmp_fft(Param::N2p1,complex(1.0, 0.0)); + base_row.push_back(tmp_fft); + for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++) + { + transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(), + [B_bsk_double](complex &z){ return z*B_bsk_double; }); + base_row.push_back(tmp_fft); + } + B_bsk_pwr_poly.push_back(base_row); + } + + bsk.clear(); + bsk = vector>(Param::B_bsk_size); + for (int i = 0; i < Param::B_bsk_size; i++) + bsk[i] = vector(param.bsk_partition[i], NGSFFTctxt(param.l_bsk[i], FFTPoly(Param::N2p1))); + + ModQPoly g(Param::N,0L); + // loop over different decomposition bases + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + vector base_row(param.bsk_partition[iBase], NGSFFTctxt(param.l_bsk[iBase], FFTPoly(Param::N2p1))); + vector& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase]; + for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef++) + { + NGSFFTctxt coef_row(param.l_bsk[iBase], FFTPoly(Param::N2p1)); + int sk_base_coef = sk_base[iCoef]; + // encrypt each bit using the NGS scheme + for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++) + { + // sample random ternary vector + sampler.get_ternary_vector(g); + // FFT transform it + fftN.to_fft(g_fft, g); + // compute g * sk_boot^(-1) + tmp_bsk_fft = g_fft; + tmp_bsk_fft *= sk_boot_inv_fft; + // compute g * sk_boot^(-1) + B^i * bit + if (sk_base_coef == 1) + tmp_bsk_fft += B_bsk_pwr_poly_row[iPart]; + // inverse FFT of the above result + fftN.from_fft(tmp_bsk_long, tmp_bsk_fft); + // reduction modulo q_boot + mod_q_boot(tmp_bsk, tmp_bsk_long); + // FFT transform for further use + fftN.to_fft(tmp_bsk_fft, tmp_bsk); + + coef_row[iPart] = tmp_bsk_fft; + } + base_row[iCoef-coef_counter] = coef_row; + } + bsk[iBase] = base_row; + coef_counter += param.bsk_partition[iBase]; + } + + cout << "Bootstrapping generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void KeyGen::get_bsk2(BSKey_LWE& bsk, const SKey_base_LWE& sk_base, const SKey_boot& sk_boot) +{ + clock_t start = clock(); + + // index of a secret key coefficient of the base scheme + int coef_counter = 0; + + // reset the input + bsk.clear(); + + // transform the secret key of the bootstrapping scheme to the DFT domain + FFTPoly sk_boot_inv_fft(Param::N2p1); + fftN.to_fft(sk_boot_inv_fft, sk_boot.sk_inv); + + // index of a secret key coefficient of the base scheme + coef_counter = 0; + // vector to keep the DFT transform of a random ternary vector + FFTPoly g_fft(Param::N2p1); + // vector to keep the DFT transform of a bootstrapping key part + FFTPoly tmp_bsk_fft(Param::N2p1); + // vector to keep a bootstrapping key part + ModQPoly tmp_bsk(Param::N); + vector tmp_bsk_long; + // precompute FFT transformed powers of decomposition bases + vector> B_bsk_pwr_poly; + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + double B_bsk_double = param.B_bsk[iBase]; + vector base_row; + // FFT transform of (1,0,...,0) + FFTPoly tmp_fft(Param::N2p1,complex(1.0, 0.0)); + base_row.push_back(tmp_fft); + for (int iPart = 1; iPart < param.l_bsk[iBase]; iPart++) + { + transform(tmp_fft.begin(), tmp_fft.end(), tmp_fft.begin(), + [B_bsk_double](complex &z){ return z*B_bsk_double; }); + base_row.push_back(tmp_fft); + } + B_bsk_pwr_poly.push_back(base_row); + } + + bsk.clear(); + bsk = vector>(Param::B_bsk_size); + for (int i = 0; i < Param::B_bsk_size; i++) + bsk[i] = vector(4 * (param.bsk_partition[i] >> 1), NGSFFTctxt(param.l_bsk[i], FFTPoly(Param::N2p1))); + + ModQPoly g(Param::N,0L); + // loop over different decomposition bases + int bits[4]; + for (int iBase = 0; iBase < Param::B_bsk_size; iBase++) + { + vector base_row(4 * (param.bsk_partition[iBase] >> 1), NGSFFTctxt(param.l_bsk[iBase], FFTPoly(Param::N2p1))); + vector& B_bsk_pwr_poly_row = B_bsk_pwr_poly[iBase]; + for (int iCoef = coef_counter; iCoef < coef_counter+param.bsk_partition[iBase]; iCoef+=2) + { + NGSFFTctxt coef_row(param.l_bsk[iBase], FFTPoly(Param::N2p1)); + // bits to encrypt: s[coef]*s[coef+1], s[coef]*(1-s[coef+1]), (1-s[coef])*s[coef+1] + bits[0] = sk_base[iCoef]*sk_base[iCoef+1]; + bits[1] = sk_base[iCoef]*(1-sk_base[iCoef+1]); + bits[2] = (1-sk_base[iCoef])*sk_base[iCoef+1]; + bits[3] = (1-sk_base[iCoef])*(1-sk_base[iCoef+1]); + // encrypt each bit using the NGS scheme + for (int iBit = 0; iBit < 4; iBit++) + { + for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++) + { + // sample random ternary vector + sampler.get_ternary_vector(g); + // FFT transform it + fftN.to_fft(g_fft, g); + // compute g * sk_boot^(-1) + tmp_bsk_fft = g_fft; + tmp_bsk_fft *= sk_boot_inv_fft; + // compute g * sk_boot^(-1) + B^i * bit + if (bits[iBit] == 1) + tmp_bsk_fft += B_bsk_pwr_poly_row[iPart]; + // inverse FFT of the above result + fftN.from_fft(tmp_bsk_long, tmp_bsk_fft); + // reduction modulo q_boot + mod_q_boot(tmp_bsk, tmp_bsk_long); + // FFT transform for further use + fftN.to_fft(tmp_bsk_fft, tmp_bsk); + + coef_row[iPart] = tmp_bsk_fft; + } + base_row[4*((iCoef-coef_counter) >> 1)+iBit] = coef_row; + } + } + bsk[iBase] = base_row; + coef_counter += param.bsk_partition[iBase]; + } + + cout << "Bootstrapping2 generation: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} \ No newline at end of file diff --git a/src/lwehe.cpp b/src/lwehe.cpp new file mode 100644 index 0000000..3a7804c --- /dev/null +++ b/src/lwehe.cpp @@ -0,0 +1,549 @@ +#include "lwehe.h" +#include "sampler.h" +#include "fft.h" +#include "ntruhe.h" + +#include +#include +#include + +using namespace std; + +Ctxt_LWE::Ctxt_LWE(const Ctxt_LWE& ct) +{ + a = ct.a; + b = ct.b; +} + +Ctxt_LWE& Ctxt_LWE::operator=(const Ctxt_LWE& ct) +{ + a = ct.a; + b = ct.b; + return *this; +} + +Ctxt_LWE Ctxt_LWE::operator +(const Ctxt_LWE& ct) const +{ + Ctxt_LWE res; + for (size_t i = 0; i < parLWE.n; i++) + res.a[i] = parLWE.mod_q_base(a[i] + ct.a[i]); + + res.b = parLWE.mod_q_base(b + ct.b); + + return res; +} + +Ctxt_LWE Ctxt_LWE::operator -(const Ctxt_LWE& ct) const +{ + Ctxt_LWE res; + for (size_t i = 0; i < parLWE.n; i++) + res.a[i] = parLWE.mod_q_base(a[i] - ct.a[i]); + + res.b = parLWE.mod_q_base(b + ct.b); + + return res; +} + +Ctxt_LWE operator -(const int c, const Ctxt_LWE& ct) +{ + Ctxt_LWE res; + res.a = vector(parLWE.n); + const vector& a = ct.a; + for (size_t i = 0; i < a.size(); i++) + res.a[i] = parLWE.mod_q_base(-a[i]); + + res.b = parLWE.mod_q_base(c-ct.b); + return res; +} + +void SchemeLWE::encrypt(Ctxt_LWE& ct, int m) const +{ + clock_t start = clock(); + + int n = parLWE.n; + + vector a(n,0L); + Sampler s(parLWE); + s.get_uniform_vector(a); + ct.a = a; + normal_distribution gaussian_sampler(0.0, Param::e_st_dev); + int b = parLWE.delta_base*m + static_cast(round(gaussian_sampler(rand_engine))); + for (int i = 0; i < n; i++) + { + b -= sk_base[i] * a[i]; + } + parLWE.mod_q_base(b); + ct.b = b; + + //cout << "Encryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +int SchemeLWE::decrypt(const Ctxt_LWE& ct) const +{ + clock_t start = clock(); + + int output = ct.b; + for (int i = 0; i < parLWE.n; i++) + { + output += ct.a[i] * sk_base[i]; + } + output = parLWE.mod_q_base(output); + output = int(round(double(output*Param::t)/double(parLWE.q_base))); + //cout << "Decryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + return output; +} + +void external_product(vector& res, const vector& poly, const vector& poly_vector, int b, int shift, int l) +{ + int N = Param::N; + int N2p1 = Param::N2p1; + + ModQPoly poly_sign(N); + ModQPoly poly_abs(N); + vector poly_decomp(N); + + for (int i = 0; i < N; ++i) + { + const int& polyi = poly[i]; + poly_abs[i] = abs(polyi); + poly_sign[i] = (polyi < 0)? -1 : 1; + } + FFTPoly res_fft(N2p1); + FFTPoly tmp_fft(N2p1); + int mask = b-1; + int bound = b >> 1; + int digit, sgn; + for (int j = 0; j < l; ++j) + { + for (int i = 0; i < N; ++i) + { + int& abs_val = poly_abs[i]; + digit = abs_val & mask; //poly_abs[i] % b; + if (digit > bound) + { + poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit); + abs_val >>= shift; + ++abs_val; //(abs_val - digit)/b + 1; + } + else + { + poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit; + abs_val >>= shift; //(abs_val - digit)/b; + } + } + fftN.to_fft(tmp_fft, poly_decomp); + tmp_fft *= poly_vector[j]; + res_fft += tmp_fft; + } + fftN.from_fft(res, res_fft); + //mod_q_boot(poly); +} + +void SchemeLWE::key_switch(Ctxt_LWE& ct, const ModQPoly& poly) const +{ + int N = Param::N; + int B_ksk = Param::B_ksk; + int l_ksk = parLWE.l_ksk; + int Nl = parLWE.Nl; + int n = parLWE.n; + + vector poly_decomp(Nl); + ModQPoly poly_sign(N); + ModQPoly poly_abs(N); + for (int i = 0; i < N; ++i) + { + const int& polyi = poly[i]; + poly_abs[i] = abs(polyi); + poly_sign[i] = (polyi < 0)? -1 : 1; + } + int digit; + int il = 0; + int tmp; + int sgn; + int bound = B_ksk >> 1; + for (int i = 0; i < N; ++i) + { + tmp = poly_abs[i]; + sgn = poly_sign[i]; + for (int j = 0; j < l_ksk; ++j) + { + digit = tmp % B_ksk; + if (digit > bound) + { + poly_decomp[il+j] = (sgn == 1) ? (digit - B_ksk): (B_ksk - digit); + tmp /= B_ksk; + ++tmp; + } + else + { + poly_decomp[il+j] = (sgn == 1) ? digit: - digit; + tmp /= B_ksk; + } + } + il += l_ksk; + } + vector a(n); + for (int i = 0; i < Nl; ++i) + { + long tmp_int = long(poly_decomp[i]); + const vector& ksk_row = ksk.A[i]; + for (int j = 0; j < n; ++j) + { + a[j] += long(ksk_row[j]) * tmp_int; + } + } + parLWE.mod_q_base(ct.a, a); + long b = 0L; + const vector& ksk_b = ksk.b; + for (int i = 0; i < Nl; ++i) + b += ksk_b[i] * long(poly_decomp[i]); + ct.b = parLWE.mod_q_base(b); +} + +/* +// debugger functions + void print(const vector& vec) + { + for (size_t i = 0; i < vec.size(); i++) + { + printf("[%zu] %d ", i, vec[i]); + } + cout << endl; + } + + void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk) + { + FFTPoly sk_fft; + fftN.to_fft(sk_fft, sk.sk); + FFTPoly ct_fft; + fftN.to_fft(ct_fft, ct); + + FFTPoly output_fft; + output_fft = ct_fft * sk_fft; + vector output; + vector output_int; + fftN.from_fft(output, output_fft); + parLWE.mod_q_boot(output_int, output); + print(output_int); + } + + void decrypt_poly_base_and_print(const ModQPoly& ct, const SKey_boot& sk) + { + FFTPoly sk_fft; + fftN.to_fft(sk_fft, sk.sk); + FFTPoly ct_fft; + fftN.to_fft(ct_fft, ct); + + FFTPoly output_fft; + output_fft = ct_fft * sk_fft; + ModQPoly output; + fftN.from_fft(output, output_fft); + mod_q_base(output); + print(output); + } + +void decryptN2(const Ctxt_LWE& ct, const SKey_base_LWE& sk) +{ + int output = ct.b; + for (int i = 0; i < lwe_he::n; i++) + { + output += ct.a[i] * sk[i]; + } + output = output%parLWE.N2; + if (output > parLWE.N) + output -= parLWE.N2; + if (output <= -parLWE.N) + output += parLWE.N2; + cout << output << endl; +} +// end debugger functions +*/ + +void SchemeLWE::bootstrap(Ctxt_LWE& ct) const +{ + //clock_t start = clock(); + int N = Param::N; + int N2 = Param::N2; + int N2p1 = Param::N2p1; + int B_bsk_size = Param::B_bsk_size; + int half_delta_boot = Param::half_delta_boot; + + // switch to modulus 2*N + modulo_switch_to_boot(ct); + // initialize accumulator and rotate accumulator by X^ct.b + vector acc(N, half_delta_boot); + int b_pow = (N/2 + ct.b)%N2; + if (b_pow < 0) + b_pow += N2; + int b_sign = 1; + if (b_pow >= N) + { + b_pow -= N; + b_sign = -1; + } + for (int i = 0; i < b_pow; ++i) + acc[i] = (b_sign == 1) ? -acc[i]: acc[i]; + for (int i = b_pow; i < N; ++i) + acc[i] = (b_sign == 1) ? acc[i]: -acc[i]; + + //accumulator loop + int coef_counter = 0; + vector& a = ct.a; + int coef, coef_sign, B, shift, l; + double Bd; + //auto start = clock(); + //float cmux_time = 0.0; + //float extprod_time = 0.0; + vector tmp_poly(N); + vector tmp_poly_long(N); + + const BSKey_LWE& boot_key = bk; + for (int iBase = 0; iBase < B_bsk_size; ++iBase) + { + B = parLWE.B_bsk[iBase]; + Bd = double(B); + shift = parLWE.shift_bsk[iBase]; + l = parLWE.l_bsk[iBase]; + //vector> w_powers(l); + //w_powers[0] = complex(1.0,0.0); + //for (int i = 1; i < l; i++) + // w_powers[i] = w_powers[i-1] * Bd; + const vector& bk_coef_row = boot_key[iBase]; + for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; ++iCoef) + { + //auto start = clock(); + coef = a[iCoef+coef_counter]; + if (coef == 0) continue; + coef_sign = 1; + if (coef < 0) coef += N2; + if (coef >= N) + { + coef -= N; + coef_sign = -1; + } + + // acc * (X^coef - 1) + if (coef_sign == 1) + { + for (int i = 0; i acc(N, half_delta_boot); + int b_pow = (N/2 + ct.b)%N2; + if (b_pow < 0) + b_pow += N2; + int b_sign = 1; + if (b_pow >= N) + { + b_pow -= N; + b_sign = -1; + } + for (int i = 0; i < b_pow; ++i) + acc[i] = (b_sign == 1) ? -acc[i]: acc[i]; + for (int i = b_pow; i < N; ++i) + acc[i] = (b_sign == 1) ? acc[i]: -acc[i]; + + //accumulator loop + int coef_counter = 0; + vector& a = ct.a; + int coef1, coef2, coef_sum, coef1_sign, coef2_sign, coef_sum_sign, B, shift, l; + double Bd; + //auto start = clock(); + //float cmux_time = 0.0; + //float extprod_time = 0.0; + vector tmp_poly(N); + vector tmp_poly_long(N); + + const BSKey_LWE& boot_key = bk; + for (int iBase = 0; iBase < B_bsk_size; ++iBase) + { + B = parLWE.B_bsk[iBase]; + Bd = double(B); + shift = parLWE.shift_bsk[iBase]; + l = parLWE.l_bsk[iBase]; + //vector> w_powers(l); + //w_powers[0] = complex(1.0,0.0); + //for (int i = 1; i < l; i++) + // w_powers[i] = w_powers[i-1] * Bd; + const vector& bk_coef_row = boot_key[iBase]; + vector mux_fft(l,FFTPoly(N2p1,complex(0.0,0.0))); + for (int iCoef = 0; iCoef < parLWE.bsk_partition[iBase]; iCoef+=2) + { + //auto start = clock(); + // normalize coef1 + coef1 = a[iCoef+coef_counter]; + coef1_sign = 1; + if (coef1 < 0) coef1 += N2; + if (coef1 >= N) + { + coef1 -= N; + coef1_sign = -1; + } + // normalize coef2 + coef2 = a[iCoef+coef_counter+1]; + coef2_sign = 1; + if (coef2 < 0) coef2 += N2; + if (coef2 >= N) + { + coef2 -= N; + coef2_sign = -1; + } + // normalize coef_sum + coef_sum = (a[iCoef+coef_counter] + a[iCoef+coef_counter+1]) % N2; + coef_sum_sign = 1; + if (coef_sum < 0) coef_sum += N2; + if (coef_sum >= N) + { + coef_sum -= N; + coef_sum_sign = -1; + } + + // bk_0 * X^(c_0+c_1) + bk_1 * X^(c_0) + bk_2 * X^(c_1) + bk_3 + const FFTPoly& x_sum = (coef_sum_sign == 1) ? fftN.pos_powers[coef_sum]: fftN.neg_powers[coef_sum]; + const FFTPoly& x_c1 = (coef1_sign == 1) ? fftN.pos_powers[coef1]: fftN.neg_powers[coef1]; + const FFTPoly& x_c2 = (coef2_sign == 1) ? fftN.pos_powers[coef2]: fftN.neg_powers[coef2]; + const NGSFFTctxt& bk_part_row0 = bk_coef_row[4*(iCoef >> 1)]; + const NGSFFTctxt& bk_part_row1 = bk_coef_row[4*(iCoef >> 1)+1]; + const NGSFFTctxt& bk_part_row2 = bk_coef_row[4*(iCoef >> 1)+2]; + const NGSFFTctxt& bk_part_row3 = bk_coef_row[4*(iCoef >> 1)+3]; + for (int iPart = 0; iPart < l; ++iPart) + { + mux_fft[iPart] = bk_part_row0[iPart]; + mux_fft[iPart] *= x_sum; + + if (coef1 == coef2 && coef1_sign == coef2_sign) + { + FFTPoly tmp_fft = bk_part_row1[iPart]; + tmp_fft += bk_part_row2[iPart]; + tmp_fft *= x_c1; + mux_fft[iPart] += tmp_fft; + } + else + { + FFTPoly tmp_fft = bk_part_row1[iPart]; + tmp_fft *= x_c1; + mux_fft[iPart] += tmp_fft; + + tmp_fft = bk_part_row2[iPart]; + tmp_fft *= x_c2; + mux_fft[iPart] += tmp_fft; + } + + mux_fft[iPart] += bk_part_row3[iPart]; + } + //cmux_time += float(clock()-start)/CLOCKS_PER_SEC; + + //start = clock(); + external_product(tmp_poly_long, acc, mux_fft, B, shift, l); + mod_q_boot(acc, tmp_poly_long); + //extprod_time += float(clock()-start)/CLOCKS_PER_SEC; + } + coef_counter += parLWE.bsk_partition[iBase]; + } + //cout << "Cmux: " << cmux_time << endl; + //cout << "Ext. prod: " << extprod_time << endl; + + // add floor(q_boot/(2*t)) to all coefficients of the accumulator + for (auto it = acc.begin(); it < acc.end(); ++it) + *it += half_delta_boot; + + //mod q_boot of the accumulator + mod_q_boot(acc); + + //decrypt_poly_boot_and_print(acc, sk_boot); + + //mod switch to q_base + modulo_switch_to_base_lwe(acc); + + //decrypt_poly_boot_and_print(acc, sk_boot); + + //key switch + //auto start = clock(); + key_switch(ct, acc); + //cout << "Key-switching: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + //cout << "Bootstrapping: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void SchemeLWE::nand_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const +{ + //clock_t start = clock(); + ct_res = parLWE.nand_const - (ct1 + ct2); + bootstrap(ct_res); + //cout << "NAND: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void SchemeLWE::and_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const +{ + //clock_t start = clock(); + ct_res = parLWE.and_const - (ct1 + ct2); + bootstrap(ct_res); + //cout << "AND: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +void SchemeLWE::or_gate(Ctxt_LWE& ct_res, const Ctxt_LWE& ct1, const Ctxt_LWE& ct2) const +{ + //clock_t start = clock(); + ct_res = parLWE.or_const - (ct1 + ct2); + bootstrap(ct_res); + //cout << "OR: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} \ No newline at end of file diff --git a/src/ntruhe.cpp b/src/ntruhe.cpp new file mode 100644 index 0000000..0026969 --- /dev/null +++ b/src/ntruhe.cpp @@ -0,0 +1,416 @@ +#include "ntruhe.h" +#include "sampler.h" +#include "fft.h" +#include "lwehe.h" + +#include "time.h" +#include + +Ctxt_NTRU::Ctxt_NTRU(const Ctxt_NTRU& ct) +{ + data = ct.data; +} + +Ctxt_NTRU& Ctxt_NTRU::operator=(const Ctxt_NTRU& ct) +{ + data = ct.data; + return *this; +} + +Ctxt_NTRU Ctxt_NTRU::operator +(const Ctxt_NTRU& ct) const +{ + Ctxt_NTRU res; + for (size_t i = 0; i < parNTRU.n; ++i) + res.data[i] = parNTRU.mod_q_base(data[i] + ct.data[i]); + + return res; +} + +Ctxt_NTRU Ctxt_NTRU::operator -(const Ctxt_NTRU& ct) const +{ + Ctxt_NTRU res; + for (size_t i = 0; i < parNTRU.n; ++i) + res.data[i] = parNTRU.mod_q_base(data[i] - ct.data[i]); + return res; +} + +void Ctxt_NTRU::operator -=(const Ctxt_NTRU& ct) +{ + for (size_t i = 0; i < parNTRU.n; ++i) + { + data[i] -= ct.data[i]; + data[i] = parNTRU.mod_q_base(data[i]); + } +} + +void SchemeNTRU::encrypt(Ctxt_NTRU& ct, const int b) const +{ + clock_t start = clock(); + + int n = parNTRU.n; + + vector g(n,0L); + Sampler::get_ternary_vector(g); + g[0] += b * parNTRU.delta_base; + ct.data = vector(n,0); + vector ct_long(n,0L); + for (int i = 0; i < n; i++) + { + long g_coef = long(g[i]); + const vector& sk_row = sk_base.sk_inv[i]; + for (int j = 0; j < n; j++) + ct_long[j] += long(sk_row[j]) * g_coef; + } + parNTRU.mod_q_base(ct.data, ct_long); + + //cout << "Encryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl; +} + +int SchemeNTRU::decrypt(const Ctxt_NTRU& ct) const +{ + clock_t start = clock(); + + int output = 0; + for (int i = 0; i < parNTRU.n; i++) + { + output += ct.data[i] * sk_base.sk[i][0]; + } + output = parNTRU.mod_q_base(output); + output = int(round(double(output)/double(parNTRU.q_base)*Param::t)); + //cout << "Decryption: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + return output; +} + +/* +void SchemeNTRU::external_product(vector& res, const vector& poly, const vector& poly_vector, const int b, const int shift, const int l) const +{ + int N = Param::N; + int N2p1 = Param::N2p1; + + ModQPoly poly_sign(N,0L); + ModQPoly poly_abs(N,0L); + for (int i = 0; i < N; i++) + { + const int& polyi = poly[i]; + poly_abs[i] = abs(polyi); + poly_sign[i] = (polyi < 0)? -1 : 1; + } + FFTPoly res_fft(N2p1); + FFTPoly tmp_fft(N2p1); + int mask = b-1; + int bound = b >> 1; + int digit, sgn, abs_val; + vector poly_decomp(N); + for (int j = 0; j < l; j++) + { + for (int i = 0; i < N; i++) + { + abs_val = poly_abs[i]; + digit = abs_val & mask; //poly_abs[i] % b; + if (digit > bound) + { + poly_decomp[i] = (poly_sign[i] == 1) ? (digit - b): (b - digit); + poly_abs[i] = (abs_val >> shift) + 1; //(abs_val - digit)/b + 1; + } + else + { + poly_decomp[i] = (poly_sign[i] == 1) ? digit: -digit; + poly_abs[i] = abs_val >> shift; //(abs_val - digit)/b; + } + } + fftN.to_fft(tmp_fft, poly_decomp); + tmp_fft *= poly_vector[j]; + res_fft += tmp_fft; + } + fftN.from_fft(res, res_fft); + //mod_q_boot(poly); +} +*/ + +void SchemeNTRU::key_switch(Ctxt_NTRU& ct, const ModQPoly& poly) const +{ + int N = Param::N; + int B_ksk = Param::B_ksk; + int Nl = parNTRU.Nl; + int l_ksk = parNTRU.l_ksk; + int n = parNTRU.n; + + vector poly_decomp(Nl); + ModQPoly poly_sign(N); + ModQPoly poly_abs(N); + for (int i = 0; i < N; ++i) + { + const int& polyi = poly[i]; + poly_abs[i] = abs(polyi); + poly_sign[i] = (polyi < 0)? -1 : 1; + } + int bound = B_ksk >> 1; + int il = 0; + int digit, tmp, sgn; + for (int i = 0; i < N; ++i) + { + tmp = poly_abs[i]; + sgn = poly_sign[i]; + for (int j = 0; j < l_ksk; j++) + { + int digit = tmp % B_ksk; + if (digit > bound) + { + poly_decomp[il+j] = (sgn == 1) ? (digit - B_ksk): (B_ksk - digit); + tmp /= B_ksk; + ++tmp; + } + else + { + poly_decomp[il+j] = (sgn == 1) ? digit: - digit; + tmp /= B_ksk; + } + } + il += l_ksk; + } + vector ct_long(n); + for (int i = 0; i < Nl; ++i) + { + long tmp_int = long(poly_decomp[i]); + const vector& ksk_row = ksk[i]; + for (int j = 0; j < n; ++j) + { + ct_long[j] += long(ksk_row[j]) * tmp_int; + } + } + parNTRU.mod_q_base(ct.data, ct_long); +} + +/* +// debugger functions +void print(const vector& vec) +{ + for (size_t i = 0; i < vec.size(); i++) + { + printf("[%zu] %d ", i, vec[i]); + } + cout << endl; +} +void decrypt_poly_boot_and_print(const ModQPoly& ct, const SKey_boot& sk, const Param& param) +{ + FFTPoly sk_fft(Param::N2p1); + fftN.to_fft(sk_fft, sk.sk); + FFTPoly ct_fft(Param::N2p1); + fftN.to_fft(ct_fft, ct); + + FFTPoly output_fft; + output_fft = ct_fft * sk_fft; + ModQPoly output; + vector output_long; + fftN.from_fft(output_long, output_fft); + mod_q_boot(output, output_long); + print(output); +} + +void decrypt_poly_base_and_print(const ModQPoly& ct, const Param& param, const SKey_boot& sk) +{ + FFTPoly sk_fft(Param::N2p1); + fftN.to_fft(sk_fft, sk.sk); + FFTPoly ct_fft(Param::N2p1); + fftN.to_fft(ct_fft, ct); + + FFTPoly output_fft; + output_fft = ct_fft * sk_fft; + ModQPoly output; + vector output_long; + fftN.from_fft(output_long, output_fft); + param.mod_q_base(output, output_long); + print(output); +} + +void decryptN2(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk) +{ + int N = Param::N; + int N2 = Param::N2; + int n = parNTRU.n; + + int output = 0; + for (int i = 0; i < n; i++) + { + output += ct.data[i] * sk.sk[i][0]; + } + output = output%N2; + if (output > N) + output -= N2; + if (output <= -N) + output += N2; + cout << output << endl; +} + +void decrypt_base(const Ctxt_NTRU& ct, const SKey_base_NTRU& sk) +{ + int n = parNTRU.n; + int q_base = parNTRU.q_base; + int half_q_base= parNTRU.half_q_base; + + int output = 0; + for (int i = 0; i < n; i++) + { + output += ct.data[i] * sk.sk[i][0]; + } + output = output%q_base; + if (output > half_q_base) + output -= q_base; + if (output <= -half_q_base) + output += q_base; + cout << output << endl; +} +// end debugger functions +*/ +void SchemeNTRU::mask_constant(Ctxt_NTRU& ct, int constant) +{ + int n = parNTRU.n; + + vector g(n); + Sampler::get_ternary_vector(g); + g[0] += constant; + ct.data = vector(n); + vector ct_long(n); + for (int i = 0; i < n; i++) + { + long g_coef = long(g[i]); + const vector& sk_row = sk_base.sk_inv[i]; + for (int j = 0; j < n; j++) + ct_long[j] += sk_row[j] * g_coef; + } + parNTRU.mod_q_base(ct.data, ct_long); +} + +void SchemeNTRU::bootstrap(Ctxt_NTRU& ct) const +{ + //clock_t start = clock(); + int N = Param::N; + int N2 = Param::N2; + int N2p1 = Param::N2p1; + int B_bsk_size = Param::B_bsk_size; + int half_delta_boot = Param::half_delta_boot; + + // switch to modulus 2*N + modulo_switch_to_boot(ct); + // initialize accumulator + ModQPoly acc(N, half_delta_boot); + //vector acc_long(N); + for (size_t i = 0; i < N/2; i++) + acc[i] = -acc[i]; + + //accumulator loop + int coef_counter = 0; + vector& data = ct.data; + int coef, neg_coef, coef_sign, neg_coef_sign, B, shift, l; + double Bd; + vector tmp_poly(N); + vector tmp_poly_long(N); + + const BSKey_NTRU& boot_key = bk; + for (int iBase = 0; iBase < B_bsk_size; ++iBase) + { + B = parNTRU.B_bsk[iBase]; + Bd = double(B); + shift = parNTRU.shift_bsk[iBase]; + l = parNTRU.l_bsk[iBase]; + //vector> w_power_fft(l); + //w_power_fft[0] = complex(1.0,0.0); + //for (int i = 1; i < l; i++) + // w_power_fft[i] = w_power_fft[i-1] * Bd; + const vector>& bk_coef_row = boot_key[iBase]; + vector mux_fft(l, FFTPoly(N2p1)); + for (int iCoef = 0; iCoef < parNTRU.bsk_partition[iBase]; ++iCoef) + { + coef = data[iCoef+coef_counter]; + if (coef == 0) continue; + coef_sign = 1; + if (coef < 0) coef += N2; + if (coef >= N) + { + coef -= N; + coef_sign = -1; + } + neg_coef = N - coef; + neg_coef_sign = -coef_sign; + if(neg_coef == N) + { + neg_coef = 0; + neg_coef_sign = -neg_coef_sign; + } + + // acc * (X^coef - 1) + if (coef_sign == 1) + { + for (int i = 0; i +#include +#include + +void Sampler::get_ternary_vector(vector& vec) +{ + + for(int i=0; i>& mat) +{ + + for(int i=0; i& row = mat[i]; + get_ternary_vector(row); + } +} + +void Sampler::get_binary_vector(vector& vec) +{ + for(int i=0; i& vec) +{ + for(int i=0; i>& mat) +{ + for(int i=0; i& row = mat[i]; + get_uniform_vector(row); + } +} + +void Sampler::get_gaussian_vector(vector& vec, double st_dev) +{ + normal_distribution gaussian_sampler(0.0, st_dev); + for(size_t i=0; i(round(gaussian_sampler(rand_engine))); +} + +void Sampler::get_gaussian_matrix(vector>& mat, double st_dev) +{ + for(size_t i=0; i& row = mat[i]; + get_gaussian_vector(row, st_dev); + } +} + +void Sampler::get_invertible_vector(vector& vec, vector& vec_inv, int scale, int shift) +{ + //polynomial with the coefficient vector vec (will be generated later) + ZZ_pX poly; + //element of Z_(q_boot) + ZZ_p coef; + coef.init(ZZ(q_boot)); + //the inverse of poly modulo poly_mod (will be generated later) + ZZ_pX inv_poly; + //random sampling + while (true) + { + //create the polynomial with the coefficient vector of the desired form + SetCoeff(poly, 0, ternary_sampler(rand_engine)*scale + shift); + for (size_t i = 1; i < vec.size(); i++) + { + coef = ternary_sampler(rand_engine)*scale; + SetCoeff(poly, i, coef); + } + //test invertibility + try + { + InvMod(inv_poly, poly, Param::get_def_poly()); + break; + } + catch(...) + { + cout << "Polynomial " << poly << " isn't a unit" << endl; + continue; + } + } + //cout << "Poly: " << poly << endl; + //cout << "Poly inverse: " << inv_poly << endl; + //extract the coefficient vector of poly + int tmp_coef; + for (int i = 0; i <= deg(poly); i++) + { + tmp_coef = conv(poly[i]); + if (tmp_coef > half_q_boot) + tmp_coef -= q_boot; + vec[i] = tmp_coef; + } + + for (int i = 0; i <= deg(inv_poly); i++) + { + tmp_coef = conv(inv_poly[i]); + if (tmp_coef > half_q_boot) + tmp_coef -= q_boot; + vec_inv[i] = tmp_coef; + } + + //cout << "Vector:" << vec << endl; + //cout << "Inverse vector:" << vec_inv << endl; +} + +void Sampler::get_invertible_matrix(vector>& mat, vector>& mat_inv, int scale, int shift) +{ + //check that the input matrices are squares + assert(mat[0].size() == mat.size()); + assert(mat_inv[0].size() == mat_inv.size()); + //check that both input matrices have the same dimension + assert(mat.size() == mat_inv.size()); + + //number of rows of the input matrix + int dim = mat.size(); + + //element of Z_(q_boot) + ZZ_p coef; + coef.init(ZZ(param.q_base)); + + //candidate matrix + mat_ZZ_p tmp_mat(INIT_SIZE, dim, dim); + + //candidate inverse matrix + mat_ZZ_p tmp_mat_inv(INIT_SIZE, dim, dim); + + //sampling and testing + while (true) + { + //sampling + for (int i = 0; i < dim; i++) + { + Vec& row = tmp_mat[i]; + for (int j = 0; j < dim; j++) + { + coef = ternary_sampler(rand_engine)*scale; + if (i==j) + coef += ZZ_p(shift); + row[j] = coef; + } + } + //test invertibility + try + { + inv(tmp_mat_inv, tmp_mat); + break; + } + catch(...) + { + cout << "Matrix " << tmp_mat << " is singular" << endl; + continue; + } + } + //lift mod q representation to integers + int tmp_coef; + for (int i = 0; i < dim; i++) + { + Vec& tmp_row = tmp_mat[i]; + Vec& tmp_row_inv = tmp_mat_inv[i]; + vector& row = mat[i]; + vector& row_inv = mat_inv[i]; + for (int j = 0; j < dim; j++) + { + tmp_coef = conv(tmp_row[j]); + if (tmp_coef > param.half_q_base) + tmp_coef -= param.q_base; + row[j] = tmp_coef; + + tmp_coef = conv(tmp_row_inv[j]); + if (tmp_coef > param.half_q_base) + tmp_coef -= param.q_base; + row_inv[j] = tmp_coef; + } + } +} \ No newline at end of file diff --git a/test.cpp b/test.cpp new file mode 100644 index 0000000..eee41a8 --- /dev/null +++ b/test.cpp @@ -0,0 +1,960 @@ +#include +#include +#include "params.h" +#include "sampler.h" +#include "keygen.h" +#include "fft.h" +#include "ntruhe.h" +#include "lwehe.h" + +#include +#include +#include +#include +#include + +#include + +using namespace std; +using namespace NTL; + +void test_params() +{ + { + Param param(LWE); + cout << "Ciphertext modulus of the base scheme (LWE): " << param.q_base << endl; + cout << "Dimension of the base scheme (LWE): " << param.n << endl; + cout << "Ciphertext modulus for bootstrapping (LWE): " << q_boot << endl; + cout << "Polynomial modulus (LWE): " << Param::get_def_poly() << endl; + assert(param.l_ksk == int(ceil(log(double(param.q_base))/log(double(Param::B_ksk))))); + cout << "Decomposition length for key-switching (LWE): " << param.l_ksk << endl; + cout << "Decomposition bases for key-switching (LWE): " << Param::B_ksk << endl; + cout << "Dimension for bootstrapping (LWE): " << Param::N << endl; + cout << "Decomposition bases for bootstrapping (LWE): "; + for (const auto &v: param.B_bsk) cout << v << ' '; + cout << endl; + cout << "Delta (LWE): " << param.delta_base << endl; + cout << "Half Delta (LWE): " << param.half_delta_base << endl; + } + { + Param param(NTRU); + cout << "Ciphertext modulus of the base scheme (MNTRU): " << param.q_base << endl; + cout << "Dimension of the base scheme (NTRU): " << param.n << endl; + cout << "Ciphertext modulus for bootstrapping (NTRU): " << q_boot << endl; + cout << "Polynomial modulus (NTRU): " << Param::get_def_poly() << endl; + assert(param.l_ksk == int(ceil(log(double(param.q_base))/log(double(Param::B_ksk))))); + cout << "Decomposition length for key-switching (MNTRU): " << param.l_ksk << endl; + cout << "Decomposition bases for key-switching (MNTRU): " << Param::B_ksk << endl; + cout << "Dimension for bootstrapping (MNTRU): " << Param::N << endl; + cout << "Decomposition bases for bootstrapping (MNTRU): "; + for (const auto &v: param.B_bsk) cout << v << ' '; + cout << endl; + cout << "Decomposition lengths for bootstrapping (MNTRU): "; + for (int i = 0; i < Param::B_bsk_size; i++) + { + assert(param.l_bsk[i] == int(ceil(log(double(q_boot))/log(double(param.B_bsk[i]))))); + cout << param.l_bsk[i] << ' '; + } + cout << endl; + cout << "Decomposition lengths for bootstrapping (MNTRU): "; + for (int i = 0; i < Param::B_bsk_size; i++) + { + assert(param.l_bsk[i] == int(ceil(log(double(q_boot))/log(double(param.B_bsk[i]))))); + cout << param.l_bsk[i] << ' '; + } + cout << endl; + cout << "Delta (MNTRU): " << param.delta_base << endl; + cout << "Half Delta (MNTRU): " << param.half_delta_base << endl; + + { + assert(0L == mod_q_boot(0L)); + assert(1L == mod_q_boot(1L)); + assert(0L == mod_q_boot(q_boot)); + assert(half_q_boot == mod_q_boot(half_q_boot)); + assert(-half_q_boot == mod_q_boot(-half_q_boot)); + cout << "MODULO REDUCTION IS OK" << endl; + } + } + + cout << "Plaintext modulus: " << Param::t << endl; + cout << endl; + cout << "PARAMS ARE OK" << endl; + + { + vector res; + decompose(res, 0, 2, 3); + assert(res.size() == 3); + for (auto iter=res.begin(); iter < res.end(); iter++) + assert(0L == *iter); + } + { + vector res; + decompose(res, 1, 2, 3); + assert(res.size() == 3); + assert(res[0] == 1); + for (auto iter=res.begin()+1; iter < res.end(); iter++) + assert(0L == *iter); + } + { + vector res; + decompose(res, 2, 3, 3); + assert(res.size() == 3); + assert(res[0] == -1 && res[1] == 1 && res[2] == 0); + } + { + vector res; + decompose(res, 2, 4, 3); + assert(res.size() == 3); + assert(res[0] == 2 && res[1] == 0 && res[2] == 0); + decompose(res, 3, 4, 3); + assert(res.size() == 3); + assert(res[0] == -1 && res[1] == 1 && res[2] == 0); + } + { + vector res; + try + { + decompose(res, 14, 3, 3); + assert(false); + } + catch (overflow_error) + { + assert(true); + } + } + { + vector res; + try + { + decompose(res, -14, 3, 3); + assert(false); + } + catch (overflow_error) + { + assert(true); + } + } + { + vector res; + decompose(res, 13, 3, 3); + assert(res.size() == 3); + assert(res[0] == 1 && res[1] == 1 && res[2] == 1); + decompose(res, -13, 3, 3); + assert(res.size() == 3); + assert(res[0] == -1 && res[1] == -1 && res[2] == -1); + } + + + cout << "DECOMPOSITION IS OK" << endl; + + +} + +void test_sampler() +{ + int N = Param::N; + + Param pLWE(LWE); + Param pNTRU(NTRU); + for (int run = 0; run < 1; run++) + { + //cout << "Run: " << run+1 << endl; + { + vector vec(pNTRU.n, 0L); + Sampler::get_ternary_vector(vec); + + assert(vec.size() == pNTRU.n); + for (int i = 0; i < pNTRU.n; i++) + { + assert((vec[i]==0) || (vec[i]==-1) || (vec[i]==1) ); + } + } + + { + vector vec(N,0L); + Sampler::get_ternary_vector(vec); + + assert(vec.size() == N); + for (int i = 0; i < N; i++) + { + assert((vec[i]==0) || (vec[i]==-1) || (vec[i]==1) ); + } + } + + { + vector vec(N,0L); + Sampler::get_binary_vector(vec); + + assert(vec.size() == N); + for (int i = 0; i < N; i++) + { + assert((vec[i]==0) || (vec[i]==1) ); + } + } + + { + int n = pLWE.n; + vector> mat(n, vector(N,0L)); + Sampler::get_ternary_matrix(mat); + + assert(mat.size() == n && mat[0].size() == N); + for (int i = 0; i < n; i++) + { + vector& row = mat[i]; + for (int j = 0; j < N; j++) + assert((row[j]==0) || (row[j]==-1) || (row[j]==1) ); + } + } + + { + int n = pLWE.n; + vector vec(n, 0L); + double st_dev = 4.0; + Sampler::get_gaussian_vector(vec, st_dev); + + assert(vec.size() == n); + for (int i = 0; i < n; i++) + { + assert(conv(abs(vec[i])) < 6*st_dev); + } + } + + { + int n = pNTRU.n; + vector> mat(n, vector(N,0L)); + double st_dev = 4.0; + Sampler::get_gaussian_matrix(mat, st_dev); + + assert(mat.size() == n && mat[0].size() == N); + for (int i = 0; i < n; i++) + { + vector& row = mat[i]; + for (int j = 0; j < N; j++) + assert(conv(abs(row[j])) < 6*st_dev); + } + } + + { + vector vec_inv(N,0L); + vector vec(N,0L); + Sampler s(pNTRU); + s.get_invertible_vector(vec, vec_inv, 4, 1); + + assert(vec.size() == N && vec_inv.size() == N); + assert((vec[0]==1) || (vec[0]==-3) || (vec[0]==5) ); + for (int i = 1; i < N; i++) + { + assert((vec[i]==0) || (vec[i]==-4) || (vec[i]==4)); + } + } + + { + int n = pLWE.n; + vector> mat_inv(n, vector(n,0L)); + vector> mat(n, vector(n,0L)); + Sampler s(pLWE); + s.get_invertible_matrix(mat, mat_inv, 5, 1); + + assert(mat.size() == n && mat[0].size() == n + && mat_inv.size() == n && mat_inv[0].size() == n); + for (int i = 0; i < n; i++) + assert((mat[i][i]==1) || (mat[i][i]==-4) || (mat[i][i]==6) ); + + for (int i = 0; i < n; i++) + for (int j = 0; (j < n) && (j != i); j++) + { + assert((mat[i][j]==0) || (mat[i][j]==-5) || (mat[i][j]==5) ); + } + } + } + cout << "SAMPLER IS OK" << endl; +} + +void test_ntru_key_gen() +{ + Param param(NTRU); + int n = param.n; + int Nl = param.Nl; + int half_q_base = param.half_q_base; + int q_base = param.q_base; + int l_ksk = param.l_ksk; + int N = Param::N; + int t = Param::t; + int B_ksk = Param::B_ksk; + int B_bsk_size = Param::B_bsk_size; + int N2p1 = Param::N2p1; + + SKey_base_NTRU sk_base; + KeyGen k(param); + k.get_sk_base(sk_base); + cout << "Secret key of the base scheme is generated" << endl; + assert(sk_base.sk.size() == n && sk_base.sk[0].size() == n + && sk_base.sk_inv.size() == n && sk_base.sk_inv[0].size() == n); + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + { + assert((sk_base.sk[i][j]==0) || (sk_base.sk[i][j]==-1) || (sk_base.sk[i][j]==1) ); + } + + SKey_boot sk_boot; + k.get_sk_boot(sk_boot); + cout << "Secret key of the bootstrapping scheme is generated" << endl; + assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N); + assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1))); + for (int i = 1; i < N; i++) + { + assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) ); + } + + KSKey_NTRU ksk; + k.get_ksk(ksk, sk_base, sk_boot); + cout << "Key-switching key is generated" << endl; + assert(ksk.size() == Nl && ksk[0].size() == n); + for (int i = 0; i < Nl; i++) + for (int j = 0; j < n; j++) + { + //cout << ksk[i][j] << endl; + assert(ksk[i][j] <= half_q_base && ksk[i][j] >= -half_q_base); + } + + vector q4_decomp; + decompose(q4_decomp, q_base/4, B_ksk, l_ksk); + vector ks_res(n,0L); + for (int i = 0; i < l_ksk; i++) + { + int tmp_int = q4_decomp[i]; + vector& ksk_row = ksk[i]; + for (int j = 0; j < n; j++) + { + ks_res[j] += ksk_row[j] * tmp_int; + } + } + param.mod_q_base(ks_res); + int ks_int = 0; + for (int i = 0; i < n; i++) + { + ks_int += ks_res[i] * sk_base.sk[i][0]; + } + ks_int = param.mod_q_base(ks_int); + ks_int = int(round(double(ks_int*4)/double(q_base))); + assert(ks_int == 1L); + + // bootstrapping key test + BSKey_NTRU bsk; + k.get_bsk(bsk, sk_base, sk_boot); + cout << "Bootstrapping key is generated" << endl; + + // check dimensions + assert(bsk.size() == B_bsk_size); + for (int i = 0; i < bsk.size(); i++) + { + assert(bsk[i].size() == param.bsk_partition[i]); + for (int j = 0; j < bsk[i].size(); j++) + { + assert(bsk[i][j].size() == 2); + assert(bsk[i][j][0].size() == param.l_bsk[i]); + assert(bsk[i][j][1].size() == param.l_bsk[i]); + } + } + // convert sk_boot to FFT + vector> sk_boot_fft(N2p1); + fftN.to_fft(sk_boot_fft, sk_boot.sk); + + int coef_counter = 0; + for (int iBase = 0; iBase < B_bsk_size; iBase++) + { + decompose(q4_decomp, q_boot/4, param.B_bsk[iBase], param.l_bsk[iBase]); + for (size_t iCoef = 0; iCoef < bsk[iBase].size(); iCoef++) + { + int sk_coef = 0; + int sk_base_coef_bits[2]; + for (int iBit = 0; iBit < 2; iBit++) + { + vector> tmp_fft(N2p1, complex(0.0,0.0)); + for (int iPart = 0; iPart < param.l_bsk[iBase]; iPart++) + { + tmp_fft = tmp_fft + bsk[iBase][iCoef][iBit][iPart] * q4_decomp[iPart]; + } + tmp_fft = tmp_fft * sk_boot_fft; + vector tmp_int; + vector tmp_long; + fftN.from_fft(tmp_long, tmp_fft); + mod_q_boot(tmp_int, tmp_long); + sk_base_coef_bits[iBit] = int(round(double(tmp_int[0]*4)/double(q_boot))); + } + if (sk_base_coef_bits[1] == 1) + sk_coef = -1; + else if (sk_base_coef_bits[0] == 1) + sk_coef = 1; + + assert(sk_coef == sk_base.sk[coef_counter + iCoef][0]); + } + coef_counter += param.bsk_partition[iBase]; + } + + cout << "KEYGEN IS OK" << endl; +} + +void test_lwe_key_gen() +{ + Param param(LWE); + int n = param.n; + int N = Param::N; + int t = Param::t; + + SKey_base_LWE sk_base; + KeyGen k(param); + k.get_sk_base(sk_base); + cout << "Secret key of the base scheme is generated" << endl; + assert(sk_base.size() == n); + for (int j = 0; j < n; j++) + { + assert((sk_base[j]==0) || (sk_base[j]==1)); + } + + SKey_boot sk_boot; + k.get_sk_boot(sk_boot); + cout << "Secret key of the bootstrapping scheme is generated" << endl; + assert(sk_boot.sk.size() == N && sk_boot.sk_inv.size() == N); + assert((sk_boot.sk[0]==1) || (sk_boot.sk[0]==(-t+1)) || (sk_boot.sk[0]==(t+1))); + for (int i = 1; i < N; i++) + { + assert((sk_boot.sk[i]==0) || (sk_boot.sk[i]==-t) || (sk_boot.sk[i]==t) ); + } + + cout << "KEYGEN IS OK" << endl; +} + +void test_fft() +{ + int N = Param::N; + int N2p1 = Param::N2p1; + + FFT_engine fft_engine(N); + { + vector in(N,0L); + vector> out(N2p1); + clock_t start = clock(); + fft_engine.to_fft(out, in); + cout << "Forward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + for (size_t i = 0; i < N/2; i++) + { + if (int(round(real(out[i])))!=0 || int(round(imag(out[i])))!=0) + { + cout << i << " " << out[i] << endl; + assert(false); + } + } + } + + { + vector out; + vector> in(N2p1, complex(0.0,0.0)); + clock_t start = clock(); + fft_engine.from_fft(out, in); + cout << "Backward FFT (zero): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + for (size_t i = 0; i < N/2; i++) + { + assert(out[i] == 0L); + } + } + + { + vector in(N,0L); + in[0] = 1L; + vector> out(N2p1); + clock_t start = clock(); + fft_engine.to_fft(out, in); + cout << "Forward FFT (1,0,...0): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + for (size_t i = 0; i < N/2; i++) + { + assert(int(round(real(out[i])))==1 && int(round(imag(out[i])))==0); + } + } + + { + vector out; + vector> in(N2p1, complex(1.0,0.0)); + clock_t start = clock(); + fft_engine.from_fft(out, in); + cout << "Backward FFT (1,1,...1): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + assert(out[0] == 1L); + for (size_t i = 1; i < N; i++) + { + assert(out[i] == 0L); + } + } + + { + uniform_int_distribution sampler(INT_MIN, INT_MAX); + int coef = sampler(rand_engine); + vector out; + vector> in(N2p1, complex(double(coef),0.0)); + clock_t start = clock(); + fft_engine.from_fft(out, in); + cout << "Backward FFT (a,a,...a): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + assert(out[0] == coef); + for (size_t i = 1; i < N; i++) + { + assert(out[i] == 0L); + } + } + + { + uniform_int_distribution sampler(INT_MIN, INT_MAX); + vector in; + for (int i = 0; i < N; i++) + in.push_back(sampler(rand_engine)); + vector> interm(N2p1); + vector out; + clock_t start = clock(); + fft_engine.to_fft(interm, in); + cout << "Forward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + start = clock(); + fft_engine.from_fft(out, interm); + cout << "Backward FFT (random): " << float(clock()-start)/CLOCKS_PER_SEC << endl; + for (size_t i = 0; i < N; i++) + { + //cout << "i: " << i << "in[i]: " << in[i] << " out[i]: " << out[i] << endl; + assert(in[i] == out[i]); + } + } + + { + uniform_int_distribution sampler(-100, 100); + vector in1, in2, res; + for (int i = 0; i < N; i++) + { + in1.push_back(sampler(rand_engine)); + in2.push_back(sampler(rand_engine)); + res.push_back(in1[i]+in2[i]); + } + vector> interm1(N2p1); + vector> interm2(N2p1); + vector> intermres(N2p1); + vector out; + + fft_engine.to_fft(interm1, in1); + fft_engine.to_fft(interm2, in2); + + clock_t start = clock(); + intermres = interm1 + interm2; + cout << "FFT addition: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + fft_engine.from_fft(out, intermres); + + for (size_t i = 0; i < N; i++) + { + //cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl; + assert(res[i] == out[i]); + } + } + + { + uniform_int_distribution sampler(-100, 100); + vector in1, in2, res; + for (int i = 0; i < N; i++) + { + in1.push_back(sampler(rand_engine)); + in2.push_back(sampler(rand_engine)); + } + ZZX poly1, poly2, poly_res; + for (int i = 0; i < N; i++) + { + SetCoeff(poly1, i, in1[i]); + SetCoeff(poly2, i, in2[i]); + } + ZZX poly_mod; + SetCoeff(poly_mod, 0, 1); + SetCoeff(poly_mod, N, 1); + + clock_t start = clock(); + MulMod(poly_res, poly1, poly2, poly_mod); + cout << "NTL multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + for (int i = 0; i < N; i++) + { + res.push_back(conv(poly_res[i])); + } + + vector> interm1(N2p1); + vector> interm2(N2p1); + vector> intermres(N2p1); + vector out; + + fft_engine.to_fft(interm1, in1); + fft_engine.to_fft(interm2, in2); + + start = clock(); + intermres = interm1 * interm2; + cout << "FFT multiplication: " << float(clock()-start)/CLOCKS_PER_SEC << endl; + + fft_engine.from_fft(out, intermres); + + for (size_t i = 0; i < N; i++) + { + //cout << "i: " << i << " in1[i]: " << in1[i] << " in2[i]: " << in2[i] << " res[i]: " << res[i] << " out[i]: " << out[i] << endl; + assert(res[i] == out[i]); + } + } + + cout << "FFT is OK" << endl; +} + +void test_ntruhe_encrypt() +{ + SchemeNTRU s; + + { + int input = 0; + Ctxt_NTRU ct; + s.encrypt(ct, input); + int output = s.decrypt(ct); + assert(output == input); + } + { + int input = 1; + Ctxt_NTRU ct; + s.encrypt(ct, input); + int output = s.decrypt(ct); + assert(output == input); + } + cout << "NTRU ENCRYPTION IS OK" << endl; +} + +void test_lwehe_encrypt() +{ + SchemeLWE s; + + { + int input = 0; + Ctxt_LWE ct; + s.encrypt(ct, input); + int output = s.decrypt(ct); + assert(output == input); + } + { + int input = 1; + Ctxt_LWE ct; + s.encrypt(ct, input); + int output = s.decrypt(ct); + assert(output == input); + } + cout << "LWE ENCRYPTION IS OK" << endl; +} + +/* +void test_mod_switch() +{ + SchemeNTRU s; + { + int input = 0; + Ctxt_NTRU ct; + s.encrypt(ct, input); + s.modulo_switch_to_base(ct.data); + int output = 0; + for (int i = 0; i < ntru_he::n; i++) + { + output += ct.data[i] * sk_base.sk[i][0]; + } + output = output%Param::N2; + if (output > Param::N) + output -= Param::N2; + else if (output <= -Param::N) + output += Param::N2; + output = int(round(double(output*t)/double(Param::N2))); + assert(output == input); + } + { + int input = 1; + ntru_he::Ctxt ct; + ntru_he::encrypt(ct, input, sk_base); + ntru_he::modulo_switch(ct, ntru_he::q_base, Param::N2); + int output = 0; + for (int i = 0; i < ntru_he::n; i++) + { + output += ct[i] * sk_base.sk[i][0]; + } + output = output%Param::N2; + if (output > Param::N) + output -= Param::N2; + else if (output <= -Param::N) + output += Param::N2; + output = int(round(double(output*t)/double(Param::N2))); + assert(output == input); + } + { + int input = 0; + ntru_he::Ctxt ct; + ntru_he::encrypt(ct, input, sk_base); + ntru_he::modulo_switch_to_boot(ct); + int output = 0; + for (int i = 0; i < ntru_he::n; i++) + { + output += ct[i] * sk_base.sk[i][0]; + } + output = output%Param::N2; + if (output > Param::N) + output -= Param::N2; + else if (output <= -Param::N) + output += Param::N2; + output = int(round(double(output*t)/double(Param::N2))); + assert(output == input); + } + { + int input = 1; + ntru_he::Ctxt ct; + ntru_he::encrypt(ct, input, sk_base); + ntru_he::modulo_switch_to_boot(ct); + int output = 0; + for (int i = 0; i < ntru_he::n; i++) + { + output += ct[i] * sk_base.sk[i][0]; + } + output = output%Param::N2; + if (output > Param::N) + output -= Param::N2; + else if (output <= -Param::N) + output += Param::N2; + output = int(round(double(output*t)/double(Param::N2))); + assert(output == input); + } + cout << "MODULO SWITCHING IS OK" << endl; +} +*/ + +void test_bootstrap() +{ + SchemeNTRU s; + { + int input = 2; + Ctxt_NTRU ct; + s.encrypt(ct, input); + + s.bootstrap(ct); + + int output = s.decrypt(ct); + cout << "Bootstrapping output: " << output << endl; + assert(output == 1L); + } + { + int input = 0; + Ctxt_NTRU ct; + s.encrypt(ct, input); + + s.bootstrap(ct); + + int output = s.decrypt(ct); + cout << "Bootstrapping output: " << output << endl; + assert(output == 0L); + } + + cout << "BOOTSTRAPPING IS OK" << endl; +} + +/* +void test_nand_aux() +{ + ntru_he::SKey_base sk_base; + ntru_he::get_sk_base(sk_base); + + ntru_he::Ctxt ct; + ntru_he::get_nand_aux(ct, sk_base); + int output = 0; + for (int i = 0; i < ntru_he::n; i++) + { + output += ct[i] * sk_base.sk[i][0]; + } + output = ntru_he::mod_q_base(output); + assert( + output == (ntru_he::nand_const-ntru_he::q_base) + || output == (ntru_he::nand_const-ntru_he::q_base+1) + || output == (ntru_he::nand_const-ntru_he::q_base-1) + ); + cout << "NAND ENCRYPTION IS OK" << endl; +}*/ + +enum GateType {NAND, AND, OR}; + +void test_ntruhe_gate_helper(int in1, int in2, const SchemeNTRU& s, GateType g) +{ + float avg_time = 0.0; + for (int i = 0; i < 100; i++) + { + Ctxt_NTRU ct_res, ct1, ct2, ct_nand; + s.encrypt(ct1, in1); + s.encrypt(ct2, in2); + + if (g == NAND) + { + auto start = clock(); + s.nand_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "NAND output: " << output << endl; + assert(output == !(in1 & in2)); + } + else if (g == AND) { + auto start = clock(); + s.and_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "AND output: " << output << endl; + assert(output == (in1 & in2)); + } + else if (g == OR) { + auto start = clock(); + s.or_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "OR output: " << output << endl; + assert(output == (in1 | in2)); + } + } + cout << "Avg. time" << avg_time/100.0 << endl; +} + +void test_ntru_gate(GateType g) +{ + SchemeNTRU s; + + test_ntruhe_gate_helper(0, 0, s, g); + test_ntruhe_gate_helper(0, 1, s, g); + test_ntruhe_gate_helper(1, 0, s, g); + test_ntruhe_gate_helper(1, 1, s, g); +} + +void test_ntruhe_nand() +{ + GateType g = NAND; + + test_ntru_gate(g); + + cout << "NAND IS OK" << endl; +} + +void test_ntruhe_and() +{ + GateType g = AND; + + test_ntru_gate(g); + + cout << "AND IS OK" << endl; +} + +void test_ntruhe_or() +{ + GateType g = OR; + + test_ntru_gate(g); + + cout << "OR IS OK" << endl; +} + +void test_lwehe_gate_helper(int in1, int in2, SchemeLWE& s, GateType g) +{ + float avg_time = 0.0; + for (int i = 0; i < 100; i++) + { + Ctxt_LWE ct_res, ct1, ct2, ct_nand; + s.encrypt(ct1, in1); + s.encrypt(ct2, in2); + + if (g == NAND) + { + auto start = clock(); + s.nand_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "NAND output: " << output << endl; + assert(output == !(in1 & in2)); + } + else if (g == AND) { + auto start = clock(); + s.and_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "AND output: " << output << endl; + assert(output == (in1 & in2)); + } + else if (g == OR) { + auto start = clock(); + s.or_gate(ct_res, ct1, ct2); + avg_time += float(clock()-start)/CLOCKS_PER_SEC; + + int output = s.decrypt(ct_res); + + //cout << "OR output: " << output << endl; + assert(output == (in1 | in2)); + } + } + cout << "Avg. time" << avg_time/100.0 << endl; +} + +void test_lwe_gate(GateType g) +{ + SchemeLWE s; + + test_lwehe_gate_helper(0, 0, s, g); + test_lwehe_gate_helper(0, 1, s, g); + test_lwehe_gate_helper(1, 0, s, g); + test_lwehe_gate_helper(1, 1, s, g); +} + +void test_lwehe_nand() +{ + GateType g = NAND; + + test_lwe_gate(g); + + cout << "NAND IS OK" << endl; +} + +void test_lwehe_and() +{ + GateType g = AND; + + test_lwe_gate(g); + + cout << "AND IS OK" << endl; +} + +void test_lwehe_or() +{ + GateType g = OR; + + test_lwe_gate(g); + + cout << "OR IS OK" << endl; +} + +int main() +{ + //test_params(); + //test_sampler(); + //test_ntru_key_gen(); + //test_lwe_key_gen(); + //test_fft(); + //test_ntruhe_encrypt(); + //test_lwehe_encrypt(); + //test_mod_switch(); + //test_bootstrap(); + //test_nand_aux(); + //test_ntruhe_nand(); + test_lwehe_nand(); + //test_ntruhe_and(); + test_lwehe_and(); + //test_ntruhe_or(); + test_lwehe_or(); + return 0; +} \ No newline at end of file