Commit fea9d6a1 authored by Amr Aboelkher's avatar Amr Aboelkher Committed by Commit Bot

Roll shell-encryption e2a4af88a0e..4e0598f826

This CL is doing the following:
- Update the BUILD.gn correspondingly with the latest changes
- Adding new patch for using absl optional instead of std
- Update the Support-SHELL-in-chromium patch

Bug: 1072286
Change-Id: Id3c4bc2da5183f92345440eb887f5391f8d7b2e8
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2153443Reviewed-by: default avatarNico Weber <thakis@chromium.org>
Reviewed-by: default avatarAmr Aboelkher <amraboelkher@google.com>
Commit-Queue: Amr Aboelkher <amraboelkher@google.com>
Cr-Commit-Position: refs/heads/master@{#760979}
parent a4d01f33
...@@ -46,6 +46,7 @@ source_set("shell_encryption") { ...@@ -46,6 +46,7 @@ source_set("shell_encryption") {
"glog/logging.h", "glog/logging.h",
"src/bits_util.h", "src/bits_util.h",
"src/constants.h", "src/constants.h",
"src/context.h",
"src/error_params.h", "src/error_params.h",
"src/galois_key.h", "src/galois_key.h",
"src/int256.h", "src/int256.h",
...@@ -68,6 +69,7 @@ source_set("shell_encryption") { ...@@ -68,6 +69,7 @@ source_set("shell_encryption") {
] ]
sources = [ sources = [
"src/int256.cc", "src/int256.cc",
"src/montgomery.cc",
"src/ntt_parameters.cc", "src/ntt_parameters.cc",
"src/prng/chacha_prng.cc", "src/prng/chacha_prng.cc",
"src/prng/chacha_prng_util.cc", "src/prng/chacha_prng_util.cc",
...@@ -96,6 +98,7 @@ test("shell_encryption_test") { ...@@ -96,6 +98,7 @@ test("shell_encryption_test") {
public = [ public = [
"src/testing/coefficient_polynomial.h", "src/testing/coefficient_polynomial.h",
"src/testing/coefficient_polynomial_ciphertext.h", "src/testing/coefficient_polynomial_ciphertext.h",
"src/testing/parameter.h",
"src/testing/protobuf_matchers.h", "src/testing/protobuf_matchers.h",
"src/testing/status_matchers.h", "src/testing/status_matchers.h",
"src/testing/status_testing.h", "src/testing/status_testing.h",
...@@ -104,6 +107,7 @@ test("shell_encryption_test") { ...@@ -104,6 +107,7 @@ test("shell_encryption_test") {
] ]
sources = [ sources = [
"src/bits_util_test.cc", "src/bits_util_test.cc",
"src/context_test.cc",
"src/error_params_test.cc", "src/error_params_test.cc",
"src/galois_key_test.cc", "src/galois_key_test.cc",
"src/int256_test.cc", "src/int256_test.cc",
......
diff --git a/montgomery.h b/montgomery.h diff --git a/montgomery.cc b/montgomery.cc
index 50fb08a..336e925 100644 index d5221f4..9fbc5da 100644
--- a/montgomery.h --- a/montgomery.cc
+++ b/montgomery.h +++ b/montgomery.cc
@@ -105,8 +105,8 @@ struct MontgomeryIntParams { @@ -22,8 +22,8 @@ template <typename T>
static rlwe::StatusOr<std::unique_ptr<MontgomeryIntParams>> Create( rlwe::StatusOr<std::unique_ptr<const MontgomeryIntParams<T>>>
Int modulus) { MontgomeryIntParams<T>::Create(Int modulus) {
// Check that the modulus is smaller than max(Int) / 4. // Check that the modulus is smaller than max(Int) / 4.
- if (Int most_significant_bit = modulus >> (bitsize_int - 2); - if (Int most_significant_bit = modulus >> (bitsize_int - 2);
- most_significant_bit != 0) { - most_significant_bit != 0) {
+ Int most_significant_bit = modulus >> (bitsize_int - 2); + Int most_significant_bit = modulus >> (bitsize_int - 2);
+ if (most_significant_bit != 0) { + if (most_significant_bit != 0) {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"The modulus should be less than 2^", (bitsize_int - 2), ".")); "The modulus should be less than 2^", (bitsize_int - 2), "."));
} }
diff --git a/ntt_parameters.h b/ntt_parameters.h diff --git a/ntt_parameters.h b/ntt_parameters.h
index da03bfe..270a5a2 100644 index 56e1871..c3da197 100644
--- a/ntt_parameters.h --- a/ntt_parameters.h
+++ b/ntt_parameters.h +++ b/ntt_parameters.h
@@ -103,7 +103,8 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs, @@ -103,7 +103,8 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
...@@ -28,12 +28,12 @@ index da03bfe..270a5a2 100644 ...@@ -28,12 +28,12 @@ index da03bfe..270a5a2 100644
} }
} }
diff --git a/polynomial.h b/polynomial.h diff --git a/polynomial.h b/polynomial.h
index 0b73dfa..fc0e38a 100644 index 07843b2..3cf0c77 100644
--- a/polynomial.h --- a/polynomial.h
+++ b/polynomial.h +++ b/polynomial.h
@@ -80,7 +80,8 @@ class Polynomial { @@ -80,7 +80,8 @@ class Polynomial {
const NttParameters<ModularInt>& ntt_params, const NttParameters<ModularInt>* ntt_params,
ModularIntParams* modular_params) { const ModularIntParams* modular_params) {
// Check to ensure that the coefficient vector is of the correct length. // Check to ensure that the coefficient vector is of the correct length.
- if (int len = poly_coeffs.size(); len <= 0 || (len & (len - 1)) != 0) { - if (int len = poly_coeffs.size(); len <= 0 || (len & (len - 1)) != 0) {
+ int len = poly_coeffs.size(); + int len = poly_coeffs.size();
...@@ -42,10 +42,10 @@ index 0b73dfa..fc0e38a 100644 ...@@ -42,10 +42,10 @@ index 0b73dfa..fc0e38a 100644
return Polynomial(); return Polynomial();
} }
diff --git a/prng/chacha_prng_util.cc b/prng/chacha_prng_util.cc diff --git a/prng/chacha_prng_util.cc b/prng/chacha_prng_util.cc
index bd78ab8..39099e1 100644 index dfab1d9..c49c82d 100644
--- a/prng/chacha_prng_util.cc --- a/prng/chacha_prng_util.cc
+++ b/prng/chacha_prng_util.cc +++ b/prng/chacha_prng_util.cc
@@ -10,7 +10,8 @@ @@ -24,7 +24,8 @@
#include <openssl/rand.h> #include <openssl/rand.h>
#include "status_macros.h" #include "status_macros.h"
...@@ -55,7 +55,7 @@ index bd78ab8..39099e1 100644 ...@@ -55,7 +55,7 @@ index bd78ab8..39099e1 100644
absl::Status ChaChaPrngResalt(absl::string_view key, int buffer_size, absl::Status ChaChaPrngResalt(absl::string_view key, int buffer_size,
int* salt_counter, int* position_in_buffer, int* salt_counter, int* position_in_buffer,
@@ -71,4 +72,5 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key, @@ -85,4 +86,5 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
return rand64; return rand64;
} }
...@@ -63,10 +63,10 @@ index bd78ab8..39099e1 100644 ...@@ -63,10 +63,10 @@ index bd78ab8..39099e1 100644
+} // namespace internal +} // namespace internal
+} // namespace rlwe +} // namespace rlwe
diff --git a/prng/chacha_prng_util.h b/prng/chacha_prng_util.h diff --git a/prng/chacha_prng_util.h b/prng/chacha_prng_util.h
index 2bb2fcf..5490505 100644 index 32cac5b..8eb8118 100644
--- a/prng/chacha_prng_util.h --- a/prng/chacha_prng_util.h
+++ b/prng/chacha_prng_util.h +++ b/prng/chacha_prng_util.h
@@ -12,7 +12,8 @@ @@ -28,7 +28,8 @@
#include "integral_types.h" #include "integral_types.h"
#include "statusor.h" #include "statusor.h"
...@@ -76,7 +76,7 @@ index 2bb2fcf..5490505 100644 ...@@ -76,7 +76,7 @@ index 2bb2fcf..5490505 100644
const int kChaChaKeyBytesSize = 32; const int kChaChaKeyBytesSize = 32;
const int kChaChaNonceSize = 12; const int kChaChaNonceSize = 12;
@@ -43,6 +44,7 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key, @@ -59,6 +60,7 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
int* salt_counter, int* salt_counter,
std::vector<Uint8>* buffer); std::vector<Uint8>* buffer);
...@@ -86,10 +86,10 @@ index 2bb2fcf..5490505 100644 ...@@ -86,10 +86,10 @@ index 2bb2fcf..5490505 100644
#endif // RLWE_CHACHA_PRNG_UTIL_H_ #endif // RLWE_CHACHA_PRNG_UTIL_H_
diff --git a/statusor.h b/statusor.h diff --git a/statusor.h b/statusor.h
index 4fdeade..42761e6 100644 index d8addb5..200f62d 100644
--- a/statusor.h --- a/statusor.h
+++ b/statusor.h +++ b/statusor.h
@@ -74,7 +74,7 @@ class StatusOr { @@ -96,7 +96,7 @@ class StatusOr {
operator absl::Status() const { return status(); } operator absl::Status() const { return status(); }
...@@ -99,11 +99,11 @@ index 4fdeade..42761e6 100644 ...@@ -99,11 +99,11 @@ index 4fdeade..42761e6 100644
if (value_) { if (value_) {
return OtherStatusOrType<T>(std::move(value_.value())); return OtherStatusOrType<T>(std::move(value_.value()));
diff --git a/symmetric_encryption.h b/symmetric_encryption.h diff --git a/symmetric_encryption.h b/symmetric_encryption.h
index d4ad730..0223149 100644 index e120b18..987e86f 100644
--- a/symmetric_encryption.h --- a/symmetric_encryption.h
+++ b/symmetric_encryption.h +++ b/symmetric_encryption.h
@@ -584,8 +584,8 @@ class SymmetricRlweKey { @@ -571,8 +571,8 @@ class SymmetricRlweKey {
typename ModularIntQ::Params* modulus_params_q, const typename ModularIntQ::Params* modulus_params_q,
const NttParameters<ModularIntQ>* ntt_params_q) const { const NttParameters<ModularIntQ>* ntt_params_q) const {
// Configuration failure. // Configuration failure.
- if (Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One(); - if (Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One();
......
diff --git a/ntt_parameters.h b/ntt_parameters.h
index c3da197..55671ec 100644
--- a/ntt_parameters.h
+++ b/ntt_parameters.h
@@ -168,7 +168,7 @@ struct NttParameters {
~NttParameters() = default;
int number_coeffs;
- std::optional<ModularInt> n_inv_ptr;
+ absl::optional<ModularInt> n_inv_ptr;
std::vector<ModularInt> psis_bitrev;
std::vector<ModularInt> psis_inv_bitrev;
std::vector<unsigned int> bitrevs;
This diff is collapsed.
...@@ -8,10 +8,6 @@ both add and multiply encrypted data. It uses modulus-switching to enable ...@@ -8,10 +8,6 @@ both add and multiply encrypted data. It uses modulus-switching to enable
arbitrary-depth homomorphic encryption (provided sufficiently large parameters arbitrary-depth homomorphic encryption (provided sufficiently large parameters
are set). RLWE is also believed to be secure in the face of quantum computers. are set). RLWE is also believed to be secure in the face of quantum computers.
This library is designed to be compact and readable. The library includes just
1500 lines of heavily-commented source (700 lines of source, 600 lines of
comments and 200 blank lines) and 900 lines of unit tests.
We intend this project to be both a useful experimental library and a We intend this project to be both a useful experimental library and a
comprehensible tool for learning about and extending RLWE. comprehensible tool for learning about and extending RLWE.
...@@ -42,8 +38,7 @@ above, Alice can securely offload computation to another entity without worrying ...@@ -42,8 +38,7 @@ above, Alice can securely offload computation to another entity without worrying
that doing so will reveal any of her private information. Among many other that doing so will reveal any of her private information. Among many other
applications, it enables *private information retrieval* (PIR) - databases that applications, it enables *private information retrieval* (PIR) - databases that
can serve user requests without learning which pieces of data the users can serve user requests without learning which pieces of data the users
requested. (For more information on PIR, see [XPIR: Private Information requested. (For more information on PIR, see [Communication--Computation Trade-offs in PIR](https://eprint.iacr.org/2019/1483.pdf).)
Retrieval for Everyone](https://eprint.iacr.org/2014/1025.pdf).)
## Ring Learning with Errors ## Ring Learning with Errors
...@@ -56,7 +51,7 @@ hardness. ...@@ -56,7 +51,7 @@ hardness.
The cryptosystem implemented in this library is from [Fully Homomorphic The cryptosystem implemented in this library is from [Fully Homomorphic
Encryption from Ring-LWE and Security for Key Dependent Encryption from Ring-LWE and Security for Key Dependent
Messages](http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf). Messages](http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf).
The cryptosystem works as follows: The cryptosystem works as follows.
### Preliminaries ### Preliminaries
...@@ -72,7 +67,7 @@ We also need a modulus *t* that is much smaller than *q*. *log(t)* is the number ...@@ -72,7 +67,7 @@ We also need a modulus *t* that is much smaller than *q*. *log(t)* is the number
of bits of plaintext information we are able to fit into each coefficient of a of bits of plaintext information we are able to fit into each coefficient of a
ciphertext polynomial. The importance of *t* will become apparent soon. ciphertext polynomial. The importance of *t* will become apparent soon.
Finally, we need two other components: a Gaussian distribution *Y* with mean 0 Finally, we need two other components: a binomial distribution *Y* with mean 0
and standard deviation *w*, where *w* is a parameter of the cryptosystem. The and standard deviation *w*, where *w* is a parameter of the cryptosystem. The
importance of this distribution will become apparent soon. importance of this distribution will become apparent soon.
...@@ -230,10 +225,10 @@ This library consists of four major components that form a RLWE stack. ...@@ -230,10 +225,10 @@ This library consists of four major components that form a RLWE stack.
### Montgomery Integers ### Montgomery Integers
This library is implemented in `montgomery.h`. At the lowest level is a library This library is implemented in `montgomery.(cch|h)`. At the lowest level is a
that represents modular integers in Montgomery form, which speeds up the library that represents modular integers in Montgomery form, which speeds up the
repeated use of the modulo operator. This library supports 64-bit integers, repeated use of the modulo operator. This library supports 128-bit integers,
meaning that it can support a modulus of up to 30-bits. For larger modulus meaning that it can support a modulus of up to 126 bits. For larger modulus
sizes, the higher levels of the stack can be parameterized with a different type sizes, the higher levels of the stack can be parameterized with a different type
(such as a BigInteger). Montgomery integers require several parameters in (such as a BigInteger). Montgomery integers require several parameters in
addition to the modulus to perform the modular operations efficiently. These addition to the modulus to perform the modular operations efficiently. These
...@@ -306,10 +301,10 @@ cd shell-encryption ...@@ -306,10 +301,10 @@ cd shell-encryption
bazel build :all --cxxopt='-std=c++17' bazel build :all --cxxopt='-std=c++17'
``` ```
You may also run all tests using the following command: You may also run all tests (recursively) using the following command:
```bash ```bash
bazel test :all --cxxopt='-std=c++17' bazel test ... --cxxopt='-std=c++17'
``` ```
If you get an error, you may need to build/test with the following flags: If you get an error, you may need to build/test with the following flags:
......
...@@ -12,63 +12,62 @@ http_archive( ...@@ -12,63 +12,62 @@ http_archive(
], ],
) )
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
rules_cc_dependencies()
# rules_proto defines abstract rules for building Protocol Buffers. # rules_proto defines abstract rules for building Protocol Buffers.
# https://github.com/bazelbuild/rules_proto
http_archive( http_archive(
name = "rules_proto", name = "rules_proto",
sha256 = "57001a3b33ec690a175cdf0698243431ef27233017b9bed23f96d44b9c98242f", sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208",
strip_prefix = "rules_proto-9cd4f8f1ede19d81c6d48910429fe96776e567b1", strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313",
urls = [ urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/9cd4f8f1ede19d81c6d48910429fe96776e567b1.tar.gz", "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
"https://github.com/bazelbuild/rules_proto/archive/9cd4f8f1ede19d81c6d48910429fe96776e567b1.tar.gz", "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
], ],
) )
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
rules_cc_dependencies()
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
rules_proto_dependencies() rules_proto_dependencies()
rules_proto_toolchains() rules_proto_toolchains()
# Install gtest. # Install gtest.
http_archive(
name = "com_google_googletest",
urls = [
"https://github.com/google/googletest/archive/release-1.10.0.zip",
],
sha256 = "94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91",
strip_prefix = "googletest-release-1.10.0",
)
# abseil-cpp
git_repository( git_repository(
name = "com_google_absl", name = "com_github_google_googletest",
remote = "https://github.com/abseil/abseil-cpp.git", commit = "703bd9caab50b139428cea1aaff9974ebee5742e", # tag = "release-1.10.0"
commit = "0d5ce2797eb695aee7e019e25323251ef6ffc277", remote = "https://github.com/google/googletest.git",
shallow_since = "1570114335 -0400",
) )
# BoringSSL # abseil-cpp
http_archive( http_archive(
name = "boringssl", name = "com_google_absl",
sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353",
strip_prefix = "abseil-cpp-20200225",
urls = [ urls = [
"https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz",
"https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
], ],
sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3", )
strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778",
# BoringSSL
git_repository(
name = "boringssl",
commit = "67ffb9606462a1897d3a5edf5c06d329878ba600", # https://boringssl.googlesource.com/boringssl/+/refs/heads/master-with-bazel
remote = "https://boringssl.googlesource.com/boringssl",
shallow_since = "1585767053 +0000"
) )
# Logging # Logging
http_archive( http_archive(
name = "com_github_glog_glog", name = "com_github_google_glog",
build_file = "@//glog:BUILD", urls = ["https://github.com/google/glog/archive/96a2f23dca4cc7180821ca5f32e526314395d26a.zip"],
urls = ["https://github.com/google/glog/archive/v0.3.5.tar.gz"], strip_prefix = "glog-96a2f23dca4cc7180821ca5f32e526314395d26a",
strip_prefix = "glog-0.3.5", sha256 = "6281aa4eeecb9e932d7091f99872e7b26fa6aacece49c15ce5b14af2b7ec050f",
) )
# gflags, needed for glog # gflags, needed for glog
git_repository( http_archive(
name = "com_github_gflags_gflags", name = "com_github_gflags_gflags",
remote = "https://github.com/gflags/gflags.git", urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"],
tag = "v2.2.2", sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
strip_prefix = "gflags-2.2.2",
) )
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_BITS_UTIL_H_ #ifndef RLWE_BITS_UTIL_H_
#define RLWE_BITS_UTIL_H_ #define RLWE_BITS_UTIL_H_
...@@ -83,6 +99,10 @@ inline unsigned int CountLeadingZeros128(absl::uint128 x) { ...@@ -83,6 +99,10 @@ inline unsigned int CountLeadingZeros128(absl::uint128 x) {
return CountLeadingZeros64(absl::Uint128Low64(x)) + 64; return CountLeadingZeros64(absl::Uint128Low64(x)) + 64;
} }
inline unsigned int BitLength(absl::uint128 x) {
return 128 - CountLeadingZeros128(x);
}
} // namespace internal } // namespace internal
} // namespace rlwe } // namespace rlwe
......
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bits_util.h" #include "bits_util.h"
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "absl/numeric/int128.h"
using ::testing::Eq; using ::testing::Eq;
...@@ -33,4 +48,12 @@ TEST(BitsUtilTest, CountLeadingZeros64Works) { ...@@ -33,4 +48,12 @@ TEST(BitsUtilTest, CountLeadingZeros64Works) {
} }
} }
TEST(BitsUtilTest, BitLengthWorks) {
absl::uint128 value = absl::MakeUint128(0x8000000000000000, 0);
for (int i = 0; i <= 128; i++) {
EXPECT_THAT(rlwe::internal::BitLength(value), Eq(128 - i));
value >>= 1;
}
}
} // namespace } // namespace
/*
* Copyright 2020 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_CONTEXT_H_
#define RLWE_CONTEXT_H_
#include <memory>
#include "absl/memory/memory.h"
#include "error_params.h"
#include "ntt_parameters.h"
#include "status_macros.h"
#include "statusor.h"
namespace rlwe {
// Defines and holds the context of the RLWE encryption scheme.
// Thread safe..
template <typename ModularInt>
class RlweContext {
using Int = typename ModularInt::Int;
using ModulusParams = typename ModularInt::Params;
public:
// Structure to hold parameters for the RLWE encryption scheme. The parameters
// include:
// - modulus, an Int which needs to be congruent to 1 modulo 2 * (1 << log_n);
// - log_n: the logarithm of the number of coefficients of the polynomials;
// - log_t: the number of bits of the plaintext space, which will be equal to
// (1 << log_t) + 1;
// - variance, the error variance to use when sampling noises or secrets.
struct Parameters {
Int modulus;
size_t log_n;
size_t log_t;
size_t variance;
};
// Factory function to create a context from a context_params.
static rlwe::StatusOr<std::unique_ptr<const RlweContext>> Create(
Parameters context_params) {
// Create the modulus parameters.
RLWE_ASSIGN_OR_RETURN(
std::unique_ptr<const ModulusParams> modulus_parameters,
ModulusParams::Create(context_params.modulus));
// Create the NTT parameters.
RLWE_ASSIGN_OR_RETURN(NttParameters<ModularInt> ntt_params_temp,
InitializeNttParameters<ModularInt>(
context_params.log_n, modulus_parameters.get()));
std::unique_ptr<const NttParameters<ModularInt>> ntt_params =
std::make_unique<const NttParameters<ModularInt>>(
std::move(ntt_params_temp));
// Create the error parameters.
RLWE_ASSIGN_OR_RETURN(ErrorParams<ModularInt> error_params_temp,
ErrorParams<ModularInt>::Create(
context_params.log_t, context_params.variance,
modulus_parameters.get(), ntt_params.get()));
std::unique_ptr<const ErrorParams<ModularInt>> error_params =
std::make_unique<const ErrorParams<ModularInt>>(
std::move(error_params_temp));
return absl::WrapUnique<const RlweContext>(
new RlweContext(std::move(modulus_parameters), std::move(ntt_params),
std::move(error_params), std::move(context_params)));
}
// Disallow copy and copy-assign, allow move and move-assign.
RlweContext(const RlweContext&) = delete;
RlweContext& operator=(const RlweContext&) = delete;
RlweContext(RlweContext&&) = default;
RlweContext& operator=(RlweContext&&) = default;
~RlweContext() = default;
// Getters.
const ModulusParams* GetModulusParams() const {
return modulus_parameters_.get();
}
const NttParameters<ModularInt>* GetNttParams() const {
return ntt_parameters_.get();
}
const ErrorParams<ModularInt>* GetErrorParams() const {
return error_parameters_.get();
}
const Int GetModulus() const { return context_params_.modulus; }
const size_t GetLogN() const { return context_params_.log_n; }
const size_t GetN() const { return 1 << context_params_.log_n; }
const size_t GetLogT() const { return context_params_.log_t; }
const Int GetT() const {
return (static_cast<Int>(1) << context_params_.log_t) + static_cast<Int>(1);
}
const size_t GetVariance() const { return context_params_.variance; }
private:
RlweContext(std::unique_ptr<const ModulusParams> modulus_parameters,
std::unique_ptr<const NttParameters<ModularInt>> ntt_parameters,
std::unique_ptr<const ErrorParams<ModularInt>> error_parameters,
Parameters context_params)
: modulus_parameters_(std::move(modulus_parameters)),
ntt_parameters_(std::move(ntt_parameters)),
error_parameters_(std::move(error_parameters)),
context_params_(std::move(context_params)) {}
const std::unique_ptr<const ModulusParams> modulus_parameters_;
const std::unique_ptr<const NttParameters<ModularInt>> ntt_parameters_;
const std::unique_ptr<const ErrorParams<ModularInt>> error_parameters_;
const Parameters context_params_;
};
} // namespace rlwe
#endif // RLWE_CONTEXT_H_
/*
* Copyright 2020 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "context.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/numeric/int128.h"
#include "constants.h"
#include "integral_types.h"
#include "montgomery.h"
#include "status_macros.h"
#include "testing/parameters.h"
#include "testing/status_testing.h"
namespace {
template <typename ModularInt>
class ContextTest : public ::testing::Test {};
TYPED_TEST_SUITE(ContextTest, rlwe::testing::ModularIntTypes);
TYPED_TEST(ContextTest, CreateWorks) {
for (const auto& params :
rlwe::testing::ContextParameters<TypeParam>::value) {
ASSERT_OK_AND_ASSIGN(auto context,
rlwe::RlweContext<TypeParam>::Create(params));
}
}
TYPED_TEST(ContextTest, ParametersMatch) {
for (const auto& params :
rlwe::testing::ContextParameters<TypeParam>::value) {
ASSERT_OK_AND_ASSIGN(auto context,
rlwe::RlweContext<TypeParam>::Create(params));
ASSERT_EQ(context->GetLogN(), params.log_n);
ASSERT_EQ(context->GetN(), context->GetNttParams()->number_coeffs);
ASSERT_EQ(context->GetLogT(), params.log_t);
ASSERT_EQ(context->GetModulus(), params.modulus);
ASSERT_EQ(context->GetModulus(), context->GetModulusParams()->modulus);
ASSERT_EQ(context->GetVariance(), params.variance);
}
}
} // namespace
...@@ -37,7 +37,8 @@ template <typename ModularInt> ...@@ -37,7 +37,8 @@ template <typename ModularInt>
class ErrorParams { class ErrorParams {
public: public:
static rlwe::StatusOr<ErrorParams> Create( static rlwe::StatusOr<ErrorParams> Create(
const int log_t, Uint64 variance, typename ModularInt::Params* params, const int log_t, Uint64 variance,
const typename ModularInt::Params* params,
const rlwe::NttParameters<ModularInt>* ntt_params) { const rlwe::NttParameters<ModularInt>* ntt_params) {
if (log_t > params->log_modulus - 1) { if (log_t > params->log_modulus - 1) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
...@@ -74,7 +75,7 @@ class ErrorParams { ...@@ -74,7 +75,7 @@ class ErrorParams {
private: private:
// Constructor to set up the params. // Constructor to set up the params.
ErrorParams(const int log_t, Uint64 variance, ErrorParams(const int log_t, Uint64 variance,
typename ModularInt::Params* params, const typename ModularInt::Params* params,
const rlwe::NttParameters<ModularInt>* ntt_params) const rlwe::NttParameters<ModularInt>* ntt_params)
: t_(params->One()) { : t_(params->One()) {
t_ = (params->One() << log_t) + params->One(); t_ = (params->One() << log_t) + params->One();
......
...@@ -90,7 +90,7 @@ class GaloisKey { ...@@ -90,7 +90,7 @@ class GaloisKey {
// 2^{log_decomposition_modulus}. Crashes for non-valid input parameters. // 2^{log_decomposition_modulus}. Crashes for non-valid input parameters.
static rlwe::StatusOr<GaloisKey> Deserialize( static rlwe::StatusOr<GaloisKey> Deserialize(
const SerializedGaloisKey& serialized, const SerializedGaloisKey& serialized,
typename ModularInt::Params* modulus_params, const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) { const NttParameters<ModularInt>* ntt_params) {
RLWE_ASSIGN_OR_RETURN(RelinearizationKey<ModularInt> key, RLWE_ASSIGN_OR_RETURN(RelinearizationKey<ModularInt> key,
RelinearizationKey<ModularInt>::Deserialize( RelinearizationKey<ModularInt>::Deserialize(
......
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
licenses(["notice"])
exports_files(["LICENSE"])
cc_library( cc_library(
name = "glog", name = "glog",
srcs = [ srcs = [
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "glog/logging.h" #include <glog/logging.h>
#include "absl/numeric/int128.h" #include "absl/numeric/int128.h"
namespace rlwe { namespace rlwe {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "glog/logging.h" #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "absl/container/fixed_array.h" #include "absl/container/fixed_array.h"
#include "absl/numeric/int128.h" #include "absl/numeric/int128.h"
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_INTEGRAL_TYPES_H_ #ifndef RLWE_INTEGRAL_TYPES_H_
#define RLWE_INTEGRAL_TYPES_H_ #define RLWE_INTEGRAL_TYPES_H_
......
This diff is collapsed.
...@@ -32,7 +32,7 @@ namespace internal { ...@@ -32,7 +32,7 @@ namespace internal {
template <typename ModularInt> template <typename ModularInt>
void FillWithEveryPower(const ModularInt& base, unsigned int n, void FillWithEveryPower(const ModularInt& base, unsigned int n,
std::vector<ModularInt>* row, std::vector<ModularInt>* row,
typename ModularInt::Params* params) { const typename ModularInt::Params* params) {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
(*row)[i].AddInPlace(base.ModExp(i, params), params); (*row)[i].AddInPlace(base.ModExp(i, params), params);
} }
...@@ -40,7 +40,7 @@ void FillWithEveryPower(const ModularInt& base, unsigned int n, ...@@ -40,7 +40,7 @@ void FillWithEveryPower(const ModularInt& base, unsigned int n,
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity( rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
unsigned int log_n, typename ModularInt::Params* params) { unsigned int log_n, const typename ModularInt::Params* params) {
typename ModularInt::Int n = params->One() << log_n; typename ModularInt::Int n = params->One() << log_n;
typename ModularInt::Int half_n = n >> 1; typename ModularInt::Int half_n = n >> 1;
...@@ -78,7 +78,7 @@ rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity( ...@@ -78,7 +78,7 @@ rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
// Each item of the vector is in modular integer representation. // Each item of the vector is in modular integer representation.
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsis( rlwe::StatusOr<std::vector<ModularInt>> NttPsis(
unsigned int log_n, typename ModularInt::Params* params) { unsigned int log_n, const typename ModularInt::Params* params) {
// Obtain psi, a primitive 2n-th root of unity (hence log_n + 1). // Obtain psi, a primitive 2n-th root of unity (hence log_n + 1).
RLWE_ASSIGN_OR_RETURN( RLWE_ASSIGN_OR_RETURN(
ModularInt psi, ModularInt psi,
...@@ -116,7 +116,7 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs, ...@@ -116,7 +116,7 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
// bitreversed powers of the primitive 2n-th root of unity. // bitreversed powers of the primitive 2n-th root of unity.
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev( rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
unsigned int log_n, typename ModularInt::Params* params) { unsigned int log_n, const typename ModularInt::Params* params) {
// Retrieve the table for the forward transformation. // Retrieve the table for the forward transformation.
RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis, RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis,
internal::NttPsis<ModularInt>(log_n, params)); internal::NttPsis<ModularInt>(log_n, params));
...@@ -129,7 +129,7 @@ rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev( ...@@ -129,7 +129,7 @@ rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
// of the bitreversed powers of the primitive 2n-th root of unity plus 1. // of the bitreversed powers of the primitive 2n-th root of unity plus 1.
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev( rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev(
unsigned int log_n, typename ModularInt::Params* params) { unsigned int log_n, const typename ModularInt::Params* params) {
// Retrieve the table for the forward transformation. // Retrieve the table for the forward transformation.
RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row, RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row,
internal::NttPsis<ModularInt>(log_n, params)); internal::NttPsis<ModularInt>(log_n, params));
...@@ -168,7 +168,7 @@ struct NttParameters { ...@@ -168,7 +168,7 @@ struct NttParameters {
~NttParameters() = default; ~NttParameters() = default;
int number_coeffs; int number_coeffs;
std::unique_ptr<ModularInt> n_inv_ptr; absl::optional<ModularInt> n_inv_ptr;
std::vector<ModularInt> psis_bitrev; std::vector<ModularInt> psis_bitrev;
std::vector<ModularInt> psis_inv_bitrev; std::vector<ModularInt> psis_inv_bitrev;
std::vector<unsigned int> bitrevs; std::vector<unsigned int> bitrevs;
...@@ -178,7 +178,7 @@ struct NttParameters { ...@@ -178,7 +178,7 @@ struct NttParameters {
// Does not take ownership of params. // Does not take ownership of params.
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters( rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
int log_n, typename ModularInt::Params* params) { int log_n, const typename ModularInt::Params* params) {
// Abort if log_n is non-positive. // Abort if log_n is non-positive.
if (log_n <= 0) { if (log_n <= 0) {
return absl::InvalidArgumentError("log_n must be positive"); return absl::InvalidArgumentError("log_n must be positive");
...@@ -203,11 +203,10 @@ rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters( ...@@ -203,11 +203,10 @@ rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n)); absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n));
} }
// 1-dimensional vector containing the inverse of n // Compute the inverse of n.
typename ModularInt::Int n = params->One() << log_n; typename ModularInt::Int n = params->One() << log_n;
RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params)); RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params));
auto minv = mn.MultiplicativeInverse(params); output.n_inv_ptr = mn.MultiplicativeInverse(params);
output.n_inv_ptr = absl::make_unique<ModularInt>(minv);
RLWE_ASSIGN_OR_RETURN(output.psis_bitrev, RLWE_ASSIGN_OR_RETURN(output.psis_bitrev,
NttPsisBitrev<ModularInt>(log_n, params)); NttPsisBitrev<ModularInt>(log_n, params));
......
...@@ -85,12 +85,16 @@ class PolynomialTest : public ::testing::Test { ...@@ -85,12 +85,16 @@ class PolynomialTest : public ::testing::Test {
q_.reset(new CoefficientPolynomial(q_coeffs, params14_.get())); q_.reset(new CoefficientPolynomial(q_coeffs, params14_.get()));
// Acquire all of the NTT parameters. // Acquire all of the NTT parameters.
ASSERT_OK_AND_ASSIGN(ntt_params_, rlwe::InitializeNttParameters<uint_m>( ASSERT_OK_AND_ASSIGN(auto ntt_params, rlwe::InitializeNttParameters<uint_m>(
log_n, params14_.get())); log_n, params14_.get()));
ntt_params_ =
absl::make_unique<rlwe::NttParameters<uint_m>>(std::move(ntt_params));
// Put p and q in the NTT domain. // Put p and q in the NTT domain.
ntt_p_ = Polynomial::ConvertToNtt(p_coeffs, ntt_params_, params14_.get()); ntt_p_ =
ntt_q_ = Polynomial::ConvertToNtt(q_coeffs, ntt_params_, params14_.get()); Polynomial::ConvertToNtt(p_coeffs, ntt_params_.get(), params14_.get());
ntt_q_ =
Polynomial::ConvertToNtt(q_coeffs, ntt_params_.get(), params14_.get());
} }
std::unique_ptr<Prng> MakePrng(absl::string_view seed) { std::unique_ptr<Prng> MakePrng(absl::string_view seed) {
...@@ -98,8 +102,8 @@ class PolynomialTest : public ::testing::Test { ...@@ -98,8 +102,8 @@ class PolynomialTest : public ::testing::Test {
return prng; return prng;
} }
std::unique_ptr<uint_m::Params> params14_; std::unique_ptr<const uint_m::Params> params14_;
rlwe::NttParameters<uint_m> ntt_params_; std::unique_ptr<rlwe::NttParameters<uint_m>> ntt_params_;
std::unique_ptr<CoefficientPolynomial> p_; std::unique_ptr<CoefficientPolynomial> p_;
std::unique_ptr<CoefficientPolynomial> q_; std::unique_ptr<CoefficientPolynomial> q_;
Polynomial ntt_p_; Polynomial ntt_p_;
...@@ -132,8 +136,8 @@ TYPED_TEST(PolynomialTest, CoeffsCorrectlyReturnsCoefficients) { ...@@ -132,8 +136,8 @@ TYPED_TEST(PolynomialTest, CoeffsCorrectlyReturnsCoefficients) {
v.push_back(elt); v.push_back(elt);
} }
Polynomial ntt_v = Polynomial ntt_v = Polynomial::ConvertToNtt(v, this->ntt_params_.get(),
Polynomial::ConvertToNtt(v, this->ntt_params_, this->params14_.get()); this->params14_.get());
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
EXPECT_EQ(ntt_v.Coeffs()[j].ExportInt(this->params14_.get()), EXPECT_EQ(ntt_v.Coeffs()[j].ExportInt(this->params14_.get()),
...@@ -167,7 +171,7 @@ TYPED_TEST(PolynomialTest, Symmetry) { ...@@ -167,7 +171,7 @@ TYPED_TEST(PolynomialTest, Symmetry) {
this->SetParams(1 << i, i); this->SetParams(1 << i, i);
EXPECT_TRUE(this->ntt_p_.IsValid()); EXPECT_TRUE(this->ntt_p_.IsValid());
CoefficientPolynomial p_prime( CoefficientPolynomial p_prime(
this->ntt_p_.InverseNtt(this->ntt_params_, this->params14_.get()), this->ntt_p_.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
EXPECT_EQ(*this->p_, p_prime); EXPECT_EQ(*this->p_, p_prime);
} }
...@@ -218,12 +222,12 @@ TYPED_TEST(PolynomialTest, BinopOfDifferentLengths) { ...@@ -218,12 +222,12 @@ TYPED_TEST(PolynomialTest, BinopOfDifferentLengths) {
std::vector<uint_m> x(1 << i, this->zero_); std::vector<uint_m> x(1 << i, this->zero_);
std::vector<uint_m> y(1 << j, this->zero_); std::vector<uint_m> y(1 << j, this->zero_);
this->ntt_params_.bitrevs = rlwe::internal::BitrevArray(i); this->ntt_params_->bitrevs = rlwe::internal::BitrevArray(i);
Polynomial ntt_x = Polynomial ntt_x = Polynomial::ConvertToNtt(x, this->ntt_params_.get(),
Polynomial::ConvertToNtt(x, this->ntt_params_, this->params14_.get()); this->params14_.get());
this->ntt_params_.bitrevs = rlwe::internal::BitrevArray(j); this->ntt_params_->bitrevs = rlwe::internal::BitrevArray(j);
Polynomial ntt_y = Polynomial ntt_y = Polynomial::ConvertToNtt(y, this->ntt_params_.get(),
Polynomial::ConvertToNtt(y, this->ntt_params_, this->params14_.get()); this->params14_.get());
EXPECT_TRUE(ntt_x.IsValid()); EXPECT_TRUE(ntt_x.IsValid());
EXPECT_TRUE(ntt_y.IsValid()); EXPECT_TRUE(ntt_y.IsValid());
...@@ -293,10 +297,10 @@ TYPED_TEST(PolynomialTest, Multiply) { ...@@ -293,10 +297,10 @@ TYPED_TEST(PolynomialTest, Multiply) {
this->ntt_q_.Mul(this->ntt_p_, this->params14_.get())); this->ntt_q_.Mul(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1( CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
CoefficientPolynomial res2( CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected, ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected,
...@@ -322,7 +326,7 @@ TYPED_TEST(PolynomialTest, ScalarMultiply) { ...@@ -322,7 +326,7 @@ TYPED_TEST(PolynomialTest, ScalarMultiply) {
this->ntt_p_.Mul(scalar, this->params14_.get())); this->ntt_p_.Mul(scalar, this->params14_.get()));
CoefficientPolynomial res( CoefficientPolynomial res(
ntt_res.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
CoefficientPolynomial expected = (*this->p_) * scalar; CoefficientPolynomial expected = (*this->p_) * scalar;
...@@ -337,9 +341,9 @@ TYPED_TEST(PolynomialTest, Negate) { ...@@ -337,9 +341,9 @@ TYPED_TEST(PolynomialTest, Negate) {
this->SetParams(1 << i, i); this->SetParams(1 << i, i);
// An NTT polynomial of all zeros. // An NTT polynomial of all zeros.
Polynomial zeros_ntt = Polynomial zeros_ntt = Polynomial::ConvertToNtt(
Polynomial::ConvertToNtt(std::vector<uint_m>(1 << i, this->zero_), std::vector<uint_m>(1 << i, this->zero_), this->ntt_params_.get(),
this->ntt_params_, this->params14_.get()); this->params14_.get());
auto minus_p = this->ntt_p_.Negate(this->params14_.get()); auto minus_p = this->ntt_p_.Negate(this->params14_.get());
ASSERT_OK_AND_ASSIGN(auto p0, ASSERT_OK_AND_ASSIGN(auto p0,
...@@ -363,10 +367,10 @@ TYPED_TEST(PolynomialTest, Add) { ...@@ -363,10 +367,10 @@ TYPED_TEST(PolynomialTest, Add) {
this->ntt_q_.Add(this->ntt_p_, this->params14_.get())); this->ntt_q_.Add(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1( CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
CoefficientPolynomial res2( CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected, ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected,
...@@ -391,10 +395,10 @@ TYPED_TEST(PolynomialTest, Sub) { ...@@ -391,10 +395,10 @@ TYPED_TEST(PolynomialTest, Sub) {
this->ntt_q_.Sub(this->ntt_p_, this->params14_.get())); this->ntt_q_.Sub(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1( CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
CoefficientPolynomial res2( CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected_res1, ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected_res1,
...@@ -413,18 +417,20 @@ TYPED_TEST(PolynomialTest, SubstitutionPowerMalformed) { ...@@ -413,18 +417,20 @@ TYPED_TEST(PolynomialTest, SubstitutionPowerMalformed) {
this->SetParams(1 << i, i); this->SetParams(1 << i, i);
EXPECT_THAT( EXPECT_THAT(
this->ntt_p_.Substitute(2, this->ntt_params_, this->params14_.get()), this->ntt_p_.Substitute(2, this->ntt_params_.get(),
this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument, StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than"))); HasSubstr("must be a non-negative odd integer less than")));
// Even when not in debugging mode, the following two tests will yield a // Even when not in debugging mode, the following two tests will yield a
// segmentation fault. We therefore only do the tests in debug mode. // segmentation fault. We therefore only do the tests in debug mode.
EXPECT_THAT( EXPECT_THAT(
this->ntt_p_.Substitute(-10, this->ntt_params_, this->params14_.get()), this->ntt_p_.Substitute(-10, this->ntt_params_.get(),
this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument, StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than"))); HasSubstr("must be a non-negative odd integer less than")));
EXPECT_THAT( EXPECT_THAT(
this->ntt_p_.Substitute(2 * (1 << i) + 1, this->ntt_params_, this->ntt_p_.Substitute(2 * (1 << i) + 1, this->ntt_params_.get(),
this->params14_.get()), this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument, StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than"))); HasSubstr("must be a non-negative odd integer less than")));
...@@ -438,11 +444,11 @@ TYPED_TEST(PolynomialTest, Substitution) { ...@@ -438,11 +444,11 @@ TYPED_TEST(PolynomialTest, Substitution) {
int dimension = 1 << i; int dimension = 1 << i;
for (int k = 0; k < i; k++) { for (int k = 0; k < i; k++) {
int power = (dimension >> k) + 1; int power = (dimension >> k) + 1;
ASSERT_OK_AND_ASSIGN(auto ntt_res, ASSERT_OK_AND_ASSIGN(
this->ntt_p_.Substitute(power, this->ntt_params_, auto ntt_res, this->ntt_p_.Substitute(power, this->ntt_params_.get(),
this->params14_.get())); this->params14_.get()));
CoefficientPolynomial res( CoefficientPolynomial res(
ntt_res.InverseNtt(this->ntt_params_, this->params14_.get()), ntt_res.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get()); this->params14_.get());
ASSERT_OK_AND_ASSIGN(auto r, this->p_->Substitute(power)); ASSERT_OK_AND_ASSIGN(auto r, this->p_->Substitute(power));
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"])
exports_files(["LICENSE"])
# PRNG interface. # PRNG interface.
...@@ -44,7 +46,7 @@ cc_library( ...@@ -44,7 +46,7 @@ cc_library(
deps = [ deps = [
":chacha_prng", ":chacha_prng",
":single_thread_chacha_prng", ":single_thread_chacha_prng",
"@com_google_googletest//:gtest", "@com_github_google_googletest//:gtest",
], ],
) )
...@@ -61,8 +63,8 @@ cc_test( ...@@ -61,8 +63,8 @@ cc_test(
":integral_prng_types", ":integral_prng_types",
"//testing:matchers", "//testing:matchers",
"//testing:status_testing", "//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
], ],
) )
...@@ -124,8 +126,8 @@ cc_test( ...@@ -124,8 +126,8 @@ cc_test(
":chacha_prng", ":chacha_prng",
"//testing:matchers", "//testing:matchers",
"//testing:status_testing", "//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
], ],
) )
...@@ -139,7 +141,7 @@ cc_test( ...@@ -139,7 +141,7 @@ cc_test(
":single_thread_chacha_prng", ":single_thread_chacha_prng",
"//testing:matchers", "//testing:matchers",
"//testing:status_testing", "//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
], ],
) )
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "prng/chacha_prng_util.h" #include "prng/chacha_prng_util.h"
#include <cstdint> #include <cstdint>
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// An implementation of a PRNG using the ChaCha20 stream cipher. Since this is // An implementation of a PRNG using the ChaCha20 stream cipher. Since this is
// a stream cipher, the key stream can be obtained by "encrypting" the plaintext // a stream cipher, the key stream can be obtained by "encrypting" the plaintext
// 0....0. // 0....0.
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_ #ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
#define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_ #define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_ #ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
#define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_ #define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "relinearization_key.h" #include "relinearization_key.h"
#include "absl/numeric/int128.h" #include "absl/numeric/int128.h"
#include "bits_util.h"
#include "montgomery.h" #include "montgomery.h"
#include "prng/integral_prng_types.h" #include "prng/integral_prng_types.h"
#include "status_macros.h" #include "status_macros.h"
...@@ -25,18 +26,13 @@ ...@@ -25,18 +26,13 @@
namespace rlwe { namespace rlwe {
namespace { namespace {
// Method to compute the number of digits needed to represent integers mod // Method to compute the number of digits needed to represent integers mod
// q in base T. Upcasts the modulus to absl::uint128 to handle all uint*_t // q in base T. Upcasts the modulus to absl::uint128 to handle all Uint*
// types. // types.
inline int ComputeDimension(Uint64 log_decomposition_modulus, inline int ComputeDimension(Uint64 log_decomposition_modulus,
absl::uint128 modulus) { absl::uint128 modulus) {
double modulus_bits; Uint64 modulus_bits = static_cast<Uint64>(internal::BitLength(modulus));
// Compute bit lengths of each as a double and divide. return (modulus_bits + (log_decomposition_modulus - 1)) /
if (absl::Uint128High64(modulus) > 0) { log_decomposition_modulus;
modulus_bits = std::log2(absl::Uint128High64(modulus)) + 64;
} else {
modulus_bits = std::log2(absl::Uint128Low64(modulus));
}
return std::ceil(modulus_bits / log_decomposition_modulus);
} }
// Returns a random vector r orthogonal to (1,s). The second component is chosen // Returns a random vector r orthogonal to (1,s). The second component is chosen
...@@ -83,30 +79,26 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT( ...@@ -83,30 +79,26 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT(
template <typename ModularInt> template <typename ModularInt>
rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose( rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose(
const std::vector<ModularInt>& coefficients, const std::vector<ModularInt>& coefficients,
typename ModularInt::Params* modulus_params, const typename ModularInt::Params* modulus_params,
const Uint64 log_decomposition_modulus, int dimension) { const Uint64 log_decomposition_modulus, int dimension) {
std::vector<std::vector<ModularInt>> result;
std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(), std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(),
0); 0);
std::transform( std::transform(
coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(), coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(),
[modulus_params](ModularInt x) { return x.ExportInt(modulus_params); }); [modulus_params](ModularInt x) { return x.ExportInt(modulus_params); });
auto zero = ModularInt::ImportZero(modulus_params); std::vector<std::vector<ModularInt>> result(dimension);
std::vector<ModularInt> sum_part(ciphertext_coeffs.size(), zero);
for (int i = 0; i < dimension; i++) { for (int i = 0; i < dimension; i++) {
result[i].reserve(ciphertext_coeffs.size());
for (int j = 0; j < ciphertext_coeffs.size(); ++j) { for (int j = 0; j < ciphertext_coeffs.size(); ++j) {
RLWE_ASSIGN_OR_RETURN( RLWE_ASSIGN_OR_RETURN(
sum_part[j], auto coefficient_part,
ModularInt::ImportInt( ModularInt::ImportInt(
(ciphertext_coeffs[j] % (1L << log_decomposition_modulus)), (ciphertext_coeffs[j] % (1L << log_decomposition_modulus)),
modulus_params)); modulus_params));
result[i].push_back(std::move(coefficient_part));
ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus; ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus;
} }
result.push_back(sum_part);
} }
return result; return result;
...@@ -116,7 +108,7 @@ template <typename ModularInt> ...@@ -116,7 +108,7 @@ template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply( rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
std::vector<std::vector<ModularInt>> decomposed_coefficients, std::vector<std::vector<ModularInt>> decomposed_coefficients,
const std::vector<std::vector<Polynomial<ModularInt>>>& matrix, const std::vector<std::vector<Polynomial<ModularInt>>>& matrix,
typename ModularInt::Params* modulus_params, const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) { const NttParameters<ModularInt>* ntt_params) {
Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params); Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params);
Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params); Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params);
...@@ -125,7 +117,7 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply( ...@@ -125,7 +117,7 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
for (int i = 0; i < matrix[0].size(); i++) { for (int i = 0; i < matrix[0].size(); i++) {
ntt_part = Polynomial<ModularInt>::ConvertToNtt( ntt_part = Polynomial<ModularInt>::ConvertToNtt(
std::move(decomposed_coefficients[i]), *ntt_params, modulus_params); std::move(decomposed_coefficients[i]), ntt_params, modulus_params);
RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params)); RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params));
RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params)); RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params));
RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params)) RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params))
...@@ -166,8 +158,8 @@ RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create( ...@@ -166,8 +158,8 @@ RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create(
key_power.Len(), key.Variance(), prng_encryption, key_power.Len(), key.Variance(), prng_encryption,
key.ModulusParams())); key.ModulusParams()));
// Convert the error coefficients into an error polynomial. // Convert the error coefficients into an error polynomial.
auto e = Polynomial<ModularInt>::ConvertToNtt(error, *key.NttParams(), auto e = Polynomial<ModularInt>::ConvertToNtt(
key.ModulusParams()); std::move(error), key.NttParams(), key.ModulusParams());
// Set the column of the Relinearization matrix. // Set the column of the Relinearization matrix.
RLWE_RETURN_IF_ERROR( RLWE_RETURN_IF_ERROR(
e.MulInPlace(key.PlaintextModulus(), key.ModulusParams())); e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
...@@ -184,11 +176,11 @@ template <typename ModularInt> ...@@ -184,11 +176,11 @@ template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> rlwe::StatusOr<std::vector<Polynomial<ModularInt>>>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo( RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo(
const Polynomial<ModularInt>& ciphertext_part, const Polynomial<ModularInt>& ciphertext_part,
typename ModularInt::Params* modulus_params, const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) const { const NttParameters<ModularInt>* ntt_params) const {
// Convert ciphertext out of NTT form. // Convert ciphertext out of NTT form.
std::vector<ModularInt> ciphertext_coefficients = std::vector<ModularInt> ciphertext_coefficients =
ciphertext_part.InverseNtt(*ntt_params, modulus_params); ciphertext_part.InverseNtt(ntt_params, modulus_params);
// Bit-decompose the vector of coefficients in the ciphertext. // Bit-decompose the vector of coefficients in the ciphertext.
RLWE_ASSIGN_OR_RETURN( RLWE_ASSIGN_OR_RETURN(
...@@ -206,14 +198,16 @@ rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart> ...@@ -206,14 +198,16 @@ rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize( RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize(
const std::vector<SerializedNttPolynomial>& polynomials, const std::vector<SerializedNttPolynomial>& polynomials,
Uint64 log_decomposition_modulus, SecurePrng* prng, Uint64 log_decomposition_modulus, SecurePrng* prng,
ModularIntParams* modulus_params, const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params) { const NttParameters<ModularInt>* ntt_params) {
// The polynomials input is a flattened representation of a 2 x dimension // The polynomials input is a flattened representation of a 2 x dimension
// matrix where the first half corresponds to the first row of matrix and the // matrix where the first half corresponds to the first row of matrix and the
// second half corresponds to the second row of matrix. This matrix makes up // second half corresponds to the second row of matrix. This matrix makes up
// the RelinearizationKeyPart. // the RelinearizationKeyPart.
auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
int dimension = polynomials.size(); int dimension = polynomials.size();
auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
matrix[0].reserve(dimension);
matrix[1].reserve(dimension);
for (int i = 0; i < dimension; i++) { for (int i = 0; i < dimension; i++) {
RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize( RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
...@@ -282,6 +276,7 @@ RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key, ...@@ -282,6 +276,7 @@ RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key,
auto dimension = auto dimension =
ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus); ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus);
std::vector<RelinearizationKeyPart> relinearization_key; std::vector<RelinearizationKeyPart> relinearization_key;
relinearization_key.reserve(num_parts);
// Create RealinearizationKeyPart for each of the secret key parts: s, ..., // Create RealinearizationKeyPart for each of the secret key parts: s, ...,
// s^k. // s^k.
for (int i = 1; i < num_parts; i++) { for (int i = 1; i < num_parts; i++) {
...@@ -362,7 +357,7 @@ template <typename ModularInt> ...@@ -362,7 +357,7 @@ template <typename ModularInt>
rlwe::StatusOr<RelinearizationKey<ModularInt>> rlwe::StatusOr<RelinearizationKey<ModularInt>>
RelinearizationKey<ModularInt>::Deserialize( RelinearizationKey<ModularInt>::Deserialize(
const SerializedRelinearizationKey& serialized, const SerializedRelinearizationKey& serialized,
typename ModularInt::Params* modulus_params, const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) { const NttParameters<ModularInt>* ntt_params) {
// Verifies that the number of polynomials in serialized is expected. // Verifies that the number of polynomials in serialized is expected.
// A RelinearizationKey can decrypt ciphertexts with num_parts number of // A RelinearizationKey can decrypt ciphertexts with num_parts number of
...@@ -419,6 +414,7 @@ RelinearizationKey<ModularInt>::Deserialize( ...@@ -419,6 +414,7 @@ RelinearizationKey<ModularInt>::Deserialize(
// Takes each polynomials_per_matrix chunk of serialized.c()'s and places them // Takes each polynomials_per_matrix chunk of serialized.c()'s and places them
// into a KeyPart. // into a KeyPart.
output.relinearization_key_.reserve(serialized.num_parts() - 1);
for (int i = 0; i < (serialized.num_parts() - 1); i++) { for (int i = 0; i < (serialized.num_parts() - 1); i++) {
auto start = serialized.c().begin() + i * polynomials_per_matrix; auto start = serialized.c().begin() + i * polynomials_per_matrix;
auto end = start + polynomials_per_matrix; auto end = start + polynomials_per_matrix;
......
...@@ -96,7 +96,7 @@ class RelinearizationKey { ...@@ -96,7 +96,7 @@ class RelinearizationKey {
// input parameters. // input parameters.
static rlwe::StatusOr<RelinearizationKey> Deserialize( static rlwe::StatusOr<RelinearizationKey> Deserialize(
const SerializedRelinearizationKey& serialized, const SerializedRelinearizationKey& serialized,
ModularIntParams* modulus_params, const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params); const NttParameters<ModularInt>* ntt_params);
// Substitution Power accessor. // Substitution Power accessor.
...@@ -120,14 +120,14 @@ class RelinearizationKey { ...@@ -120,14 +120,14 @@ class RelinearizationKey {
// of polynomials. // of polynomials.
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> ApplyPartTo( rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> ApplyPartTo(
const Polynomial<ModularInt>& ciphertext_part, const Polynomial<ModularInt>& ciphertext_part,
ModularIntParams* modulus_params, const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params) const; const NttParameters<ModularInt>* ntt_params) const;
// Creates a RelinearizationKeyPart out of a vector of Polynomials. // Creates a RelinearizationKeyPart out of a vector of Polynomials.
static rlwe::StatusOr<RelinearizationKeyPart> Deserialize( static rlwe::StatusOr<RelinearizationKeyPart> Deserialize(
const std::vector<SerializedNttPolynomial>& polynomials, const std::vector<SerializedNttPolynomial>& polynomials,
Uint64 log_decomposition_modulus, SecurePrng* prng, Uint64 log_decomposition_modulus, SecurePrng* prng,
ModularIntParams* modulus_params, const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params); const NttParameters<ModularInt>* ntt_params);
std::vector<Polynomial<ModularInt>> Matrix() const { return matrix_[0]; } std::vector<Polynomial<ModularInt>> Matrix() const { return matrix_[0]; }
...@@ -147,7 +147,7 @@ class RelinearizationKey { ...@@ -147,7 +147,7 @@ class RelinearizationKey {
// Creates an empty RelinearizationKey. // Creates an empty RelinearizationKey.
RelinearizationKey(Uint64 log_decomposition_modulus, RelinearizationKey(Uint64 log_decomposition_modulus,
const ModularInt& decomposition_modulus, const ModularInt& decomposition_modulus,
ModularIntParams* params, const ModularIntParams* params,
const NttParameters<ModularInt>* ntt_params) const NttParameters<ModularInt>* ntt_params)
: log_decomposition_modulus_(log_decomposition_modulus), : log_decomposition_modulus_(log_decomposition_modulus),
decomposition_modulus_(decomposition_modulus), decomposition_modulus_(decomposition_modulus),
...@@ -176,7 +176,7 @@ class RelinearizationKey { ...@@ -176,7 +176,7 @@ class RelinearizationKey {
int substitution_power_; int substitution_power_;
// Modulus parameters. Does not take ownership. // Modulus parameters. Does not take ownership.
ModularIntParams* modulus_params_; const ModularIntParams* modulus_params_;
// NTT parameters. Does not take ownership. // NTT parameters. Does not take ownership.
const NttParameters<ModularInt>* ntt_params_; const NttParameters<ModularInt>* ntt_params_;
......
...@@ -43,7 +43,7 @@ namespace rlwe { ...@@ -43,7 +43,7 @@ namespace rlwe {
template <typename ModularInt> template <typename ModularInt>
static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution( static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution(
unsigned int num_coeffs, Uint64 variance, SecurePrng* prng, unsigned int num_coeffs, Uint64 variance, SecurePrng* prng,
typename ModularInt::Params* modulus_params) { const typename ModularInt::Params* modulus_params) {
if (variance > kMaxVariance) { if (variance > kMaxVariance) {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"The variance, ", variance, ", must be at most ", kMaxVariance, ".")); "The variance, ", variance, ", must be at most ", kMaxVariance, "."));
...@@ -54,48 +54,46 @@ static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution( ...@@ -54,48 +54,46 @@ static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution(
// Sample from the centered binomial distribution. To do so, we sample k pairs // Sample from the centered binomial distribution. To do so, we sample k pairs
// of bits (a, b), where k = 2 * variance. The sample of the binomial // of bits (a, b), where k = 2 * variance. The sample of the binomial
// distribution is the sum of the differences between each pair of bits. // distribution is the sum of the differences between each pair of bits.
// This is implemented by splitting k in words k', drawing 2*k' bits, and
// computing the difference of Hamming weight between the first k' bits and
// the last k' bits, where
Uint64 k; Uint64 k;
Uint64 coefficient_positive, coefficient_negative; typename ModularInt::Int coefficient;
for (int i = 0; i < num_coeffs; i++) { for (int i = 0; i < num_coeffs; i++) {
coefficient_positive = 0; coefficient = modulus_params->modulus;
coefficient_negative = 0;
k = variance << 1; k = variance << 1;
while (k > 0) { while (k > 0) {
if (k >= 64) { if (k >= 64) {
// Use all the bits // Use 64 bits of randomness
RLWE_ASSIGN_OR_RETURN(auto r64, prng->Rand64()); RLWE_ASSIGN_OR_RETURN(auto r64, prng->Rand64());
coefficient_positive += rlwe::internal::CountOnes64(r64); coefficient += rlwe::internal::CountOnes64(r64);
RLWE_ASSIGN_OR_RETURN(r64, prng->Rand64()); RLWE_ASSIGN_OR_RETURN(r64, prng->Rand64());
coefficient_negative += rlwe::internal::CountOnes64(r64); coefficient -= rlwe::internal::CountOnes64(r64);
k -= 64; k -= 64;
} else if (k > 8) { } else if (k >= 8) {
Uint64 mask = (1ULL << k) - 1; // Use 8 bits of randomness
RLWE_ASSIGN_OR_RETURN(auto r64, prng->Rand64()); RLWE_ASSIGN_OR_RETURN(auto r8, prng->Rand8());
coefficient_positive += rlwe::internal::CountOnes64(r64 & mask); coefficient += rlwe::internal::CountOnesInByte(r8);
RLWE_ASSIGN_OR_RETURN(r64, prng->Rand64()); RLWE_ASSIGN_OR_RETURN(r8, prng->Rand8());
coefficient_negative += rlwe::internal::CountOnes64(r64 & mask); coefficient -= rlwe::internal::CountOnesInByte(r8);
k = 0; k -= 8;
} else { } else {
Uint8 mask = (1 << k) - 1; Uint8 mask = (1 << k) - 1;
RLWE_ASSIGN_OR_RETURN(auto r8, prng->Rand8()); RLWE_ASSIGN_OR_RETURN(auto r8, prng->Rand8());
coefficient_positive += rlwe::internal::CountOnesInByte(r8 & mask); coefficient += rlwe::internal::CountOnesInByte(r8 & mask);
RLWE_ASSIGN_OR_RETURN(r8, prng->Rand8()); RLWE_ASSIGN_OR_RETURN(r8, prng->Rand8());
coefficient_negative += rlwe::internal::CountOnesInByte(r8 & mask); coefficient -= rlwe::internal::CountOnesInByte(r8 & mask);
k = 0; break; // all k remaining pairs have been sampled.
} }
} }
RLWE_ASSIGN_OR_RETURN( // coefficient is in the interval [modulus - 2k, modulus + 2k]. We reduce
coeffs[i], ModularInt::ImportInt(coefficient_positive, modulus_params)); // it in [0, modulus). Since ModularInt::Int is unsigned, we create a mask
RLWE_ASSIGN_OR_RETURN( // equal to 0xFF...FF when coefficient >= modulus, and equal to 0 otherwise.
auto v, ModularInt::ImportInt(coefficient_negative, modulus_params)); typename ModularInt::Int mask = -(coefficient >= modulus_params->modulus);
coeffs[i].SubInPlace(v, modulus_params); coefficient -= mask & modulus_params->modulus;
RLWE_ASSIGN_OR_RETURN(coeffs[i],
ModularInt::ImportInt(coefficient, modulus_params));
} }
return coeffs; return coeffs;
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_STATUS_MACROS_H_ #ifndef RLWE_STATUS_MACROS_H_
#define RLWE_STATUS_MACROS_H_ #define RLWE_STATUS_MACROS_H_
......
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "status_macros.h" #include "status_macros.h"
#include <sstream> #include <sstream>
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment