Commit 34ac4b2f authored by Ryan Sleevi's avatar Ryan Sleevi Committed by Commit Bot

Switch //net/ntlm to use base::span<const uint8_t>

//net/ntlm makes extensive use of uint8_t* pointers with various
fixed size requirements. Convert these into defined-extent
base::span<>'s so that the compiler can do more work for us to
ensure that all the buffers are appropriately sized.

This also switches uses of std::basic_string<uint8_t> and
base::StringPiece into std::vector<uint8_t> / base::span<uint8_t>,
as these are all working on 'data' buffers.

Bug: 837308
Change-Id: Iae3f1933b2c948d77c841e6f9bc8ce04bfddc67d
Reviewed-on: https://chromium-review.googlesource.com/1056013
Commit-Queue: Ryan Sleevi <rsleevi@chromium.org>
Reviewed-by: default avatarAsanka Herath <asanka@chromium.org>
Cr-Commit-Position: refs/heads/master@{#558579}
parent 3e49d5cb
......@@ -921,6 +921,7 @@ component("net") {
"ntlm/ntlm_buffer_writer.h",
"ntlm/ntlm_client.cc",
"ntlm/ntlm_client.h",
"ntlm/ntlm_constants.cc",
"ntlm/ntlm_constants.h",
"proxy_resolution/dhcp_pac_file_adapter_fetcher_win.cc",
"proxy_resolution/dhcp_pac_file_adapter_fetcher_win.h",
......
......@@ -80,9 +80,8 @@ int HttpAuthHandlerNTLM::GenerateAuthTokenImpl(
}
}
ntlm::Buffer next_token = GetNextToken(
ntlm::Buffer(reinterpret_cast<const uint8_t*>(decoded_auth_data.data()),
decoded_auth_data.size()));
std::vector<uint8_t> next_token =
GetNextToken(base::as_bytes(base::make_span(decoded_auth_data)));
if (next_token.empty())
return ERR_UNEXPECTED;
......
......@@ -28,7 +28,9 @@
#endif
#include <string>
#include <vector>
#include "base/containers/span.h"
#include "base/strings/string16.h"
#include "net/base/net_export.h"
#include "net/http/http_auth_handler.h"
......@@ -149,7 +151,7 @@ class NET_EXPORT_PRIVATE HttpAuthHandlerNTLM : public HttpAuthHandler {
// Given an input token received from the server, generate the next output
// token to be sent to the server.
ntlm::Buffer GetNextToken(const ntlm::Buffer& in_token);
std::vector<uint8_t> GetNextToken(base::span<const uint8_t> in_token);
#endif
// Parse the challenge, saving the results into this instance.
......
......@@ -89,7 +89,8 @@ HttpAuthHandlerNTLM::Factory::Factory() = default;
HttpAuthHandlerNTLM::Factory::~Factory() = default;
ntlm::Buffer HttpAuthHandlerNTLM::GetNextToken(const ntlm::Buffer& in_token) {
std::vector<uint8_t> HttpAuthHandlerNTLM::GetNextToken(
base::span<const uint8_t> in_token) {
// If in_token is non-empty, then assume it contains a challenge message,
// and generate the Authenticate message in reply. Otherwise return the
// Negotiate message.
......@@ -99,7 +100,7 @@ ntlm::Buffer HttpAuthHandlerNTLM::GetNextToken(const ntlm::Buffer& in_token) {
std::string hostname = get_host_name_proc_();
if (hostname.empty())
return ntlm::Buffer();
return {};
uint8_t client_challenge[8];
generate_random_proc_(client_challenge, 8);
uint64_t client_time = get_ms_time_proc_();
......
......@@ -5,8 +5,10 @@
#include <string>
#include "base/base64.h"
#include "base/containers/span.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "build/build_config.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/mock_host_resolver.h"
#include "net/http/http_auth_challenge_tokenizer.h"
......@@ -52,19 +54,16 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
&auth_handler_);
}
std::string CreateNtlmAuthHeader(ntlm::Buffer message) {
std::string CreateNtlmAuthHeader(base::span<const uint8_t> buffer) {
std::string output;
base::Base64Encode(
base::StringPiece(reinterpret_cast<const char*>(message.data()),
message.size()),
base::StringPiece(reinterpret_cast<const char*>(buffer.data()),
buffer.size()),
&output);
return "NTLM " + output;
}
std::string CreateNtlmAuthHeader(const uint8_t* buffer, size_t length) {
return CreateNtlmAuthHeader(ntlm::Buffer(buffer, length));
}
HttpAuth::AuthorizationResult HandleAnotherChallenge(
const std::string& challenge) {
......@@ -85,10 +84,10 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
}
bool ReadBytesPayload(ntlm::NtlmBufferReader* reader,
uint8_t* buffer,
size_t len) {
base::span<uint8_t> buffer) {
ntlm::SecurityBuffer sec_buf;
return reader->ReadSecurityBuffer(&sec_buf) && (sec_buf.length == len) &&
return reader->ReadSecurityBuffer(&sec_buf) &&
(sec_buf.length == buffer.size()) &&
reader->ReadBytesFrom(sec_buf, buffer);
}
......@@ -99,11 +98,13 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
if (!reader->ReadSecurityBuffer(&sec_buf))
return false;
std::unique_ptr<uint8_t[]> raw(new uint8_t[sec_buf.length]);
if (!reader->ReadBytesFrom(sec_buf, raw.get()))
if (!reader->ReadBytesFrom(
sec_buf,
base::as_writable_bytes(base::make_span(
base::WriteInto(str, sec_buf.length + 1), sec_buf.length)))) {
return false;
}
str->assign(reinterpret_cast<const char*>(raw.get()), sec_buf.length);
return true;
}
......@@ -116,17 +117,17 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
EXPECT_TRUE(reader->ReadSecurityBuffer(&sec_buf));
EXPECT_EQ(0, sec_buf.length % 2);
std::unique_ptr<uint8_t[]> raw(new uint8_t[sec_buf.length]);
EXPECT_TRUE(reader->ReadBytesFrom(sec_buf, raw.get()));
std::vector<uint8_t> raw(sec_buf.length);
EXPECT_TRUE(reader->ReadBytesFrom(sec_buf, raw));
#ifdef IS_BIG_ENDIAN
for (size_t i = 0; i < sec_buf.length; i += 2) {
#if defined(ARCH_CPU_BIG_ENDIAN)
for (size_t i = 0; i < raw.size(); i += 2) {
std::swap(raw[i], raw[i + 1]);
}
#endif
str->assign(reinterpret_cast<const base::char16*>(raw.get()),
sec_buf.length / 2);
str->assign(reinterpret_cast<const base::char16*>(raw.data()),
raw.size() / 2);
}
int GetGenerateAuthTokenResult() {
......@@ -223,8 +224,7 @@ TEST_F(HttpAuthHandlerNtlmPortableTest, NtlmV1AuthenticationSuccess) {
std::string token;
ASSERT_EQ(HttpAuth::AUTHORIZATION_RESULT_ACCEPT,
HandleAnotherChallenge(
CreateNtlmAuthHeader(ntlm::test::kChallengeMsgV1,
arraysize(ntlm::test::kChallengeMsgV1))));
CreateNtlmAuthHeader(ntlm::test::kChallengeMsgV1)));
ASSERT_EQ(OK, GenerateAuthToken(&token));
// Validate the authenticate message
......
This diff is collapsed.
This diff is collapsed.
......@@ -11,19 +11,10 @@
namespace net {
namespace ntlm {
NtlmBufferReader::NtlmBufferReader() : NtlmBufferReader(nullptr, 0) {}
NtlmBufferReader::NtlmBufferReader() {}
NtlmBufferReader::NtlmBufferReader(const Buffer& buffer)
: NtlmBufferReader(
base::StringPiece(reinterpret_cast<const char*>(buffer.data()),
buffer.length())) {}
NtlmBufferReader::NtlmBufferReader(base::StringPiece str)
: buffer_(str), cursor_(0) {}
NtlmBufferReader::NtlmBufferReader(const uint8_t* ptr, size_t len)
: NtlmBufferReader(
base::StringPiece(reinterpret_cast<const char*>(ptr), len)) {}
NtlmBufferReader::NtlmBufferReader(base::span<const uint8_t> buffer)
: buffer_(buffer) {}
NtlmBufferReader::~NtlmBufferReader() = default;
......@@ -59,25 +50,22 @@ bool NtlmBufferReader::ReadFlags(NegotiateFlags* flags) {
return true;
}
bool NtlmBufferReader::ReadBytes(uint8_t* buffer, size_t len) {
if (!CanRead(len))
bool NtlmBufferReader::ReadBytes(base::span<uint8_t> buffer) {
if (!CanRead(buffer.size()))
return false;
memcpy(reinterpret_cast<void*>(buffer),
reinterpret_cast<const void*>(GetBufferAtCursor()), len);
memcpy(buffer.data(), GetBufferAtCursor(), buffer.size());
AdvanceCursor(len);
AdvanceCursor(buffer.size());
return true;
}
bool NtlmBufferReader::ReadBytesFrom(const SecurityBuffer& sec_buf,
uint8_t* buffer) {
if (!CanReadFrom(sec_buf))
base::span<uint8_t> buffer) {
if (!CanReadFrom(sec_buf) || buffer.size() < sec_buf.length)
return false;
memcpy(reinterpret_cast<void*>(buffer),
reinterpret_cast<const void*>(GetBufferPtr() + sec_buf.offset),
sec_buf.length);
memcpy(buffer.data(), GetBufferPtr() + sec_buf.offset, sec_buf.length);
return true;
}
......@@ -87,7 +75,8 @@ bool NtlmBufferReader::ReadPayloadAsBufferReader(const SecurityBuffer& sec_buf,
if (!CanReadFrom(sec_buf))
return false;
*reader = NtlmBufferReader(GetBufferPtr() + sec_buf.offset, sec_buf.length);
*reader = NtlmBufferReader(
base::make_span(GetBufferPtr() + sec_buf.offset, sec_buf.length));
return true;
}
......@@ -139,7 +128,7 @@ bool NtlmBufferReader::ReadTargetInfo(size_t target_info_len,
return false;
// Take a copy of the payload in the AVPair.
pair.buffer.assign(GetBufferAtCursor(), pair.avlen);
pair.buffer.assign(GetBufferAtCursor(), GetBufferAtCursor() + pair.avlen);
if (pair.avid == TargetInfoAvId::kEol) {
// Terminator must have zero length.
if (pair.avlen != 0)
......
......@@ -11,7 +11,7 @@
#include <string>
#include <vector>
#include "base/strings/string_piece.h"
#include "base/containers/span.h"
#include "net/base/net_export.h"
#include "net/ntlm/ntlm_constants.h"
......@@ -49,15 +49,11 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
public:
NtlmBufferReader();
// |buffer| is not copied and must outlive the |NtlmBufferReader|.
explicit NtlmBufferReader(const Buffer& buffer);
explicit NtlmBufferReader(base::StringPiece buffer);
explicit NtlmBufferReader(base::span<const uint8_t> buffer);
// This class does not take ownership of |ptr|, so the caller must ensure
// that the buffer outlives the |NtlmBufferReader|.
NtlmBufferReader(const uint8_t* ptr, size_t len);
~NtlmBufferReader();
size_t GetLength() const { return buffer_.length(); }
size_t GetLength() const { return buffer_.size(); }
size_t GetCursor() const { return cursor_; }
bool IsEndOfBuffer() const { return cursor_ >= GetLength(); }
......@@ -92,14 +88,14 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
bool ReadFlags(NegotiateFlags* flags) WARN_UNUSED_RESULT;
// Reads |len| bytes and copies them into |buffer|.
bool ReadBytes(uint8_t* buffer, size_t len) WARN_UNUSED_RESULT;
bool ReadBytes(base::span<uint8_t> buffer) WARN_UNUSED_RESULT;
// Reads |sec_buf.length| bytes from offset |sec_buf.offset| and copies them
// into |buffer|. If the security buffer specifies a payload outside the
// buffer, then the call fails. Unlike the other Read* methods, this does
// not move the cursor.
bool ReadBytesFrom(const SecurityBuffer& sec_buf,
uint8_t* buffer) WARN_UNUSED_RESULT;
base::span<uint8_t> buffer) WARN_UNUSED_RESULT;
// Reads |sec_buf.length| bytes from offset |sec_buf.offset| and assigns
// |reader| an |NtlmBufferReader| representing the payload. If the security
......@@ -207,9 +203,7 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); }
// Returns a constant pointer to the start of the buffer.
const uint8_t* GetBufferPtr() const {
return reinterpret_cast<const uint8_t*>(buffer_.data());
}
const uint8_t* GetBufferPtr() const { return buffer_.data(); }
// Returns a pointer to the underlying buffer at the current cursor
// position.
......@@ -221,8 +215,8 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
return *(GetBufferAtCursor());
}
base::StringPiece buffer_;
size_t cursor_;
base::span<const uint8_t> buffer_;
size_t cursor_ = 0;
};
} // namespace ntlm
......
This diff is collapsed.
......@@ -48,21 +48,15 @@ bool NtlmBufferWriter::WriteFlags(NegotiateFlags flags) {
return WriteUInt32(static_cast<uint32_t>(flags));
}
bool NtlmBufferWriter::WriteBytes(const uint8_t* buffer, size_t len) {
if (!CanWrite(len))
bool NtlmBufferWriter::WriteBytes(base::span<const uint8_t> bytes) {
if (!CanWrite(bytes.size()))
return false;
memcpy(reinterpret_cast<void*>(GetBufferPtrAtCursor()),
reinterpret_cast<const void*>(buffer), len);
AdvanceCursor(len);
memcpy(GetBufferPtrAtCursor(), bytes.data(), bytes.size());
AdvanceCursor(bytes.size());
return true;
}
bool NtlmBufferWriter::WriteBytes(const Buffer& bytes) {
return WriteBytes(bytes.data(), bytes.length());
}
bool NtlmBufferWriter::WriteZeros(size_t count) {
if (!CanWrite(count))
return false;
......@@ -105,8 +99,7 @@ bool NtlmBufferWriter::WriteAvPair(const AvPair& pair) {
}
bool NtlmBufferWriter::WriteUtf8String(const std::string& str) {
return WriteBytes(reinterpret_cast<const uint8_t*>(str.c_str()),
str.length());
return WriteBytes(base::as_bytes(base::make_span(str)));
}
bool NtlmBufferWriter::WriteUtf16AsUtf8String(const base::string16& str) {
......@@ -120,7 +113,7 @@ bool NtlmBufferWriter::WriteUtf8AsUtf16String(const std::string& str) {
}
bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) {
size_t num_bytes = str.length() * 2;
size_t num_bytes = str.size() * 2;
if (!CanWrite(num_bytes))
return false;
......@@ -142,7 +135,7 @@ bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) {
}
bool NtlmBufferWriter::WriteSignature() {
return WriteBytes(kSignature, kSignatureLen);
return WriteBytes(kSignature);
}
bool NtlmBufferWriter::WriteMessageType(MessageType message_type) {
......
......@@ -11,6 +11,7 @@
#include <memory>
#include <string>
#include "base/containers/span.h"
#include "base/strings/string16.h"
#include "base/strings/string_piece.h"
#include "net/base/net_export.h"
......@@ -49,8 +50,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
size_t GetLength() const { return buffer_.size(); }
size_t GetCursor() const { return cursor_; }
bool IsEndOfBuffer() const { return cursor_ >= GetLength(); }
const Buffer& GetBuffer() const { return buffer_; }
Buffer Pass() const { return std::move(buffer_); }
base::span<const uint8_t> GetBuffer() const { return buffer_; }
std::vector<uint8_t> Pass() const { return std::move(buffer_); }
// Returns true if there are |len| more bytes between the current cursor
// position and the end of the buffer.
......@@ -71,13 +72,9 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
// Writes flags as a 32 bit unsigned value (little endian).
bool WriteFlags(NegotiateFlags flags) WARN_UNUSED_RESULT;
// Writes |len| bytes from |buffer|. If there are not |len| more bytes in
// the buffer, it returns false.
bool WriteBytes(const uint8_t* buffer, size_t len) WARN_UNUSED_RESULT;
// Writes the bytes from the |Buffer|. If there are not enough
// Writes the bytes from the |buffer|. If there are not enough
// bytes in the buffer, it returns false.
bool WriteBytes(const Buffer& buffer) WARN_UNUSED_RESULT;
bool WriteBytes(base::span<const uint8_t> buffer) WARN_UNUSED_RESULT;
// Writes |count| bytes of zeros to the buffer. If there are not |count|
// more bytes in available in the buffer, it returns false.
......@@ -179,8 +176,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); }
// Returns a pointer to the start of the buffer.
const uint8_t* GetBufferPtr() const { return &buffer_[0]; }
uint8_t* GetBufferPtr() { return &buffer_[0]; }
const uint8_t* GetBufferPtr() const { return buffer_.data(); }
uint8_t* GetBufferPtr() { return buffer_.data(); }
// Returns pointer into the buffer at the current cursor location.
const uint8_t* GetBufferPtrAtCursor() const {
......@@ -188,7 +185,7 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
}
uint8_t* GetBufferPtrAtCursor() { return GetBufferPtr() + GetCursor(); }
Buffer buffer_;
std::vector<uint8_t> buffer_;
size_t cursor_;
DISALLOW_COPY_AND_ASSIGN(NtlmBufferWriter);
......
......@@ -4,7 +4,7 @@
#include "net/ntlm/ntlm_buffer_writer.h"
#include "base/macros.h"
#include "base/stl_util.h"
#include "base/strings/utf_string_conversions.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -30,7 +30,7 @@ TEST(NtlmBufferWriterTest, Initialization) {
NtlmBufferWriter writer(1);
ASSERT_EQ(1u, writer.GetLength());
ASSERT_EQ(1u, writer.GetBuffer().length());
ASSERT_EQ(1u, writer.GetBuffer().size());
ASSERT_EQ(0u, writer.GetCursor());
ASSERT_FALSE(writer.IsEndOfBuffer());
ASSERT_TRUE(writer.CanWrite(1));
......@@ -45,11 +45,11 @@ TEST(NtlmBufferWriterTest, Write16) {
ASSERT_TRUE(writer.WriteUInt16(value));
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(arraysize(expected), writer.GetLength());
ASSERT_EQ(base::size(expected), writer.GetLength());
ASSERT_FALSE(writer.WriteUInt16(value));
ASSERT_EQ(0,
memcmp(expected, writer.GetBuffer().data(), arraysize(expected)));
memcmp(expected, writer.GetBuffer().data(), base::size(expected)));
}
TEST(NtlmBufferWriterTest, Write16PastEob) {
......@@ -69,7 +69,7 @@ TEST(NtlmBufferWriterTest, Write32) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUInt32(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, Write32PastEob) {
......@@ -89,7 +89,7 @@ TEST(NtlmBufferWriterTest, Write64) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUInt64(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, Write64PastEob) {
......@@ -102,22 +102,22 @@ TEST(NtlmBufferWriterTest, Write64PastEob) {
TEST(NtlmBufferWriterTest, WriteBytes) {
uint8_t expected[8] = {0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11};
NtlmBufferWriter writer(arraysize(expected));
NtlmBufferWriter writer(base::size(expected));
ASSERT_TRUE(writer.WriteBytes(expected, arraysize(expected)));
ASSERT_EQ(0, memcmp(GetBufferPtr(writer), expected, arraysize(expected)));
ASSERT_TRUE(writer.WriteBytes(expected));
ASSERT_EQ(0, memcmp(GetBufferPtr(writer), expected, base::size(expected)));
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteBytes(expected, 1));
ASSERT_FALSE(writer.WriteBytes(base::make_span(expected, 1)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteBytesPastEob) {
uint8_t buffer[8];
NtlmBufferWriter writer(arraysize(buffer) - 1);
NtlmBufferWriter writer(base::size(buffer) - 1);
ASSERT_FALSE(writer.WriteBytes(buffer, arraysize(buffer)));
ASSERT_FALSE(writer.WriteBytes(buffer));
}
TEST(NtlmBufferWriterTest, WriteSecurityBuffer) {
......@@ -131,7 +131,7 @@ TEST(NtlmBufferWriterTest, WriteSecurityBuffer) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteSecurityBuffer(SecurityBuffer(offset, length)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteSecurityBufferPastEob) {
......@@ -151,7 +151,7 @@ TEST(NtlmBufferWriterTest, WriteNarrowString) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf8String(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteAsciiStringPastEob) {
......@@ -172,7 +172,7 @@ TEST(NtlmBufferWriterTest, WriteUtf16String) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf16String(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteUtf16StringPastEob) {
......@@ -193,7 +193,7 @@ TEST(NtlmBufferWriterTest, WriteUtf8AsUtf16String) {
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf8AsUtf16String(input));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteSignature) {
......@@ -203,7 +203,7 @@ TEST(NtlmBufferWriterTest, WriteSignature) {
ASSERT_TRUE(writer.WriteSignature());
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteSignaturePastEob) {
......@@ -232,12 +232,12 @@ TEST(NtlmBufferWriterTest, WriteMessageTypePastEob) {
TEST(NtlmBufferWriterTest, WriteAvPairHeader) {
const uint8_t expected[4] = {0x06, 0x00, 0x11, 0x22};
NtlmBufferWriter writer(4);
NtlmBufferWriter writer(base::size(expected));
ASSERT_TRUE(writer.WriteAvPairHeader(TargetInfoAvId::kFlags, 0x2211));
ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
}
TEST(NtlmBufferWriterTest, WriteAvPairHeaderPastEob) {
......
......@@ -19,31 +19,31 @@ namespace ntlm {
namespace {
// Parses the challenge message and returns the |challenge_flags| and
// |server_challenge| into the supplied buffer.
// |server_challenge| must contain at least 8 bytes.
bool ParseChallengeMessage(const Buffer& challenge_message,
bool ParseChallengeMessage(
base::span<const uint8_t> challenge_message,
NegotiateFlags* challenge_flags,
uint8_t* server_challenge) {
base::span<uint8_t, kChallengeLen> server_challenge) {
NtlmBufferReader challenge_reader(challenge_message);
return challenge_reader.MatchMessageHeader(MessageType::kChallenge) &&
challenge_reader.SkipSecurityBufferWithValidation() &&
challenge_reader.ReadFlags(challenge_flags) &&
challenge_reader.ReadBytes(server_challenge, kChallengeLen);
challenge_reader.ReadBytes(server_challenge);
}
// Parses the challenge message and extracts the information necessary to
// make an NTLMv2 response.
// |server_challenge| must contain at least 8 bytes.
bool ParseChallengeMessageV2(const Buffer& challenge_message,
bool ParseChallengeMessageV2(
base::span<const uint8_t> challenge_message,
NegotiateFlags* challenge_flags,
uint8_t* server_challenge,
base::span<uint8_t, kChallengeLen> server_challenge,
std::vector<AvPair>* av_pairs) {
NtlmBufferReader challenge_reader(challenge_message);
return challenge_reader.MatchMessageHeader(MessageType::kChallenge) &&
challenge_reader.SkipSecurityBufferWithValidation() &&
challenge_reader.ReadFlags(challenge_flags) &&
challenge_reader.ReadBytes(server_challenge, kChallengeLen) &&
challenge_reader.ReadBytes(server_challenge) &&
challenge_reader.SkipBytes(8) &&
// challenge_reader.ReadTargetInfoPayload(av_pairs);
(((*challenge_flags & NegotiateFlags::kTargetInfo) ==
......@@ -71,27 +71,24 @@ bool WriteAuthenticateMessage(NtlmBufferWriter* authenticate_writer,
}
// Writes the NTLMv1 LM Response and NTLM Response.
// |lm_response| must contain |kResponseLenV1| bytes.
// |ntlm_response| must contain |kResponseLenV1| bytes.
bool WriteResponsePayloads(NtlmBufferWriter* authenticate_writer,
const uint8_t* lm_response,
const uint8_t* ntlm_response) {
return authenticate_writer->WriteBytes(lm_response, kResponseLenV1) &&
authenticate_writer->WriteBytes(ntlm_response, kResponseLenV1);
bool WriteResponsePayloads(
NtlmBufferWriter* authenticate_writer,
base::span<const uint8_t, kResponseLenV1> lm_response,
base::span<const uint8_t, kResponseLenV1> ntlm_response) {
return authenticate_writer->WriteBytes(lm_response) &&
authenticate_writer->WriteBytes(ntlm_response);
}
// Writes the |lm_response| and writes the NTLMv2 response by concatenating
// |v2_proof|, |v2_proof_input|, |updated_target_info| and 4 zero bytes.
//
// |lm_response| must contain |kResponseLenV1| bytes.
// |v2_proof| must contain |kNtlmProofLenV2| bytes.
bool WriteResponsePayloadsV2(NtlmBufferWriter* authenticate_writer,
const uint8_t* lm_response,
const uint8_t* v2_proof,
const Buffer& v2_proof_input,
const Buffer& updated_target_info) {
return authenticate_writer->WriteBytes(lm_response, kResponseLenV1) &&
authenticate_writer->WriteBytes(v2_proof, kNtlmProofLenV2) &&
bool WriteResponsePayloadsV2(
NtlmBufferWriter* authenticate_writer,
base::span<const uint8_t, kResponseLenV1> lm_response,
base::span<const uint8_t, kNtlmProofLenV2> v2_proof,
base::span<const uint8_t> v2_proof_input,
base::span<const uint8_t> updated_target_info) {
return authenticate_writer->WriteBytes(lm_response) &&
authenticate_writer->WriteBytes(v2_proof) &&
authenticate_writer->WriteBytes(v2_proof_input) &&
authenticate_writer->WriteBytes(updated_target_info) &&
authenticate_writer->WriteUInt32(0);
......@@ -146,7 +143,7 @@ NtlmClient::NtlmClient(NtlmFeatures features)
NtlmClient::~NtlmClient() = default;
Buffer NtlmClient::GetNegotiateMessage() const {
std::vector<uint8_t> NtlmClient::GetNegotiateMessage() const {
return negotiate_message_;
}
......@@ -164,7 +161,7 @@ void NtlmClient::GenerateNegotiateMessage() {
negotiate_message_ = writer.Pass();
}
Buffer NtlmClient::GenerateAuthenticateMessage(
std::vector<uint8_t> NtlmClient::GenerateAuthenticateMessage(
const base::string16& domain,
const base::string16& username,
const base::string16& password,
......@@ -172,8 +169,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
const std::string& channel_bindings,
const std::string& spn,
uint64_t client_time,
const uint8_t* client_challenge,
const Buffer& server_challenge_message) const {
base::span<const uint8_t, kChallengeLen> client_challenge,
base::span<const uint8_t> server_challenge_message) const {
// Limit the size of strings that are accepted. As an absolute limit any
// field represented by a |SecurityBuffer| or |AvPair| must be less than
// UINT16_MAX bytes long. The strings are restricted to the maximum sizes
......@@ -188,8 +185,9 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
// [2] - https://technet.microsoft.com/en-us/library/cc512606.aspx
if (hostname.length() > kMaxFqdnLen || domain.length() > kMaxFqdnLen ||
username.length() > kMaxUsernameLen ||
password.length() > kMaxPasswordLen)
return Buffer();
password.length() > kMaxPasswordLen) {
return {};
}
NegotiateFlags challenge_flags;
uint8_t server_challenge[kChallengeLen];
......@@ -197,8 +195,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
uint8_t ntlm_response[kResponseLenV1];
// Response fields only for NTLMv2
Buffer updated_target_info;
Buffer v2_proof_input;
std::vector<uint8_t> updated_target_info;
std::vector<uint8_t> v2_proof_input;
uint8_t v2_proof[kNtlmProofLenV2];
uint8_t v2_session_key[kSessionKeyLenV2];
......@@ -206,7 +204,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
std::vector<AvPair> av_pairs;
if (!ParseChallengeMessageV2(server_challenge_message, &challenge_flags,
server_challenge, &av_pairs)) {
return Buffer();
return {};
}
uint64_t timestamp;
......@@ -229,7 +227,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
} else {
if (!ParseChallengeMessage(server_challenge_message, &challenge_flags,
server_challenge)) {
return Buffer();
return {};
}
// Calculate the responses for the authenticate message.
......@@ -306,7 +304,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
DCHECK(authenticate_writer.IsEndOfBuffer());
DCHECK_EQ(authenticate_message_len, authenticate_writer.GetLength());
Buffer auth_msg = authenticate_writer.Pass();
std::vector<uint8_t> auth_msg = authenticate_writer.Pass();
// Backfill the MIC if enabled.
if (IsMicEnabled()) {
......@@ -314,9 +312,10 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
// set to zeros.
DCHECK_LT(kMicOffsetV2 + kMicLenV2, authenticate_message_len);
uint8_t* mic_ptr = reinterpret_cast<uint8_t*>(&auth_msg[kMicOffsetV2]);
base::span<uint8_t, kMicLenV2> mic(
const_cast<uint8_t*>(auth_msg.data()) + kMicOffsetV2, kMicLenV2);
GenerateMicV2(v2_session_key, negotiate_message_, server_challenge_message,
auth_msg, mic_ptr);
auth_msg, mic);
}
return auth_msg;
......
......@@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include "base/containers/span.h"
#include "base/strings/string16.h"
#include "base/strings/string_piece.h"
#include "net/base/net_export.h"
......@@ -50,11 +51,11 @@ class NET_EXPORT_PRIVATE NtlmClient {
bool IsEpaEnabled() const { return IsNtlmV2() && features_.enable_EPA; }
// Returns a |Buffer| containing the Negotiate message.
Buffer GetNegotiateMessage() const;
// Returns the Negotiate message.
std::vector<uint8_t> GetNegotiateMessage() const;
// Returns a |Buffer| containing the Authenticate message. If the method
// fails an empty |Buffer| is returned.
// Returns a the Authenticate message. If the method fails an empty vector
// is returned.
//
// |username| is treated case insensitively by NTLM however the mechanism
// to uppercase is not clearly defined. In this implementation the default
......@@ -81,12 +82,11 @@ class NET_EXPORT_PRIVATE NtlmClient {
// 100 nanosecond ticks since midnight Jan 01, 1601 (UTC). If the server does
// not send a timestamp, the client timestamp is used in the Proof Input
// instead.
// |client_challenge| must contain 8 bytes of random data.
// |server_challenge_message| is the full content of the challenge message
// sent by the server.
//
// [1] - https://technet.microsoft.com/en-us/library/jj852267(v=ws.11).aspx
Buffer GenerateAuthenticateMessage(
std::vector<uint8_t> GenerateAuthenticateMessage(
const base::string16& domain,
const base::string16& username,
const base::string16& password,
......@@ -94,19 +94,19 @@ class NET_EXPORT_PRIVATE NtlmClient {
const std::string& channel_bindings,
const std::string& spn,
uint64_t client_time,
const uint8_t* client_challenge,
const Buffer& server_challenge_message) const;
base::span<const uint8_t, kChallengeLen> client_challenge,
base::span<const uint8_t> server_challenge_message) const;
// Simplified method for NTLMv1 which does not require |channel_bindings|,
// |spn|, or |client_time|. See |GenerateAuthenticateMessage| for more
// details.
Buffer GenerateAuthenticateMessageV1(
std::vector<uint8_t> GenerateAuthenticateMessageV1(
const base::string16& domain,
const base::string16& username,
const base::string16& password,
const std::string& hostname,
const uint8_t* client_challenge,
const Buffer& server_challenge_message) const {
base::span<const uint8_t, 8> client_challenge,
base::span<const uint8_t> server_challenge_message) const {
DCHECK(!IsNtlmV2());
return GenerateAuthenticateMessage(
......@@ -150,7 +150,7 @@ class NET_EXPORT_PRIVATE NtlmClient {
NtlmFeatures features_;
NegotiateFlags negotiate_flags_;
Buffer negotiate_message_;
std::vector<uint8_t> negotiate_message_;
DISALLOW_COPY_AND_ASSIGN(NtlmClient);
};
......
......@@ -5,6 +5,7 @@
#include <stddef.h>
#include <stdint.h>
#include "base/containers/span.h"
#include "base/test/fuzzed_data_provider.h"
#include "net/ntlm/ntlm_client.h"
#include "net/ntlm/ntlm_test_data.h"
......@@ -21,8 +22,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
base::FuzzedDataProvider fdp(data, size);
bool is_v2 = fdp.ConsumeBool();
uint64_t client_time =
((uint64_t)fdp.ConsumeUint32InRange(0, 0xffffffffu) << 32) |
(uint64_t)fdp.ConsumeUint32InRange(0, 0xffffffffu);
(static_cast<uint64_t>(fdp.ConsumeUint32InRange(0, 0xffffffffu)) << 32) |
static_cast<uint64_t>(fdp.ConsumeUint32InRange(0, 0xffffffffu));
net::ntlm::NtlmClient client((net::ntlm::NtlmFeatures(is_v2)));
// Generate the input strings and challenge message. The strings will have a
......@@ -44,8 +45,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
client.GenerateAuthenticateMessage(
domain, username, password, hostname, channel_bindings, spn, client_time,
net::ntlm::test::kClientChallenge,
net::ntlm::Buffer(
reinterpret_cast<const uint8_t*>(challenge_msg_bytes.data()),
challenge_msg_bytes.size()));
base::as_bytes(base::make_span(challenge_msg_bytes)));
return 0;
}
This diff is collapsed.
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/ntlm/ntlm_constants.h"
namespace net {
namespace ntlm {
AvPair::AvPair() = default;
AvPair::AvPair(TargetInfoAvId avid, uint16_t avlen)
: avid(avid), avlen(avlen) {}
AvPair::AvPair(TargetInfoAvId avid, std::vector<uint8_t> buffer)
: buffer(std::move(buffer)), avid(avid) {
avlen = this->buffer.size();
}
AvPair::AvPair(const AvPair& other) = default;
AvPair::AvPair(AvPair&& other) = default;
AvPair::~AvPair() = default;
AvPair& AvPair::operator=(const AvPair& other) = default;
AvPair& AvPair::operator=(AvPair&& other) = default;
} // namespace ntlm
} // namespace net
......@@ -7,17 +7,15 @@
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <type_traits>
#include "base/macros.h"
#include <vector>
#include "base/stl_util.h"
#include "net/base/net_export.h"
namespace net {
namespace ntlm {
using Buffer = std::basic_string<uint8_t>;
// A security buffer is a structure within an NTLM message that indicates
// the offset from the beginning of the message and the length of a payload
// that occurs later in the message. Within the raw message there is also
......@@ -134,15 +132,18 @@ constexpr inline TargetInfoAvFlags operator&(TargetInfoAvFlags lhs,
// other AvPairs the value of these 2 fields is undefined and the payload
// is in the |buffer| field. For these fields the payload is copied verbatim
// and it's content is not read or validated in any way.
struct AvPair {
AvPair() {}
AvPair(TargetInfoAvId avid, uint16_t avlen) : avid(avid), avlen(avlen) {}
AvPair(TargetInfoAvId avid, Buffer buffer)
: buffer(std::move(buffer)), avid(avid) {
avlen = this->buffer.size();
}
Buffer buffer;
struct NET_EXPORT_PRIVATE AvPair {
AvPair();
AvPair(TargetInfoAvId avid, uint16_t avlen);
AvPair(TargetInfoAvId avid, std::vector<uint8_t> buffer);
AvPair(const AvPair& other);
AvPair(AvPair&& other);
~AvPair();
AvPair& operator=(const AvPair& other);
AvPair& operator=(AvPair&& other);
std::vector<uint8_t> buffer;
uint64_t timestamp;
TargetInfoAvFlags flags;
TargetInfoAvId avid;
......@@ -150,7 +151,7 @@ struct AvPair {
};
static constexpr uint8_t kSignature[] = "NTLMSSP";
static constexpr size_t kSignatureLen = arraysize(kSignature);
static constexpr size_t kSignatureLen = base::size(kSignature);
static constexpr uint16_t kProofInputVersionV2 = 0x0101;
static constexpr size_t kSecurityBufferLen =
(2 * sizeof(uint16_t)) + sizeof(uint32_t);
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment