Commit fea9d6a1 authored by Amr Aboelkher's avatar Amr Aboelkher Committed by Commit Bot

Roll shell-encryption e2a4af88a0e..4e0598f826

This CL is doing the following:
- Update the BUILD.gn correspondingly with the latest changes
- Adding new patch for using absl optional instead of std
- Update the Support-SHELL-in-chromium patch

Bug: 1072286
Change-Id: Id3c4bc2da5183f92345440eb887f5391f8d7b2e8
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2153443Reviewed-by: default avatarNico Weber <thakis@chromium.org>
Reviewed-by: default avatarAmr Aboelkher <amraboelkher@google.com>
Commit-Queue: Amr Aboelkher <amraboelkher@google.com>
Cr-Commit-Position: refs/heads/master@{#760979}
parent a4d01f33
......@@ -46,6 +46,7 @@ source_set("shell_encryption") {
"glog/logging.h",
"src/bits_util.h",
"src/constants.h",
"src/context.h",
"src/error_params.h",
"src/galois_key.h",
"src/int256.h",
......@@ -68,6 +69,7 @@ source_set("shell_encryption") {
]
sources = [
"src/int256.cc",
"src/montgomery.cc",
"src/ntt_parameters.cc",
"src/prng/chacha_prng.cc",
"src/prng/chacha_prng_util.cc",
......@@ -96,6 +98,7 @@ test("shell_encryption_test") {
public = [
"src/testing/coefficient_polynomial.h",
"src/testing/coefficient_polynomial_ciphertext.h",
"src/testing/parameter.h",
"src/testing/protobuf_matchers.h",
"src/testing/status_matchers.h",
"src/testing/status_testing.h",
......@@ -104,6 +107,7 @@ test("shell_encryption_test") {
]
sources = [
"src/bits_util_test.cc",
"src/context_test.cc",
"src/error_params_test.cc",
"src/galois_key_test.cc",
"src/int256_test.cc",
......
diff --git a/montgomery.h b/montgomery.h
index 50fb08a..336e925 100644
--- a/montgomery.h
+++ b/montgomery.h
@@ -105,8 +105,8 @@ struct MontgomeryIntParams {
static rlwe::StatusOr<std::unique_ptr<MontgomeryIntParams>> Create(
Int modulus) {
// Check that the modulus is smaller than max(Int) / 4.
- if (Int most_significant_bit = modulus >> (bitsize_int - 2);
- most_significant_bit != 0) {
+ Int most_significant_bit = modulus >> (bitsize_int - 2);
+ if (most_significant_bit != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"The modulus should be less than 2^", (bitsize_int - 2), "."));
}
diff --git a/montgomery.cc b/montgomery.cc
index d5221f4..9fbc5da 100644
--- a/montgomery.cc
+++ b/montgomery.cc
@@ -22,8 +22,8 @@ template <typename T>
rlwe::StatusOr<std::unique_ptr<const MontgomeryIntParams<T>>>
MontgomeryIntParams<T>::Create(Int modulus) {
// Check that the modulus is smaller than max(Int) / 4.
- if (Int most_significant_bit = modulus >> (bitsize_int - 2);
- most_significant_bit != 0) {
+ Int most_significant_bit = modulus >> (bitsize_int - 2);
+ if (most_significant_bit != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"The modulus should be less than 2^", (bitsize_int - 2), "."));
}
diff --git a/ntt_parameters.h b/ntt_parameters.h
index da03bfe..270a5a2 100644
index 56e1871..c3da197 100644
--- a/ntt_parameters.h
+++ b/ntt_parameters.h
@@ -103,7 +103,8 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
......@@ -28,12 +28,12 @@ index da03bfe..270a5a2 100644
}
}
diff --git a/polynomial.h b/polynomial.h
index 0b73dfa..fc0e38a 100644
index 07843b2..3cf0c77 100644
--- a/polynomial.h
+++ b/polynomial.h
@@ -80,7 +80,8 @@ class Polynomial {
const NttParameters<ModularInt>& ntt_params,
ModularIntParams* modular_params) {
const NttParameters<ModularInt>* ntt_params,
const ModularIntParams* modular_params) {
// Check to ensure that the coefficient vector is of the correct length.
- if (int len = poly_coeffs.size(); len <= 0 || (len & (len - 1)) != 0) {
+ int len = poly_coeffs.size();
......@@ -42,10 +42,10 @@ index 0b73dfa..fc0e38a 100644
return Polynomial();
}
diff --git a/prng/chacha_prng_util.cc b/prng/chacha_prng_util.cc
index bd78ab8..39099e1 100644
index dfab1d9..c49c82d 100644
--- a/prng/chacha_prng_util.cc
+++ b/prng/chacha_prng_util.cc
@@ -10,7 +10,8 @@
@@ -24,7 +24,8 @@
#include <openssl/rand.h>
#include "status_macros.h"
......@@ -55,7 +55,7 @@ index bd78ab8..39099e1 100644
absl::Status ChaChaPrngResalt(absl::string_view key, int buffer_size,
int* salt_counter, int* position_in_buffer,
@@ -71,4 +72,5 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
@@ -85,4 +86,5 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
return rand64;
}
......@@ -63,10 +63,10 @@ index bd78ab8..39099e1 100644
+} // namespace internal
+} // namespace rlwe
diff --git a/prng/chacha_prng_util.h b/prng/chacha_prng_util.h
index 2bb2fcf..5490505 100644
index 32cac5b..8eb8118 100644
--- a/prng/chacha_prng_util.h
+++ b/prng/chacha_prng_util.h
@@ -12,7 +12,8 @@
@@ -28,7 +28,8 @@
#include "integral_types.h"
#include "statusor.h"
......@@ -76,7 +76,7 @@ index 2bb2fcf..5490505 100644
const int kChaChaKeyBytesSize = 32;
const int kChaChaNonceSize = 12;
@@ -43,6 +44,7 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
@@ -59,6 +60,7 @@ rlwe::StatusOr<Uint64> ChaChaPrngRand64(absl::string_view key,
int* salt_counter,
std::vector<Uint8>* buffer);
......@@ -86,10 +86,10 @@ index 2bb2fcf..5490505 100644
#endif // RLWE_CHACHA_PRNG_UTIL_H_
diff --git a/statusor.h b/statusor.h
index 4fdeade..42761e6 100644
index d8addb5..200f62d 100644
--- a/statusor.h
+++ b/statusor.h
@@ -74,7 +74,7 @@ class StatusOr {
@@ -96,7 +96,7 @@ class StatusOr {
operator absl::Status() const { return status(); }
......@@ -99,11 +99,11 @@ index 4fdeade..42761e6 100644
if (value_) {
return OtherStatusOrType<T>(std::move(value_.value()));
diff --git a/symmetric_encryption.h b/symmetric_encryption.h
index d4ad730..0223149 100644
index e120b18..987e86f 100644
--- a/symmetric_encryption.h
+++ b/symmetric_encryption.h
@@ -584,8 +584,8 @@ class SymmetricRlweKey {
typename ModularIntQ::Params* modulus_params_q,
@@ -571,8 +571,8 @@ class SymmetricRlweKey {
const typename ModularIntQ::Params* modulus_params_q,
const NttParameters<ModularIntQ>* ntt_params_q) const {
// Configuration failure.
- if (Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One();
......
diff --git a/ntt_parameters.h b/ntt_parameters.h
index c3da197..55671ec 100644
--- a/ntt_parameters.h
+++ b/ntt_parameters.h
@@ -168,7 +168,7 @@ struct NttParameters {
~NttParameters() = default;
int number_coeffs;
- std::optional<ModularInt> n_inv_ptr;
+ absl::optional<ModularInt> n_inv_ptr;
std::vector<ModularInt> psis_bitrev;
std::vector<ModularInt> psis_inv_bitrev;
std::vector<unsigned int> bitrevs;
This diff is collapsed.
......@@ -8,10 +8,6 @@ both add and multiply encrypted data. It uses modulus-switching to enable
arbitrary-depth homomorphic encryption (provided sufficiently large parameters
are set). RLWE is also believed to be secure in the face of quantum computers.
This library is designed to be compact and readable. The library includes just
1500 lines of heavily-commented source (700 lines of source, 600 lines of
comments and 200 blank lines) and 900 lines of unit tests.
We intend this project to be both a useful experimental library and a
comprehensible tool for learning about and extending RLWE.
......@@ -42,8 +38,7 @@ above, Alice can securely offload computation to another entity without worrying
that doing so will reveal any of her private information. Among many other
applications, it enables *private information retrieval* (PIR) - databases that
can serve user requests without learning which pieces of data the users
requested. (For more information on PIR, see [XPIR: Private Information
Retrieval for Everyone](https://eprint.iacr.org/2014/1025.pdf).)
requested. (For more information on PIR, see [Communication--Computation Trade-offs in PIR](https://eprint.iacr.org/2019/1483.pdf).)
## Ring Learning with Errors
......@@ -56,7 +51,7 @@ hardness.
The cryptosystem implemented in this library is from [Fully Homomorphic
Encryption from Ring-LWE and Security for Key Dependent
Messages](http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf).
The cryptosystem works as follows:
The cryptosystem works as follows.
### Preliminaries
......@@ -72,7 +67,7 @@ We also need a modulus *t* that is much smaller than *q*. *log(t)* is the number
of bits of plaintext information we are able to fit into each coefficient of a
ciphertext polynomial. The importance of *t* will become apparent soon.
Finally, we need two other components: a Gaussian distribution *Y* with mean 0
Finally, we need two other components: a binomial distribution *Y* with mean 0
and standard deviation *w*, where *w* is a parameter of the cryptosystem. The
importance of this distribution will become apparent soon.
......@@ -230,10 +225,10 @@ This library consists of four major components that form a RLWE stack.
### Montgomery Integers
This library is implemented in `montgomery.h`. At the lowest level is a library
that represents modular integers in Montgomery form, which speeds up the
repeated use of the modulo operator. This library supports 64-bit integers,
meaning that it can support a modulus of up to 30-bits. For larger modulus
This library is implemented in `montgomery.(cch|h)`. At the lowest level is a
library that represents modular integers in Montgomery form, which speeds up the
repeated use of the modulo operator. This library supports 128-bit integers,
meaning that it can support a modulus of up to 126 bits. For larger modulus
sizes, the higher levels of the stack can be parameterized with a different type
(such as a BigInteger). Montgomery integers require several parameters in
addition to the modulus to perform the modular operations efficiently. These
......@@ -306,10 +301,10 @@ cd shell-encryption
bazel build :all --cxxopt='-std=c++17'
```
You may also run all tests using the following command:
You may also run all tests (recursively) using the following command:
```bash
bazel test :all --cxxopt='-std=c++17'
bazel test ... --cxxopt='-std=c++17'
```
If you get an error, you may need to build/test with the following flags:
......
......@@ -12,63 +12,62 @@ http_archive(
],
)
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
rules_cc_dependencies()
# rules_proto defines abstract rules for building Protocol Buffers.
# https://github.com/bazelbuild/rules_proto
http_archive(
name = "rules_proto",
sha256 = "57001a3b33ec690a175cdf0698243431ef27233017b9bed23f96d44b9c98242f",
strip_prefix = "rules_proto-9cd4f8f1ede19d81c6d48910429fe96776e567b1",
sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208",
strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/9cd4f8f1ede19d81c6d48910429fe96776e567b1.tar.gz",
"https://github.com/bazelbuild/rules_proto/archive/9cd4f8f1ede19d81c6d48910429fe96776e567b1.tar.gz",
"https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
"https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
],
)
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
rules_cc_dependencies()
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
rules_proto_dependencies()
rules_proto_toolchains()
# Install gtest.
http_archive(
name = "com_google_googletest",
urls = [
"https://github.com/google/googletest/archive/release-1.10.0.zip",
],
sha256 = "94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91",
strip_prefix = "googletest-release-1.10.0",
)
# abseil-cpp
git_repository(
name = "com_google_absl",
remote = "https://github.com/abseil/abseil-cpp.git",
commit = "0d5ce2797eb695aee7e019e25323251ef6ffc277",
name = "com_github_google_googletest",
commit = "703bd9caab50b139428cea1aaff9974ebee5742e", # tag = "release-1.10.0"
remote = "https://github.com/google/googletest.git",
shallow_since = "1570114335 -0400",
)
# BoringSSL
# abseil-cpp
http_archive(
name = "boringssl",
name = "com_google_absl",
sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353",
strip_prefix = "abseil-cpp-20200225",
urls = [
"https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
"https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
"https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz",
],
sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3",
strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778",
)
# BoringSSL
git_repository(
name = "boringssl",
commit = "67ffb9606462a1897d3a5edf5c06d329878ba600", # https://boringssl.googlesource.com/boringssl/+/refs/heads/master-with-bazel
remote = "https://boringssl.googlesource.com/boringssl",
shallow_since = "1585767053 +0000"
)
# Logging
http_archive(
name = "com_github_glog_glog",
build_file = "@//glog:BUILD",
urls = ["https://github.com/google/glog/archive/v0.3.5.tar.gz"],
strip_prefix = "glog-0.3.5",
name = "com_github_google_glog",
urls = ["https://github.com/google/glog/archive/96a2f23dca4cc7180821ca5f32e526314395d26a.zip"],
strip_prefix = "glog-96a2f23dca4cc7180821ca5f32e526314395d26a",
sha256 = "6281aa4eeecb9e932d7091f99872e7b26fa6aacece49c15ce5b14af2b7ec050f",
)
# gflags, needed for glog
git_repository(
http_archive(
name = "com_github_gflags_gflags",
remote = "https://github.com/gflags/gflags.git",
tag = "v2.2.2",
urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"],
sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
strip_prefix = "gflags-2.2.2",
)
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_BITS_UTIL_H_
#define RLWE_BITS_UTIL_H_
......@@ -83,6 +99,10 @@ inline unsigned int CountLeadingZeros128(absl::uint128 x) {
return CountLeadingZeros64(absl::Uint128Low64(x)) + 64;
}
inline unsigned int BitLength(absl::uint128 x) {
return 128 - CountLeadingZeros128(x);
}
} // namespace internal
} // namespace rlwe
......
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bits_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/numeric/int128.h"
using ::testing::Eq;
......@@ -33,4 +48,12 @@ TEST(BitsUtilTest, CountLeadingZeros64Works) {
}
}
TEST(BitsUtilTest, BitLengthWorks) {
absl::uint128 value = absl::MakeUint128(0x8000000000000000, 0);
for (int i = 0; i <= 128; i++) {
EXPECT_THAT(rlwe::internal::BitLength(value), Eq(128 - i));
value >>= 1;
}
}
} // namespace
/*
* Copyright 2020 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_CONTEXT_H_
#define RLWE_CONTEXT_H_
#include <memory>
#include "absl/memory/memory.h"
#include "error_params.h"
#include "ntt_parameters.h"
#include "status_macros.h"
#include "statusor.h"
namespace rlwe {
// Defines and holds the context of the RLWE encryption scheme.
// Thread safe..
template <typename ModularInt>
class RlweContext {
using Int = typename ModularInt::Int;
using ModulusParams = typename ModularInt::Params;
public:
// Structure to hold parameters for the RLWE encryption scheme. The parameters
// include:
// - modulus, an Int which needs to be congruent to 1 modulo 2 * (1 << log_n);
// - log_n: the logarithm of the number of coefficients of the polynomials;
// - log_t: the number of bits of the plaintext space, which will be equal to
// (1 << log_t) + 1;
// - variance, the error variance to use when sampling noises or secrets.
struct Parameters {
Int modulus;
size_t log_n;
size_t log_t;
size_t variance;
};
// Factory function to create a context from a context_params.
static rlwe::StatusOr<std::unique_ptr<const RlweContext>> Create(
Parameters context_params) {
// Create the modulus parameters.
RLWE_ASSIGN_OR_RETURN(
std::unique_ptr<const ModulusParams> modulus_parameters,
ModulusParams::Create(context_params.modulus));
// Create the NTT parameters.
RLWE_ASSIGN_OR_RETURN(NttParameters<ModularInt> ntt_params_temp,
InitializeNttParameters<ModularInt>(
context_params.log_n, modulus_parameters.get()));
std::unique_ptr<const NttParameters<ModularInt>> ntt_params =
std::make_unique<const NttParameters<ModularInt>>(
std::move(ntt_params_temp));
// Create the error parameters.
RLWE_ASSIGN_OR_RETURN(ErrorParams<ModularInt> error_params_temp,
ErrorParams<ModularInt>::Create(
context_params.log_t, context_params.variance,
modulus_parameters.get(), ntt_params.get()));
std::unique_ptr<const ErrorParams<ModularInt>> error_params =
std::make_unique<const ErrorParams<ModularInt>>(
std::move(error_params_temp));
return absl::WrapUnique<const RlweContext>(
new RlweContext(std::move(modulus_parameters), std::move(ntt_params),
std::move(error_params), std::move(context_params)));
}
// Disallow copy and copy-assign, allow move and move-assign.
RlweContext(const RlweContext&) = delete;
RlweContext& operator=(const RlweContext&) = delete;
RlweContext(RlweContext&&) = default;
RlweContext& operator=(RlweContext&&) = default;
~RlweContext() = default;
// Getters.
const ModulusParams* GetModulusParams() const {
return modulus_parameters_.get();
}
const NttParameters<ModularInt>* GetNttParams() const {
return ntt_parameters_.get();
}
const ErrorParams<ModularInt>* GetErrorParams() const {
return error_parameters_.get();
}
const Int GetModulus() const { return context_params_.modulus; }
const size_t GetLogN() const { return context_params_.log_n; }
const size_t GetN() const { return 1 << context_params_.log_n; }
const size_t GetLogT() const { return context_params_.log_t; }
const Int GetT() const {
return (static_cast<Int>(1) << context_params_.log_t) + static_cast<Int>(1);
}
const size_t GetVariance() const { return context_params_.variance; }
private:
RlweContext(std::unique_ptr<const ModulusParams> modulus_parameters,
std::unique_ptr<const NttParameters<ModularInt>> ntt_parameters,
std::unique_ptr<const ErrorParams<ModularInt>> error_parameters,
Parameters context_params)
: modulus_parameters_(std::move(modulus_parameters)),
ntt_parameters_(std::move(ntt_parameters)),
error_parameters_(std::move(error_parameters)),
context_params_(std::move(context_params)) {}
const std::unique_ptr<const ModulusParams> modulus_parameters_;
const std::unique_ptr<const NttParameters<ModularInt>> ntt_parameters_;
const std::unique_ptr<const ErrorParams<ModularInt>> error_parameters_;
const Parameters context_params_;
};
} // namespace rlwe
#endif // RLWE_CONTEXT_H_
/*
* Copyright 2020 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "context.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/numeric/int128.h"
#include "constants.h"
#include "integral_types.h"
#include "montgomery.h"
#include "status_macros.h"
#include "testing/parameters.h"
#include "testing/status_testing.h"
namespace {
template <typename ModularInt>
class ContextTest : public ::testing::Test {};
TYPED_TEST_SUITE(ContextTest, rlwe::testing::ModularIntTypes);
TYPED_TEST(ContextTest, CreateWorks) {
for (const auto& params :
rlwe::testing::ContextParameters<TypeParam>::value) {
ASSERT_OK_AND_ASSIGN(auto context,
rlwe::RlweContext<TypeParam>::Create(params));
}
}
TYPED_TEST(ContextTest, ParametersMatch) {
for (const auto& params :
rlwe::testing::ContextParameters<TypeParam>::value) {
ASSERT_OK_AND_ASSIGN(auto context,
rlwe::RlweContext<TypeParam>::Create(params));
ASSERT_EQ(context->GetLogN(), params.log_n);
ASSERT_EQ(context->GetN(), context->GetNttParams()->number_coeffs);
ASSERT_EQ(context->GetLogT(), params.log_t);
ASSERT_EQ(context->GetModulus(), params.modulus);
ASSERT_EQ(context->GetModulus(), context->GetModulusParams()->modulus);
ASSERT_EQ(context->GetVariance(), params.variance);
}
}
} // namespace
......@@ -37,7 +37,8 @@ template <typename ModularInt>
class ErrorParams {
public:
static rlwe::StatusOr<ErrorParams> Create(
const int log_t, Uint64 variance, typename ModularInt::Params* params,
const int log_t, Uint64 variance,
const typename ModularInt::Params* params,
const rlwe::NttParameters<ModularInt>* ntt_params) {
if (log_t > params->log_modulus - 1) {
return absl::InvalidArgumentError(
......@@ -74,7 +75,7 @@ class ErrorParams {
private:
// Constructor to set up the params.
ErrorParams(const int log_t, Uint64 variance,
typename ModularInt::Params* params,
const typename ModularInt::Params* params,
const rlwe::NttParameters<ModularInt>* ntt_params)
: t_(params->One()) {
t_ = (params->One() << log_t) + params->One();
......
......@@ -90,7 +90,7 @@ class GaloisKey {
// 2^{log_decomposition_modulus}. Crashes for non-valid input parameters.
static rlwe::StatusOr<GaloisKey> Deserialize(
const SerializedGaloisKey& serialized,
typename ModularInt::Params* modulus_params,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
RLWE_ASSIGN_OR_RETURN(RelinearizationKey<ModularInt> key,
RelinearizationKey<ModularInt>::Deserialize(
......
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
exports_files(["LICENSE"])
cc_library(
name = "glog",
srcs = [
......
......@@ -19,7 +19,7 @@
#include <iostream>
#include <sstream>
#include "glog/logging.h"
#include <glog/logging.h>
#include "absl/numeric/int128.h"
namespace rlwe {
......
......@@ -21,7 +21,7 @@
#include <type_traits>
#include <utility>
#include "glog/logging.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "absl/container/fixed_array.h"
#include "absl/numeric/int128.h"
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_INTEGRAL_TYPES_H_
#define RLWE_INTEGRAL_TYPES_H_
......
This diff is collapsed.
......@@ -32,7 +32,7 @@ namespace internal {
template <typename ModularInt>
void FillWithEveryPower(const ModularInt& base, unsigned int n,
std::vector<ModularInt>* row,
typename ModularInt::Params* params) {
const typename ModularInt::Params* params) {
for (int i = 0; i < n; i++) {
(*row)[i].AddInPlace(base.ModExp(i, params), params);
}
......@@ -40,7 +40,7 @@ void FillWithEveryPower(const ModularInt& base, unsigned int n,
template <typename ModularInt>
rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
unsigned int log_n, typename ModularInt::Params* params) {
unsigned int log_n, const typename ModularInt::Params* params) {
typename ModularInt::Int n = params->One() << log_n;
typename ModularInt::Int half_n = n >> 1;
......@@ -78,7 +78,7 @@ rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
// Each item of the vector is in modular integer representation.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsis(
unsigned int log_n, typename ModularInt::Params* params) {
unsigned int log_n, const typename ModularInt::Params* params) {
// Obtain psi, a primitive 2n-th root of unity (hence log_n + 1).
RLWE_ASSIGN_OR_RETURN(
ModularInt psi,
......@@ -116,7 +116,7 @@ static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
// bitreversed powers of the primitive 2n-th root of unity.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
unsigned int log_n, typename ModularInt::Params* params) {
unsigned int log_n, const typename ModularInt::Params* params) {
// Retrieve the table for the forward transformation.
RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis,
internal::NttPsis<ModularInt>(log_n, params));
......@@ -129,7 +129,7 @@ rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
// of the bitreversed powers of the primitive 2n-th root of unity plus 1.
template <typename ModularInt>
rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev(
unsigned int log_n, typename ModularInt::Params* params) {
unsigned int log_n, const typename ModularInt::Params* params) {
// Retrieve the table for the forward transformation.
RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row,
internal::NttPsis<ModularInt>(log_n, params));
......@@ -168,7 +168,7 @@ struct NttParameters {
~NttParameters() = default;
int number_coeffs;
std::unique_ptr<ModularInt> n_inv_ptr;
absl::optional<ModularInt> n_inv_ptr;
std::vector<ModularInt> psis_bitrev;
std::vector<ModularInt> psis_inv_bitrev;
std::vector<unsigned int> bitrevs;
......@@ -178,7 +178,7 @@ struct NttParameters {
// Does not take ownership of params.
template <typename ModularInt>
rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
int log_n, typename ModularInt::Params* params) {
int log_n, const typename ModularInt::Params* params) {
// Abort if log_n is non-positive.
if (log_n <= 0) {
return absl::InvalidArgumentError("log_n must be positive");
......@@ -203,11 +203,10 @@ rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n));
}
// 1-dimensional vector containing the inverse of n
// Compute the inverse of n.
typename ModularInt::Int n = params->One() << log_n;
RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params));
auto minv = mn.MultiplicativeInverse(params);
output.n_inv_ptr = absl::make_unique<ModularInt>(minv);
output.n_inv_ptr = mn.MultiplicativeInverse(params);
RLWE_ASSIGN_OR_RETURN(output.psis_bitrev,
NttPsisBitrev<ModularInt>(log_n, params));
......
......@@ -85,12 +85,16 @@ class PolynomialTest : public ::testing::Test {
q_.reset(new CoefficientPolynomial(q_coeffs, params14_.get()));
// Acquire all of the NTT parameters.
ASSERT_OK_AND_ASSIGN(ntt_params_, rlwe::InitializeNttParameters<uint_m>(
log_n, params14_.get()));
ASSERT_OK_AND_ASSIGN(auto ntt_params, rlwe::InitializeNttParameters<uint_m>(
log_n, params14_.get()));
ntt_params_ =
absl::make_unique<rlwe::NttParameters<uint_m>>(std::move(ntt_params));
// Put p and q in the NTT domain.
ntt_p_ = Polynomial::ConvertToNtt(p_coeffs, ntt_params_, params14_.get());
ntt_q_ = Polynomial::ConvertToNtt(q_coeffs, ntt_params_, params14_.get());
ntt_p_ =
Polynomial::ConvertToNtt(p_coeffs, ntt_params_.get(), params14_.get());
ntt_q_ =
Polynomial::ConvertToNtt(q_coeffs, ntt_params_.get(), params14_.get());
}
std::unique_ptr<Prng> MakePrng(absl::string_view seed) {
......@@ -98,8 +102,8 @@ class PolynomialTest : public ::testing::Test {
return prng;
}
std::unique_ptr<uint_m::Params> params14_;
rlwe::NttParameters<uint_m> ntt_params_;
std::unique_ptr<const uint_m::Params> params14_;
std::unique_ptr<rlwe::NttParameters<uint_m>> ntt_params_;
std::unique_ptr<CoefficientPolynomial> p_;
std::unique_ptr<CoefficientPolynomial> q_;
Polynomial ntt_p_;
......@@ -132,8 +136,8 @@ TYPED_TEST(PolynomialTest, CoeffsCorrectlyReturnsCoefficients) {
v.push_back(elt);
}
Polynomial ntt_v =
Polynomial::ConvertToNtt(v, this->ntt_params_, this->params14_.get());
Polynomial ntt_v = Polynomial::ConvertToNtt(v, this->ntt_params_.get(),
this->params14_.get());
for (int j = 0; j < n; j++) {
EXPECT_EQ(ntt_v.Coeffs()[j].ExportInt(this->params14_.get()),
......@@ -167,7 +171,7 @@ TYPED_TEST(PolynomialTest, Symmetry) {
this->SetParams(1 << i, i);
EXPECT_TRUE(this->ntt_p_.IsValid());
CoefficientPolynomial p_prime(
this->ntt_p_.InverseNtt(this->ntt_params_, this->params14_.get()),
this->ntt_p_.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
EXPECT_EQ(*this->p_, p_prime);
}
......@@ -218,12 +222,12 @@ TYPED_TEST(PolynomialTest, BinopOfDifferentLengths) {
std::vector<uint_m> x(1 << i, this->zero_);
std::vector<uint_m> y(1 << j, this->zero_);
this->ntt_params_.bitrevs = rlwe::internal::BitrevArray(i);
Polynomial ntt_x =
Polynomial::ConvertToNtt(x, this->ntt_params_, this->params14_.get());
this->ntt_params_.bitrevs = rlwe::internal::BitrevArray(j);
Polynomial ntt_y =
Polynomial::ConvertToNtt(y, this->ntt_params_, this->params14_.get());
this->ntt_params_->bitrevs = rlwe::internal::BitrevArray(i);
Polynomial ntt_x = Polynomial::ConvertToNtt(x, this->ntt_params_.get(),
this->params14_.get());
this->ntt_params_->bitrevs = rlwe::internal::BitrevArray(j);
Polynomial ntt_y = Polynomial::ConvertToNtt(y, this->ntt_params_.get(),
this->params14_.get());
EXPECT_TRUE(ntt_x.IsValid());
EXPECT_TRUE(ntt_y.IsValid());
......@@ -293,10 +297,10 @@ TYPED_TEST(PolynomialTest, Multiply) {
this->ntt_q_.Mul(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected,
......@@ -322,7 +326,7 @@ TYPED_TEST(PolynomialTest, ScalarMultiply) {
this->ntt_p_.Mul(scalar, this->params14_.get()));
CoefficientPolynomial res(
ntt_res.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
CoefficientPolynomial expected = (*this->p_) * scalar;
......@@ -337,9 +341,9 @@ TYPED_TEST(PolynomialTest, Negate) {
this->SetParams(1 << i, i);
// An NTT polynomial of all zeros.
Polynomial zeros_ntt =
Polynomial::ConvertToNtt(std::vector<uint_m>(1 << i, this->zero_),
this->ntt_params_, this->params14_.get());
Polynomial zeros_ntt = Polynomial::ConvertToNtt(
std::vector<uint_m>(1 << i, this->zero_), this->ntt_params_.get(),
this->params14_.get());
auto minus_p = this->ntt_p_.Negate(this->params14_.get());
ASSERT_OK_AND_ASSIGN(auto p0,
......@@ -363,10 +367,10 @@ TYPED_TEST(PolynomialTest, Add) {
this->ntt_q_.Add(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected,
......@@ -391,10 +395,10 @@ TYPED_TEST(PolynomialTest, Sub) {
this->ntt_q_.Sub(this->ntt_p_, this->params14_.get()));
CoefficientPolynomial res1(
ntt_res1.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res1.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
CoefficientPolynomial res2(
ntt_res2.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res2.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
ASSERT_OK_AND_ASSIGN(CoefficientPolynomial expected_res1,
......@@ -413,18 +417,20 @@ TYPED_TEST(PolynomialTest, SubstitutionPowerMalformed) {
this->SetParams(1 << i, i);
EXPECT_THAT(
this->ntt_p_.Substitute(2, this->ntt_params_, this->params14_.get()),
this->ntt_p_.Substitute(2, this->ntt_params_.get(),
this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than")));
// Even when not in debugging mode, the following two tests will yield a
// segmentation fault. We therefore only do the tests in debug mode.
EXPECT_THAT(
this->ntt_p_.Substitute(-10, this->ntt_params_, this->params14_.get()),
this->ntt_p_.Substitute(-10, this->ntt_params_.get(),
this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than")));
EXPECT_THAT(
this->ntt_p_.Substitute(2 * (1 << i) + 1, this->ntt_params_,
this->ntt_p_.Substitute(2 * (1 << i) + 1, this->ntt_params_.get(),
this->params14_.get()),
StatusIs(::absl::StatusCode::kInvalidArgument,
HasSubstr("must be a non-negative odd integer less than")));
......@@ -438,11 +444,11 @@ TYPED_TEST(PolynomialTest, Substitution) {
int dimension = 1 << i;
for (int k = 0; k < i; k++) {
int power = (dimension >> k) + 1;
ASSERT_OK_AND_ASSIGN(auto ntt_res,
this->ntt_p_.Substitute(power, this->ntt_params_,
this->params14_.get()));
ASSERT_OK_AND_ASSIGN(
auto ntt_res, this->ntt_p_.Substitute(power, this->ntt_params_.get(),
this->params14_.get()));
CoefficientPolynomial res(
ntt_res.InverseNtt(this->ntt_params_, this->params14_.get()),
ntt_res.InverseNtt(this->ntt_params_.get(), this->params14_.get()),
this->params14_.get());
ASSERT_OK_AND_ASSIGN(auto r, this->p_->Substitute(power));
......
......@@ -14,7 +14,9 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
licenses(["notice"])
exports_files(["LICENSE"])
# PRNG interface.
......@@ -44,7 +46,7 @@ cc_library(
deps = [
":chacha_prng",
":single_thread_chacha_prng",
"@com_google_googletest//:gtest",
"@com_github_google_googletest//:gtest",
],
)
......@@ -61,8 +63,8 @@ cc_test(
":integral_prng_types",
"//testing:matchers",
"//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
......@@ -124,8 +126,8 @@ cc_test(
":chacha_prng",
"//testing:matchers",
"//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
......@@ -139,7 +141,7 @@ cc_test(
":single_thread_chacha_prng",
"//testing:matchers",
"//testing:status_testing",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "prng/chacha_prng_util.h"
#include <cstdint>
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// An implementation of a PRNG using the ChaCha20 stream cipher. Since this is
// a stream cipher, the key stream can be obtained by "encrypting" the plaintext
// 0....0.
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
#define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
#define RLWE_PRNG_INTEGRAL_PRNG_TYPE_H_
......
......@@ -16,6 +16,7 @@
#include "relinearization_key.h"
#include "absl/numeric/int128.h"
#include "bits_util.h"
#include "montgomery.h"
#include "prng/integral_prng_types.h"
#include "status_macros.h"
......@@ -25,18 +26,13 @@
namespace rlwe {
namespace {
// Method to compute the number of digits needed to represent integers mod
// q in base T. Upcasts the modulus to absl::uint128 to handle all uint*_t
// q in base T. Upcasts the modulus to absl::uint128 to handle all Uint*
// types.
inline int ComputeDimension(Uint64 log_decomposition_modulus,
absl::uint128 modulus) {
double modulus_bits;
// Compute bit lengths of each as a double and divide.
if (absl::Uint128High64(modulus) > 0) {
modulus_bits = std::log2(absl::Uint128High64(modulus)) + 64;
} else {
modulus_bits = std::log2(absl::Uint128Low64(modulus));
}
return std::ceil(modulus_bits / log_decomposition_modulus);
Uint64 modulus_bits = static_cast<Uint64>(internal::BitLength(modulus));
return (modulus_bits + (log_decomposition_modulus - 1)) /
log_decomposition_modulus;
}
// Returns a random vector r orthogonal to (1,s). The second component is chosen
......@@ -83,30 +79,26 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT(
template <typename ModularInt>
rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose(
const std::vector<ModularInt>& coefficients,
typename ModularInt::Params* modulus_params,
const typename ModularInt::Params* modulus_params,
const Uint64 log_decomposition_modulus, int dimension) {
std::vector<std::vector<ModularInt>> result;
std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(),
0);
std::transform(
coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(),
[modulus_params](ModularInt x) { return x.ExportInt(modulus_params); });
auto zero = ModularInt::ImportZero(modulus_params);
std::vector<ModularInt> sum_part(ciphertext_coeffs.size(), zero);
std::vector<std::vector<ModularInt>> result(dimension);
for (int i = 0; i < dimension; i++) {
result[i].reserve(ciphertext_coeffs.size());
for (int j = 0; j < ciphertext_coeffs.size(); ++j) {
RLWE_ASSIGN_OR_RETURN(
sum_part[j],
auto coefficient_part,
ModularInt::ImportInt(
(ciphertext_coeffs[j] % (1L << log_decomposition_modulus)),
modulus_params));
result[i].push_back(std::move(coefficient_part));
ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus;
}
result.push_back(sum_part);
}
return result;
......@@ -116,7 +108,7 @@ template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
std::vector<std::vector<ModularInt>> decomposed_coefficients,
const std::vector<std::vector<Polynomial<ModularInt>>>& matrix,
typename ModularInt::Params* modulus_params,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params);
Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params);
......@@ -125,7 +117,7 @@ rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
for (int i = 0; i < matrix[0].size(); i++) {
ntt_part = Polynomial<ModularInt>::ConvertToNtt(
std::move(decomposed_coefficients[i]), *ntt_params, modulus_params);
std::move(decomposed_coefficients[i]), ntt_params, modulus_params);
RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params));
RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params));
RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params))
......@@ -166,8 +158,8 @@ RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create(
key_power.Len(), key.Variance(), prng_encryption,
key.ModulusParams()));
// Convert the error coefficients into an error polynomial.
auto e = Polynomial<ModularInt>::ConvertToNtt(error, *key.NttParams(),
key.ModulusParams());
auto e = Polynomial<ModularInt>::ConvertToNtt(
std::move(error), key.NttParams(), key.ModulusParams());
// Set the column of the Relinearization matrix.
RLWE_RETURN_IF_ERROR(
e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
......@@ -184,11 +176,11 @@ template <typename ModularInt>
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo(
const Polynomial<ModularInt>& ciphertext_part,
typename ModularInt::Params* modulus_params,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) const {
// Convert ciphertext out of NTT form.
std::vector<ModularInt> ciphertext_coefficients =
ciphertext_part.InverseNtt(*ntt_params, modulus_params);
ciphertext_part.InverseNtt(ntt_params, modulus_params);
// Bit-decompose the vector of coefficients in the ciphertext.
RLWE_ASSIGN_OR_RETURN(
......@@ -206,14 +198,16 @@ rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize(
const std::vector<SerializedNttPolynomial>& polynomials,
Uint64 log_decomposition_modulus, SecurePrng* prng,
ModularIntParams* modulus_params,
const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
// The polynomials input is a flattened representation of a 2 x dimension
// matrix where the first half corresponds to the first row of matrix and the
// second half corresponds to the second row of matrix. This matrix makes up
// the RelinearizationKeyPart.
auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
int dimension = polynomials.size();
auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
matrix[0].reserve(dimension);
matrix[1].reserve(dimension);
for (int i = 0; i < dimension; i++) {
RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
......@@ -282,6 +276,7 @@ RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key,
auto dimension =
ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus);
std::vector<RelinearizationKeyPart> relinearization_key;
relinearization_key.reserve(num_parts);
// Create RealinearizationKeyPart for each of the secret key parts: s, ...,
// s^k.
for (int i = 1; i < num_parts; i++) {
......@@ -362,7 +357,7 @@ template <typename ModularInt>
rlwe::StatusOr<RelinearizationKey<ModularInt>>
RelinearizationKey<ModularInt>::Deserialize(
const SerializedRelinearizationKey& serialized,
typename ModularInt::Params* modulus_params,
const typename ModularInt::Params* modulus_params,
const NttParameters<ModularInt>* ntt_params) {
// Verifies that the number of polynomials in serialized is expected.
// A RelinearizationKey can decrypt ciphertexts with num_parts number of
......@@ -419,6 +414,7 @@ RelinearizationKey<ModularInt>::Deserialize(
// Takes each polynomials_per_matrix chunk of serialized.c()'s and places them
// into a KeyPart.
output.relinearization_key_.reserve(serialized.num_parts() - 1);
for (int i = 0; i < (serialized.num_parts() - 1); i++) {
auto start = serialized.c().begin() + i * polynomials_per_matrix;
auto end = start + polynomials_per_matrix;
......
......@@ -96,7 +96,7 @@ class RelinearizationKey {
// input parameters.
static rlwe::StatusOr<RelinearizationKey> Deserialize(
const SerializedRelinearizationKey& serialized,
ModularIntParams* modulus_params,
const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params);
// Substitution Power accessor.
......@@ -120,14 +120,14 @@ class RelinearizationKey {
// of polynomials.
rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> ApplyPartTo(
const Polynomial<ModularInt>& ciphertext_part,
ModularIntParams* modulus_params,
const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params) const;
// Creates a RelinearizationKeyPart out of a vector of Polynomials.
static rlwe::StatusOr<RelinearizationKeyPart> Deserialize(
const std::vector<SerializedNttPolynomial>& polynomials,
Uint64 log_decomposition_modulus, SecurePrng* prng,
ModularIntParams* modulus_params,
const ModularIntParams* modulus_params,
const NttParameters<ModularInt>* ntt_params);
std::vector<Polynomial<ModularInt>> Matrix() const { return matrix_[0]; }
......@@ -147,7 +147,7 @@ class RelinearizationKey {
// Creates an empty RelinearizationKey.
RelinearizationKey(Uint64 log_decomposition_modulus,
const ModularInt& decomposition_modulus,
ModularIntParams* params,
const ModularIntParams* params,
const NttParameters<ModularInt>* ntt_params)
: log_decomposition_modulus_(log_decomposition_modulus),
decomposition_modulus_(decomposition_modulus),
......@@ -176,7 +176,7 @@ class RelinearizationKey {
int substitution_power_;
// Modulus parameters. Does not take ownership.
ModularIntParams* modulus_params_;
const ModularIntParams* modulus_params_;
// NTT parameters. Does not take ownership.
const NttParameters<ModularInt>* ntt_params_;
......
......@@ -43,7 +43,7 @@ namespace rlwe {
template <typename ModularInt>
static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution(
unsigned int num_coeffs, Uint64 variance, SecurePrng* prng,
typename ModularInt::Params* modulus_params) {
const typename ModularInt::Params* modulus_params) {
if (variance > kMaxVariance) {
return absl::InvalidArgumentError(absl::StrCat(
"The variance, ", variance, ", must be at most ", kMaxVariance, "."));
......@@ -54,48 +54,46 @@ static rlwe::StatusOr<std::vector<ModularInt>> SampleFromErrorDistribution(
// Sample from the centered binomial distribution. To do so, we sample k pairs
// of bits (a, b), where k = 2 * variance. The sample of the binomial
// distribution is the sum of the differences between each pair of bits.
// This is implemented by splitting k in words k', drawing 2*k' bits, and
// computing the difference of Hamming weight between the first k' bits and
// the last k' bits, where
Uint64 k;
Uint64 coefficient_positive, coefficient_negative;
typename ModularInt::Int coefficient;
for (int i = 0; i < num_coeffs; i++) {
coefficient_positive = 0;
coefficient_negative = 0;
coefficient = modulus_params->modulus;
k = variance << 1;
while (k > 0) {
if (k >= 64) {
// Use all the bits
// Use 64 bits of randomness
RLWE_ASSIGN_OR_RETURN(auto r64, prng->Rand64());
coefficient_positive += rlwe::internal::CountOnes64(r64);
coefficient += rlwe::internal::CountOnes64(r64);
RLWE_ASSIGN_OR_RETURN(r64, prng->Rand64());
coefficient_negative += rlwe::internal::CountOnes64(r64);
coefficient -= rlwe::internal::CountOnes64(r64);
k -= 64;
} else if (k > 8) {
Uint64 mask = (1ULL << k) - 1;
RLWE_ASSIGN_OR_RETURN(auto r64, prng->Rand64());
coefficient_positive += rlwe::internal::CountOnes64(r64 & mask);
RLWE_ASSIGN_OR_RETURN(r64, prng->Rand64());
coefficient_negative += rlwe::internal::CountOnes64(r64 & mask);
k = 0;
} else if (k >= 8) {
// Use 8 bits of randomness
RLWE_ASSIGN_OR_RETURN(auto r8, prng->Rand8());
coefficient += rlwe::internal::CountOnesInByte(r8);
RLWE_ASSIGN_OR_RETURN(r8, prng->Rand8());
coefficient -= rlwe::internal::CountOnesInByte(r8);
k -= 8;
} else {
Uint8 mask = (1 << k) - 1;
RLWE_ASSIGN_OR_RETURN(auto r8, prng->Rand8());
coefficient_positive += rlwe::internal::CountOnesInByte(r8 & mask);
coefficient += rlwe::internal::CountOnesInByte(r8 & mask);
RLWE_ASSIGN_OR_RETURN(r8, prng->Rand8());
coefficient_negative += rlwe::internal::CountOnesInByte(r8 & mask);
k = 0;
coefficient -= rlwe::internal::CountOnesInByte(r8 & mask);
break; // all k remaining pairs have been sampled.
}
}
RLWE_ASSIGN_OR_RETURN(
coeffs[i], ModularInt::ImportInt(coefficient_positive, modulus_params));
RLWE_ASSIGN_OR_RETURN(
auto v, ModularInt::ImportInt(coefficient_negative, modulus_params));
coeffs[i].SubInPlace(v, modulus_params);
// coefficient is in the interval [modulus - 2k, modulus + 2k]. We reduce
// it in [0, modulus). Since ModularInt::Int is unsigned, we create a mask
// equal to 0xFF...FF when coefficient >= modulus, and equal to 0 otherwise.
typename ModularInt::Int mask = -(coefficient >= modulus_params->modulus);
coefficient -= mask & modulus_params->modulus;
RLWE_ASSIGN_OR_RETURN(coeffs[i],
ModularInt::ImportInt(coefficient, modulus_params));
}
return coeffs;
......
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef RLWE_STATUS_MACROS_H_
#define RLWE_STATUS_MACROS_H_
......
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "status_macros.h"
#include <sstream>
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment