Commit 34ac4b2f authored by Ryan Sleevi's avatar Ryan Sleevi Committed by Commit Bot

Switch //net/ntlm to use base::span<const uint8_t>

//net/ntlm makes extensive use of uint8_t* pointers with various
fixed size requirements. Convert these into defined-extent
base::span<>'s so that the compiler can do more work for us to
ensure that all the buffers are appropriately sized.

This also switches uses of std::basic_string<uint8_t> and
base::StringPiece into std::vector<uint8_t> / base::span<uint8_t>,
as these are all working on 'data' buffers.

Bug: 837308
Change-Id: Iae3f1933b2c948d77c841e6f9bc8ce04bfddc67d
Reviewed-on: https://chromium-review.googlesource.com/1056013
Commit-Queue: Ryan Sleevi <rsleevi@chromium.org>
Reviewed-by: default avatarAsanka Herath <asanka@chromium.org>
Cr-Commit-Position: refs/heads/master@{#558579}
parent 3e49d5cb
...@@ -921,6 +921,7 @@ component("net") { ...@@ -921,6 +921,7 @@ component("net") {
"ntlm/ntlm_buffer_writer.h", "ntlm/ntlm_buffer_writer.h",
"ntlm/ntlm_client.cc", "ntlm/ntlm_client.cc",
"ntlm/ntlm_client.h", "ntlm/ntlm_client.h",
"ntlm/ntlm_constants.cc",
"ntlm/ntlm_constants.h", "ntlm/ntlm_constants.h",
"proxy_resolution/dhcp_pac_file_adapter_fetcher_win.cc", "proxy_resolution/dhcp_pac_file_adapter_fetcher_win.cc",
"proxy_resolution/dhcp_pac_file_adapter_fetcher_win.h", "proxy_resolution/dhcp_pac_file_adapter_fetcher_win.h",
......
...@@ -80,9 +80,8 @@ int HttpAuthHandlerNTLM::GenerateAuthTokenImpl( ...@@ -80,9 +80,8 @@ int HttpAuthHandlerNTLM::GenerateAuthTokenImpl(
} }
} }
ntlm::Buffer next_token = GetNextToken( std::vector<uint8_t> next_token =
ntlm::Buffer(reinterpret_cast<const uint8_t*>(decoded_auth_data.data()), GetNextToken(base::as_bytes(base::make_span(decoded_auth_data)));
decoded_auth_data.size()));
if (next_token.empty()) if (next_token.empty())
return ERR_UNEXPECTED; return ERR_UNEXPECTED;
......
...@@ -28,7 +28,9 @@ ...@@ -28,7 +28,9 @@
#endif #endif
#include <string> #include <string>
#include <vector>
#include "base/containers/span.h"
#include "base/strings/string16.h" #include "base/strings/string16.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/http/http_auth_handler.h" #include "net/http/http_auth_handler.h"
...@@ -149,7 +151,7 @@ class NET_EXPORT_PRIVATE HttpAuthHandlerNTLM : public HttpAuthHandler { ...@@ -149,7 +151,7 @@ class NET_EXPORT_PRIVATE HttpAuthHandlerNTLM : public HttpAuthHandler {
// Given an input token received from the server, generate the next output // Given an input token received from the server, generate the next output
// token to be sent to the server. // token to be sent to the server.
ntlm::Buffer GetNextToken(const ntlm::Buffer& in_token); std::vector<uint8_t> GetNextToken(base::span<const uint8_t> in_token);
#endif #endif
// Parse the challenge, saving the results into this instance. // Parse the challenge, saving the results into this instance.
......
...@@ -89,7 +89,8 @@ HttpAuthHandlerNTLM::Factory::Factory() = default; ...@@ -89,7 +89,8 @@ HttpAuthHandlerNTLM::Factory::Factory() = default;
HttpAuthHandlerNTLM::Factory::~Factory() = default; HttpAuthHandlerNTLM::Factory::~Factory() = default;
ntlm::Buffer HttpAuthHandlerNTLM::GetNextToken(const ntlm::Buffer& in_token) { std::vector<uint8_t> HttpAuthHandlerNTLM::GetNextToken(
base::span<const uint8_t> in_token) {
// If in_token is non-empty, then assume it contains a challenge message, // If in_token is non-empty, then assume it contains a challenge message,
// and generate the Authenticate message in reply. Otherwise return the // and generate the Authenticate message in reply. Otherwise return the
// Negotiate message. // Negotiate message.
...@@ -99,7 +100,7 @@ ntlm::Buffer HttpAuthHandlerNTLM::GetNextToken(const ntlm::Buffer& in_token) { ...@@ -99,7 +100,7 @@ ntlm::Buffer HttpAuthHandlerNTLM::GetNextToken(const ntlm::Buffer& in_token) {
std::string hostname = get_host_name_proc_(); std::string hostname = get_host_name_proc_();
if (hostname.empty()) if (hostname.empty())
return ntlm::Buffer(); return {};
uint8_t client_challenge[8]; uint8_t client_challenge[8];
generate_random_proc_(client_challenge, 8); generate_random_proc_(client_challenge, 8);
uint64_t client_time = get_ms_time_proc_(); uint64_t client_time = get_ms_time_proc_();
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
#include <string> #include <string>
#include "base/base64.h" #include "base/base64.h"
#include "base/containers/span.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
#include "build/build_config.h"
#include "net/base/test_completion_callback.h" #include "net/base/test_completion_callback.h"
#include "net/dns/mock_host_resolver.h" #include "net/dns/mock_host_resolver.h"
#include "net/http/http_auth_challenge_tokenizer.h" #include "net/http/http_auth_challenge_tokenizer.h"
...@@ -52,19 +54,16 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest { ...@@ -52,19 +54,16 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
&auth_handler_); &auth_handler_);
} }
std::string CreateNtlmAuthHeader(ntlm::Buffer message) { std::string CreateNtlmAuthHeader(base::span<const uint8_t> buffer) {
std::string output; std::string output;
base::Base64Encode( base::Base64Encode(
base::StringPiece(reinterpret_cast<const char*>(message.data()), base::StringPiece(reinterpret_cast<const char*>(buffer.data()),
message.size()), buffer.size()),
&output); &output);
return "NTLM " + output; return "NTLM " + output;
} }
std::string CreateNtlmAuthHeader(const uint8_t* buffer, size_t length) {
return CreateNtlmAuthHeader(ntlm::Buffer(buffer, length));
}
HttpAuth::AuthorizationResult HandleAnotherChallenge( HttpAuth::AuthorizationResult HandleAnotherChallenge(
const std::string& challenge) { const std::string& challenge) {
...@@ -85,10 +84,10 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest { ...@@ -85,10 +84,10 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
} }
bool ReadBytesPayload(ntlm::NtlmBufferReader* reader, bool ReadBytesPayload(ntlm::NtlmBufferReader* reader,
uint8_t* buffer, base::span<uint8_t> buffer) {
size_t len) {
ntlm::SecurityBuffer sec_buf; ntlm::SecurityBuffer sec_buf;
return reader->ReadSecurityBuffer(&sec_buf) && (sec_buf.length == len) && return reader->ReadSecurityBuffer(&sec_buf) &&
(sec_buf.length == buffer.size()) &&
reader->ReadBytesFrom(sec_buf, buffer); reader->ReadBytesFrom(sec_buf, buffer);
} }
...@@ -99,11 +98,13 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest { ...@@ -99,11 +98,13 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
if (!reader->ReadSecurityBuffer(&sec_buf)) if (!reader->ReadSecurityBuffer(&sec_buf))
return false; return false;
std::unique_ptr<uint8_t[]> raw(new uint8_t[sec_buf.length]); if (!reader->ReadBytesFrom(
if (!reader->ReadBytesFrom(sec_buf, raw.get())) sec_buf,
base::as_writable_bytes(base::make_span(
base::WriteInto(str, sec_buf.length + 1), sec_buf.length)))) {
return false; return false;
}
str->assign(reinterpret_cast<const char*>(raw.get()), sec_buf.length);
return true; return true;
} }
...@@ -116,17 +117,17 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest { ...@@ -116,17 +117,17 @@ class HttpAuthHandlerNtlmPortableTest : public PlatformTest {
EXPECT_TRUE(reader->ReadSecurityBuffer(&sec_buf)); EXPECT_TRUE(reader->ReadSecurityBuffer(&sec_buf));
EXPECT_EQ(0, sec_buf.length % 2); EXPECT_EQ(0, sec_buf.length % 2);
std::unique_ptr<uint8_t[]> raw(new uint8_t[sec_buf.length]); std::vector<uint8_t> raw(sec_buf.length);
EXPECT_TRUE(reader->ReadBytesFrom(sec_buf, raw.get())); EXPECT_TRUE(reader->ReadBytesFrom(sec_buf, raw));
#ifdef IS_BIG_ENDIAN #if defined(ARCH_CPU_BIG_ENDIAN)
for (size_t i = 0; i < sec_buf.length; i += 2) { for (size_t i = 0; i < raw.size(); i += 2) {
std::swap(raw[i], raw[i + 1]); std::swap(raw[i], raw[i + 1]);
} }
#endif #endif
str->assign(reinterpret_cast<const base::char16*>(raw.get()), str->assign(reinterpret_cast<const base::char16*>(raw.data()),
sec_buf.length / 2); raw.size() / 2);
} }
int GetGenerateAuthTokenResult() { int GetGenerateAuthTokenResult() {
...@@ -223,8 +224,7 @@ TEST_F(HttpAuthHandlerNtlmPortableTest, NtlmV1AuthenticationSuccess) { ...@@ -223,8 +224,7 @@ TEST_F(HttpAuthHandlerNtlmPortableTest, NtlmV1AuthenticationSuccess) {
std::string token; std::string token;
ASSERT_EQ(HttpAuth::AUTHORIZATION_RESULT_ACCEPT, ASSERT_EQ(HttpAuth::AUTHORIZATION_RESULT_ACCEPT,
HandleAnotherChallenge( HandleAnotherChallenge(
CreateNtlmAuthHeader(ntlm::test::kChallengeMsgV1, CreateNtlmAuthHeader(ntlm::test::kChallengeMsgV1)));
arraysize(ntlm::test::kChallengeMsgV1))));
ASSERT_EQ(OK, GenerateAuthToken(&token)); ASSERT_EQ(OK, GenerateAuthToken(&token));
// Validate the authenticate message // Validate the authenticate message
......
This diff is collapsed.
This diff is collapsed.
...@@ -11,19 +11,10 @@ ...@@ -11,19 +11,10 @@
namespace net { namespace net {
namespace ntlm { namespace ntlm {
NtlmBufferReader::NtlmBufferReader() : NtlmBufferReader(nullptr, 0) {} NtlmBufferReader::NtlmBufferReader() {}
NtlmBufferReader::NtlmBufferReader(const Buffer& buffer) NtlmBufferReader::NtlmBufferReader(base::span<const uint8_t> buffer)
: NtlmBufferReader( : buffer_(buffer) {}
base::StringPiece(reinterpret_cast<const char*>(buffer.data()),
buffer.length())) {}
NtlmBufferReader::NtlmBufferReader(base::StringPiece str)
: buffer_(str), cursor_(0) {}
NtlmBufferReader::NtlmBufferReader(const uint8_t* ptr, size_t len)
: NtlmBufferReader(
base::StringPiece(reinterpret_cast<const char*>(ptr), len)) {}
NtlmBufferReader::~NtlmBufferReader() = default; NtlmBufferReader::~NtlmBufferReader() = default;
...@@ -59,25 +50,22 @@ bool NtlmBufferReader::ReadFlags(NegotiateFlags* flags) { ...@@ -59,25 +50,22 @@ bool NtlmBufferReader::ReadFlags(NegotiateFlags* flags) {
return true; return true;
} }
bool NtlmBufferReader::ReadBytes(uint8_t* buffer, size_t len) { bool NtlmBufferReader::ReadBytes(base::span<uint8_t> buffer) {
if (!CanRead(len)) if (!CanRead(buffer.size()))
return false; return false;
memcpy(reinterpret_cast<void*>(buffer), memcpy(buffer.data(), GetBufferAtCursor(), buffer.size());
reinterpret_cast<const void*>(GetBufferAtCursor()), len);
AdvanceCursor(len); AdvanceCursor(buffer.size());
return true; return true;
} }
bool NtlmBufferReader::ReadBytesFrom(const SecurityBuffer& sec_buf, bool NtlmBufferReader::ReadBytesFrom(const SecurityBuffer& sec_buf,
uint8_t* buffer) { base::span<uint8_t> buffer) {
if (!CanReadFrom(sec_buf)) if (!CanReadFrom(sec_buf) || buffer.size() < sec_buf.length)
return false; return false;
memcpy(reinterpret_cast<void*>(buffer), memcpy(buffer.data(), GetBufferPtr() + sec_buf.offset, sec_buf.length);
reinterpret_cast<const void*>(GetBufferPtr() + sec_buf.offset),
sec_buf.length);
return true; return true;
} }
...@@ -87,7 +75,8 @@ bool NtlmBufferReader::ReadPayloadAsBufferReader(const SecurityBuffer& sec_buf, ...@@ -87,7 +75,8 @@ bool NtlmBufferReader::ReadPayloadAsBufferReader(const SecurityBuffer& sec_buf,
if (!CanReadFrom(sec_buf)) if (!CanReadFrom(sec_buf))
return false; return false;
*reader = NtlmBufferReader(GetBufferPtr() + sec_buf.offset, sec_buf.length); *reader = NtlmBufferReader(
base::make_span(GetBufferPtr() + sec_buf.offset, sec_buf.length));
return true; return true;
} }
...@@ -139,7 +128,7 @@ bool NtlmBufferReader::ReadTargetInfo(size_t target_info_len, ...@@ -139,7 +128,7 @@ bool NtlmBufferReader::ReadTargetInfo(size_t target_info_len,
return false; return false;
// Take a copy of the payload in the AVPair. // Take a copy of the payload in the AVPair.
pair.buffer.assign(GetBufferAtCursor(), pair.avlen); pair.buffer.assign(GetBufferAtCursor(), GetBufferAtCursor() + pair.avlen);
if (pair.avid == TargetInfoAvId::kEol) { if (pair.avid == TargetInfoAvId::kEol) {
// Terminator must have zero length. // Terminator must have zero length.
if (pair.avlen != 0) if (pair.avlen != 0)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "base/strings/string_piece.h" #include "base/containers/span.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
#include "net/ntlm/ntlm_constants.h" #include "net/ntlm/ntlm_constants.h"
...@@ -49,15 +49,11 @@ class NET_EXPORT_PRIVATE NtlmBufferReader { ...@@ -49,15 +49,11 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
public: public:
NtlmBufferReader(); NtlmBufferReader();
// |buffer| is not copied and must outlive the |NtlmBufferReader|. // |buffer| is not copied and must outlive the |NtlmBufferReader|.
explicit NtlmBufferReader(const Buffer& buffer); explicit NtlmBufferReader(base::span<const uint8_t> buffer);
explicit NtlmBufferReader(base::StringPiece buffer);
// This class does not take ownership of |ptr|, so the caller must ensure
// that the buffer outlives the |NtlmBufferReader|.
NtlmBufferReader(const uint8_t* ptr, size_t len);
~NtlmBufferReader(); ~NtlmBufferReader();
size_t GetLength() const { return buffer_.length(); } size_t GetLength() const { return buffer_.size(); }
size_t GetCursor() const { return cursor_; } size_t GetCursor() const { return cursor_; }
bool IsEndOfBuffer() const { return cursor_ >= GetLength(); } bool IsEndOfBuffer() const { return cursor_ >= GetLength(); }
...@@ -92,14 +88,14 @@ class NET_EXPORT_PRIVATE NtlmBufferReader { ...@@ -92,14 +88,14 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
bool ReadFlags(NegotiateFlags* flags) WARN_UNUSED_RESULT; bool ReadFlags(NegotiateFlags* flags) WARN_UNUSED_RESULT;
// Reads |len| bytes and copies them into |buffer|. // Reads |len| bytes and copies them into |buffer|.
bool ReadBytes(uint8_t* buffer, size_t len) WARN_UNUSED_RESULT; bool ReadBytes(base::span<uint8_t> buffer) WARN_UNUSED_RESULT;
// Reads |sec_buf.length| bytes from offset |sec_buf.offset| and copies them // Reads |sec_buf.length| bytes from offset |sec_buf.offset| and copies them
// into |buffer|. If the security buffer specifies a payload outside the // into |buffer|. If the security buffer specifies a payload outside the
// buffer, then the call fails. Unlike the other Read* methods, this does // buffer, then the call fails. Unlike the other Read* methods, this does
// not move the cursor. // not move the cursor.
bool ReadBytesFrom(const SecurityBuffer& sec_buf, bool ReadBytesFrom(const SecurityBuffer& sec_buf,
uint8_t* buffer) WARN_UNUSED_RESULT; base::span<uint8_t> buffer) WARN_UNUSED_RESULT;
// Reads |sec_buf.length| bytes from offset |sec_buf.offset| and assigns // Reads |sec_buf.length| bytes from offset |sec_buf.offset| and assigns
// |reader| an |NtlmBufferReader| representing the payload. If the security // |reader| an |NtlmBufferReader| representing the payload. If the security
...@@ -207,9 +203,7 @@ class NET_EXPORT_PRIVATE NtlmBufferReader { ...@@ -207,9 +203,7 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); } void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); }
// Returns a constant pointer to the start of the buffer. // Returns a constant pointer to the start of the buffer.
const uint8_t* GetBufferPtr() const { const uint8_t* GetBufferPtr() const { return buffer_.data(); }
return reinterpret_cast<const uint8_t*>(buffer_.data());
}
// Returns a pointer to the underlying buffer at the current cursor // Returns a pointer to the underlying buffer at the current cursor
// position. // position.
...@@ -221,8 +215,8 @@ class NET_EXPORT_PRIVATE NtlmBufferReader { ...@@ -221,8 +215,8 @@ class NET_EXPORT_PRIVATE NtlmBufferReader {
return *(GetBufferAtCursor()); return *(GetBufferAtCursor());
} }
base::StringPiece buffer_; base::span<const uint8_t> buffer_;
size_t cursor_; size_t cursor_ = 0;
}; };
} // namespace ntlm } // namespace ntlm
......
This diff is collapsed.
...@@ -48,21 +48,15 @@ bool NtlmBufferWriter::WriteFlags(NegotiateFlags flags) { ...@@ -48,21 +48,15 @@ bool NtlmBufferWriter::WriteFlags(NegotiateFlags flags) {
return WriteUInt32(static_cast<uint32_t>(flags)); return WriteUInt32(static_cast<uint32_t>(flags));
} }
bool NtlmBufferWriter::WriteBytes(const uint8_t* buffer, size_t len) { bool NtlmBufferWriter::WriteBytes(base::span<const uint8_t> bytes) {
if (!CanWrite(len)) if (!CanWrite(bytes.size()))
return false; return false;
memcpy(reinterpret_cast<void*>(GetBufferPtrAtCursor()), memcpy(GetBufferPtrAtCursor(), bytes.data(), bytes.size());
reinterpret_cast<const void*>(buffer), len); AdvanceCursor(bytes.size());
AdvanceCursor(len);
return true; return true;
} }
bool NtlmBufferWriter::WriteBytes(const Buffer& bytes) {
return WriteBytes(bytes.data(), bytes.length());
}
bool NtlmBufferWriter::WriteZeros(size_t count) { bool NtlmBufferWriter::WriteZeros(size_t count) {
if (!CanWrite(count)) if (!CanWrite(count))
return false; return false;
...@@ -105,8 +99,7 @@ bool NtlmBufferWriter::WriteAvPair(const AvPair& pair) { ...@@ -105,8 +99,7 @@ bool NtlmBufferWriter::WriteAvPair(const AvPair& pair) {
} }
bool NtlmBufferWriter::WriteUtf8String(const std::string& str) { bool NtlmBufferWriter::WriteUtf8String(const std::string& str) {
return WriteBytes(reinterpret_cast<const uint8_t*>(str.c_str()), return WriteBytes(base::as_bytes(base::make_span(str)));
str.length());
} }
bool NtlmBufferWriter::WriteUtf16AsUtf8String(const base::string16& str) { bool NtlmBufferWriter::WriteUtf16AsUtf8String(const base::string16& str) {
...@@ -120,7 +113,7 @@ bool NtlmBufferWriter::WriteUtf8AsUtf16String(const std::string& str) { ...@@ -120,7 +113,7 @@ bool NtlmBufferWriter::WriteUtf8AsUtf16String(const std::string& str) {
} }
bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) { bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) {
size_t num_bytes = str.length() * 2; size_t num_bytes = str.size() * 2;
if (!CanWrite(num_bytes)) if (!CanWrite(num_bytes))
return false; return false;
...@@ -142,7 +135,7 @@ bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) { ...@@ -142,7 +135,7 @@ bool NtlmBufferWriter::WriteUtf16String(const base::string16& str) {
} }
bool NtlmBufferWriter::WriteSignature() { bool NtlmBufferWriter::WriteSignature() {
return WriteBytes(kSignature, kSignatureLen); return WriteBytes(kSignature);
} }
bool NtlmBufferWriter::WriteMessageType(MessageType message_type) { bool NtlmBufferWriter::WriteMessageType(MessageType message_type) {
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "base/containers/span.h"
#include "base/strings/string16.h" #include "base/strings/string16.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
...@@ -49,8 +50,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter { ...@@ -49,8 +50,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
size_t GetLength() const { return buffer_.size(); } size_t GetLength() const { return buffer_.size(); }
size_t GetCursor() const { return cursor_; } size_t GetCursor() const { return cursor_; }
bool IsEndOfBuffer() const { return cursor_ >= GetLength(); } bool IsEndOfBuffer() const { return cursor_ >= GetLength(); }
const Buffer& GetBuffer() const { return buffer_; } base::span<const uint8_t> GetBuffer() const { return buffer_; }
Buffer Pass() const { return std::move(buffer_); } std::vector<uint8_t> Pass() const { return std::move(buffer_); }
// Returns true if there are |len| more bytes between the current cursor // Returns true if there are |len| more bytes between the current cursor
// position and the end of the buffer. // position and the end of the buffer.
...@@ -71,13 +72,9 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter { ...@@ -71,13 +72,9 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
// Writes flags as a 32 bit unsigned value (little endian). // Writes flags as a 32 bit unsigned value (little endian).
bool WriteFlags(NegotiateFlags flags) WARN_UNUSED_RESULT; bool WriteFlags(NegotiateFlags flags) WARN_UNUSED_RESULT;
// Writes |len| bytes from |buffer|. If there are not |len| more bytes in // Writes the bytes from the |buffer|. If there are not enough
// the buffer, it returns false.
bool WriteBytes(const uint8_t* buffer, size_t len) WARN_UNUSED_RESULT;
// Writes the bytes from the |Buffer|. If there are not enough
// bytes in the buffer, it returns false. // bytes in the buffer, it returns false.
bool WriteBytes(const Buffer& buffer) WARN_UNUSED_RESULT; bool WriteBytes(base::span<const uint8_t> buffer) WARN_UNUSED_RESULT;
// Writes |count| bytes of zeros to the buffer. If there are not |count| // Writes |count| bytes of zeros to the buffer. If there are not |count|
// more bytes in available in the buffer, it returns false. // more bytes in available in the buffer, it returns false.
...@@ -179,8 +176,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter { ...@@ -179,8 +176,8 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); } void AdvanceCursor(size_t count) { SetCursor(GetCursor() + count); }
// Returns a pointer to the start of the buffer. // Returns a pointer to the start of the buffer.
const uint8_t* GetBufferPtr() const { return &buffer_[0]; } const uint8_t* GetBufferPtr() const { return buffer_.data(); }
uint8_t* GetBufferPtr() { return &buffer_[0]; } uint8_t* GetBufferPtr() { return buffer_.data(); }
// Returns pointer into the buffer at the current cursor location. // Returns pointer into the buffer at the current cursor location.
const uint8_t* GetBufferPtrAtCursor() const { const uint8_t* GetBufferPtrAtCursor() const {
...@@ -188,7 +185,7 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter { ...@@ -188,7 +185,7 @@ class NET_EXPORT_PRIVATE NtlmBufferWriter {
} }
uint8_t* GetBufferPtrAtCursor() { return GetBufferPtr() + GetCursor(); } uint8_t* GetBufferPtrAtCursor() { return GetBufferPtr() + GetCursor(); }
Buffer buffer_; std::vector<uint8_t> buffer_;
size_t cursor_; size_t cursor_;
DISALLOW_COPY_AND_ASSIGN(NtlmBufferWriter); DISALLOW_COPY_AND_ASSIGN(NtlmBufferWriter);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "net/ntlm/ntlm_buffer_writer.h" #include "net/ntlm/ntlm_buffer_writer.h"
#include "base/macros.h" #include "base/stl_util.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -30,7 +30,7 @@ TEST(NtlmBufferWriterTest, Initialization) { ...@@ -30,7 +30,7 @@ TEST(NtlmBufferWriterTest, Initialization) {
NtlmBufferWriter writer(1); NtlmBufferWriter writer(1);
ASSERT_EQ(1u, writer.GetLength()); ASSERT_EQ(1u, writer.GetLength());
ASSERT_EQ(1u, writer.GetBuffer().length()); ASSERT_EQ(1u, writer.GetBuffer().size());
ASSERT_EQ(0u, writer.GetCursor()); ASSERT_EQ(0u, writer.GetCursor());
ASSERT_FALSE(writer.IsEndOfBuffer()); ASSERT_FALSE(writer.IsEndOfBuffer());
ASSERT_TRUE(writer.CanWrite(1)); ASSERT_TRUE(writer.CanWrite(1));
...@@ -45,11 +45,11 @@ TEST(NtlmBufferWriterTest, Write16) { ...@@ -45,11 +45,11 @@ TEST(NtlmBufferWriterTest, Write16) {
ASSERT_TRUE(writer.WriteUInt16(value)); ASSERT_TRUE(writer.WriteUInt16(value));
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(arraysize(expected), writer.GetLength()); ASSERT_EQ(base::size(expected), writer.GetLength());
ASSERT_FALSE(writer.WriteUInt16(value)); ASSERT_FALSE(writer.WriteUInt16(value));
ASSERT_EQ(0, ASSERT_EQ(0,
memcmp(expected, writer.GetBuffer().data(), arraysize(expected))); memcmp(expected, writer.GetBuffer().data(), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, Write16PastEob) { TEST(NtlmBufferWriterTest, Write16PastEob) {
...@@ -69,7 +69,7 @@ TEST(NtlmBufferWriterTest, Write32) { ...@@ -69,7 +69,7 @@ TEST(NtlmBufferWriterTest, Write32) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUInt32(value)); ASSERT_FALSE(writer.WriteUInt32(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, Write32PastEob) { TEST(NtlmBufferWriterTest, Write32PastEob) {
...@@ -89,7 +89,7 @@ TEST(NtlmBufferWriterTest, Write64) { ...@@ -89,7 +89,7 @@ TEST(NtlmBufferWriterTest, Write64) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUInt64(value)); ASSERT_FALSE(writer.WriteUInt64(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, Write64PastEob) { TEST(NtlmBufferWriterTest, Write64PastEob) {
...@@ -102,22 +102,22 @@ TEST(NtlmBufferWriterTest, Write64PastEob) { ...@@ -102,22 +102,22 @@ TEST(NtlmBufferWriterTest, Write64PastEob) {
TEST(NtlmBufferWriterTest, WriteBytes) { TEST(NtlmBufferWriterTest, WriteBytes) {
uint8_t expected[8] = {0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11}; uint8_t expected[8] = {0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11};
NtlmBufferWriter writer(arraysize(expected)); NtlmBufferWriter writer(base::size(expected));
ASSERT_TRUE(writer.WriteBytes(expected, arraysize(expected))); ASSERT_TRUE(writer.WriteBytes(expected));
ASSERT_EQ(0, memcmp(GetBufferPtr(writer), expected, arraysize(expected))); ASSERT_EQ(0, memcmp(GetBufferPtr(writer), expected, base::size(expected)));
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteBytes(expected, 1)); ASSERT_FALSE(writer.WriteBytes(base::make_span(expected, 1)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteBytesPastEob) { TEST(NtlmBufferWriterTest, WriteBytesPastEob) {
uint8_t buffer[8]; uint8_t buffer[8];
NtlmBufferWriter writer(arraysize(buffer) - 1); NtlmBufferWriter writer(base::size(buffer) - 1);
ASSERT_FALSE(writer.WriteBytes(buffer, arraysize(buffer))); ASSERT_FALSE(writer.WriteBytes(buffer));
} }
TEST(NtlmBufferWriterTest, WriteSecurityBuffer) { TEST(NtlmBufferWriterTest, WriteSecurityBuffer) {
...@@ -131,7 +131,7 @@ TEST(NtlmBufferWriterTest, WriteSecurityBuffer) { ...@@ -131,7 +131,7 @@ TEST(NtlmBufferWriterTest, WriteSecurityBuffer) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteSecurityBuffer(SecurityBuffer(offset, length))); ASSERT_FALSE(writer.WriteSecurityBuffer(SecurityBuffer(offset, length)));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteSecurityBufferPastEob) { TEST(NtlmBufferWriterTest, WriteSecurityBufferPastEob) {
...@@ -151,7 +151,7 @@ TEST(NtlmBufferWriterTest, WriteNarrowString) { ...@@ -151,7 +151,7 @@ TEST(NtlmBufferWriterTest, WriteNarrowString) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf8String(value)); ASSERT_FALSE(writer.WriteUtf8String(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteAsciiStringPastEob) { TEST(NtlmBufferWriterTest, WriteAsciiStringPastEob) {
...@@ -172,7 +172,7 @@ TEST(NtlmBufferWriterTest, WriteUtf16String) { ...@@ -172,7 +172,7 @@ TEST(NtlmBufferWriterTest, WriteUtf16String) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf16String(value)); ASSERT_FALSE(writer.WriteUtf16String(value));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteUtf16StringPastEob) { TEST(NtlmBufferWriterTest, WriteUtf16StringPastEob) {
...@@ -193,7 +193,7 @@ TEST(NtlmBufferWriterTest, WriteUtf8AsUtf16String) { ...@@ -193,7 +193,7 @@ TEST(NtlmBufferWriterTest, WriteUtf8AsUtf16String) {
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_FALSE(writer.WriteUtf8AsUtf16String(input)); ASSERT_FALSE(writer.WriteUtf8AsUtf16String(input));
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteSignature) { TEST(NtlmBufferWriterTest, WriteSignature) {
...@@ -203,7 +203,7 @@ TEST(NtlmBufferWriterTest, WriteSignature) { ...@@ -203,7 +203,7 @@ TEST(NtlmBufferWriterTest, WriteSignature) {
ASSERT_TRUE(writer.WriteSignature()); ASSERT_TRUE(writer.WriteSignature());
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteSignaturePastEob) { TEST(NtlmBufferWriterTest, WriteSignaturePastEob) {
...@@ -232,12 +232,12 @@ TEST(NtlmBufferWriterTest, WriteMessageTypePastEob) { ...@@ -232,12 +232,12 @@ TEST(NtlmBufferWriterTest, WriteMessageTypePastEob) {
TEST(NtlmBufferWriterTest, WriteAvPairHeader) { TEST(NtlmBufferWriterTest, WriteAvPairHeader) {
const uint8_t expected[4] = {0x06, 0x00, 0x11, 0x22}; const uint8_t expected[4] = {0x06, 0x00, 0x11, 0x22};
NtlmBufferWriter writer(4); NtlmBufferWriter writer(base::size(expected));
ASSERT_TRUE(writer.WriteAvPairHeader(TargetInfoAvId::kFlags, 0x2211)); ASSERT_TRUE(writer.WriteAvPairHeader(TargetInfoAvId::kFlags, 0x2211));
ASSERT_TRUE(writer.IsEndOfBuffer()); ASSERT_TRUE(writer.IsEndOfBuffer());
ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), arraysize(expected))); ASSERT_EQ(0, memcmp(expected, GetBufferPtr(writer), base::size(expected)));
} }
TEST(NtlmBufferWriterTest, WriteAvPairHeaderPastEob) { TEST(NtlmBufferWriterTest, WriteAvPairHeaderPastEob) {
......
...@@ -19,31 +19,31 @@ namespace ntlm { ...@@ -19,31 +19,31 @@ namespace ntlm {
namespace { namespace {
// Parses the challenge message and returns the |challenge_flags| and // Parses the challenge message and returns the |challenge_flags| and
// |server_challenge| into the supplied buffer. // |server_challenge| into the supplied buffer.
// |server_challenge| must contain at least 8 bytes. bool ParseChallengeMessage(
bool ParseChallengeMessage(const Buffer& challenge_message, base::span<const uint8_t> challenge_message,
NegotiateFlags* challenge_flags, NegotiateFlags* challenge_flags,
uint8_t* server_challenge) { base::span<uint8_t, kChallengeLen> server_challenge) {
NtlmBufferReader challenge_reader(challenge_message); NtlmBufferReader challenge_reader(challenge_message);
return challenge_reader.MatchMessageHeader(MessageType::kChallenge) && return challenge_reader.MatchMessageHeader(MessageType::kChallenge) &&
challenge_reader.SkipSecurityBufferWithValidation() && challenge_reader.SkipSecurityBufferWithValidation() &&
challenge_reader.ReadFlags(challenge_flags) && challenge_reader.ReadFlags(challenge_flags) &&
challenge_reader.ReadBytes(server_challenge, kChallengeLen); challenge_reader.ReadBytes(server_challenge);
} }
// Parses the challenge message and extracts the information necessary to // Parses the challenge message and extracts the information necessary to
// make an NTLMv2 response. // make an NTLMv2 response.
// |server_challenge| must contain at least 8 bytes. bool ParseChallengeMessageV2(
bool ParseChallengeMessageV2(const Buffer& challenge_message, base::span<const uint8_t> challenge_message,
NegotiateFlags* challenge_flags, NegotiateFlags* challenge_flags,
uint8_t* server_challenge, base::span<uint8_t, kChallengeLen> server_challenge,
std::vector<AvPair>* av_pairs) { std::vector<AvPair>* av_pairs) {
NtlmBufferReader challenge_reader(challenge_message); NtlmBufferReader challenge_reader(challenge_message);
return challenge_reader.MatchMessageHeader(MessageType::kChallenge) && return challenge_reader.MatchMessageHeader(MessageType::kChallenge) &&
challenge_reader.SkipSecurityBufferWithValidation() && challenge_reader.SkipSecurityBufferWithValidation() &&
challenge_reader.ReadFlags(challenge_flags) && challenge_reader.ReadFlags(challenge_flags) &&
challenge_reader.ReadBytes(server_challenge, kChallengeLen) && challenge_reader.ReadBytes(server_challenge) &&
challenge_reader.SkipBytes(8) && challenge_reader.SkipBytes(8) &&
// challenge_reader.ReadTargetInfoPayload(av_pairs); // challenge_reader.ReadTargetInfoPayload(av_pairs);
(((*challenge_flags & NegotiateFlags::kTargetInfo) == (((*challenge_flags & NegotiateFlags::kTargetInfo) ==
...@@ -71,27 +71,24 @@ bool WriteAuthenticateMessage(NtlmBufferWriter* authenticate_writer, ...@@ -71,27 +71,24 @@ bool WriteAuthenticateMessage(NtlmBufferWriter* authenticate_writer,
} }
// Writes the NTLMv1 LM Response and NTLM Response. // Writes the NTLMv1 LM Response and NTLM Response.
// |lm_response| must contain |kResponseLenV1| bytes. bool WriteResponsePayloads(
// |ntlm_response| must contain |kResponseLenV1| bytes. NtlmBufferWriter* authenticate_writer,
bool WriteResponsePayloads(NtlmBufferWriter* authenticate_writer, base::span<const uint8_t, kResponseLenV1> lm_response,
const uint8_t* lm_response, base::span<const uint8_t, kResponseLenV1> ntlm_response) {
const uint8_t* ntlm_response) { return authenticate_writer->WriteBytes(lm_response) &&
return authenticate_writer->WriteBytes(lm_response, kResponseLenV1) && authenticate_writer->WriteBytes(ntlm_response);
authenticate_writer->WriteBytes(ntlm_response, kResponseLenV1);
} }
// Writes the |lm_response| and writes the NTLMv2 response by concatenating // Writes the |lm_response| and writes the NTLMv2 response by concatenating
// |v2_proof|, |v2_proof_input|, |updated_target_info| and 4 zero bytes. // |v2_proof|, |v2_proof_input|, |updated_target_info| and 4 zero bytes.
// bool WriteResponsePayloadsV2(
// |lm_response| must contain |kResponseLenV1| bytes. NtlmBufferWriter* authenticate_writer,
// |v2_proof| must contain |kNtlmProofLenV2| bytes. base::span<const uint8_t, kResponseLenV1> lm_response,
bool WriteResponsePayloadsV2(NtlmBufferWriter* authenticate_writer, base::span<const uint8_t, kNtlmProofLenV2> v2_proof,
const uint8_t* lm_response, base::span<const uint8_t> v2_proof_input,
const uint8_t* v2_proof, base::span<const uint8_t> updated_target_info) {
const Buffer& v2_proof_input, return authenticate_writer->WriteBytes(lm_response) &&
const Buffer& updated_target_info) { authenticate_writer->WriteBytes(v2_proof) &&
return authenticate_writer->WriteBytes(lm_response, kResponseLenV1) &&
authenticate_writer->WriteBytes(v2_proof, kNtlmProofLenV2) &&
authenticate_writer->WriteBytes(v2_proof_input) && authenticate_writer->WriteBytes(v2_proof_input) &&
authenticate_writer->WriteBytes(updated_target_info) && authenticate_writer->WriteBytes(updated_target_info) &&
authenticate_writer->WriteUInt32(0); authenticate_writer->WriteUInt32(0);
...@@ -146,7 +143,7 @@ NtlmClient::NtlmClient(NtlmFeatures features) ...@@ -146,7 +143,7 @@ NtlmClient::NtlmClient(NtlmFeatures features)
NtlmClient::~NtlmClient() = default; NtlmClient::~NtlmClient() = default;
Buffer NtlmClient::GetNegotiateMessage() const { std::vector<uint8_t> NtlmClient::GetNegotiateMessage() const {
return negotiate_message_; return negotiate_message_;
} }
...@@ -164,7 +161,7 @@ void NtlmClient::GenerateNegotiateMessage() { ...@@ -164,7 +161,7 @@ void NtlmClient::GenerateNegotiateMessage() {
negotiate_message_ = writer.Pass(); negotiate_message_ = writer.Pass();
} }
Buffer NtlmClient::GenerateAuthenticateMessage( std::vector<uint8_t> NtlmClient::GenerateAuthenticateMessage(
const base::string16& domain, const base::string16& domain,
const base::string16& username, const base::string16& username,
const base::string16& password, const base::string16& password,
...@@ -172,8 +169,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -172,8 +169,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
const std::string& channel_bindings, const std::string& channel_bindings,
const std::string& spn, const std::string& spn,
uint64_t client_time, uint64_t client_time,
const uint8_t* client_challenge, base::span<const uint8_t, kChallengeLen> client_challenge,
const Buffer& server_challenge_message) const { base::span<const uint8_t> server_challenge_message) const {
// Limit the size of strings that are accepted. As an absolute limit any // Limit the size of strings that are accepted. As an absolute limit any
// field represented by a |SecurityBuffer| or |AvPair| must be less than // field represented by a |SecurityBuffer| or |AvPair| must be less than
// UINT16_MAX bytes long. The strings are restricted to the maximum sizes // UINT16_MAX bytes long. The strings are restricted to the maximum sizes
...@@ -188,8 +185,9 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -188,8 +185,9 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
// [2] - https://technet.microsoft.com/en-us/library/cc512606.aspx // [2] - https://technet.microsoft.com/en-us/library/cc512606.aspx
if (hostname.length() > kMaxFqdnLen || domain.length() > kMaxFqdnLen || if (hostname.length() > kMaxFqdnLen || domain.length() > kMaxFqdnLen ||
username.length() > kMaxUsernameLen || username.length() > kMaxUsernameLen ||
password.length() > kMaxPasswordLen) password.length() > kMaxPasswordLen) {
return Buffer(); return {};
}
NegotiateFlags challenge_flags; NegotiateFlags challenge_flags;
uint8_t server_challenge[kChallengeLen]; uint8_t server_challenge[kChallengeLen];
...@@ -197,8 +195,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -197,8 +195,8 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
uint8_t ntlm_response[kResponseLenV1]; uint8_t ntlm_response[kResponseLenV1];
// Response fields only for NTLMv2 // Response fields only for NTLMv2
Buffer updated_target_info; std::vector<uint8_t> updated_target_info;
Buffer v2_proof_input; std::vector<uint8_t> v2_proof_input;
uint8_t v2_proof[kNtlmProofLenV2]; uint8_t v2_proof[kNtlmProofLenV2];
uint8_t v2_session_key[kSessionKeyLenV2]; uint8_t v2_session_key[kSessionKeyLenV2];
...@@ -206,7 +204,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -206,7 +204,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
std::vector<AvPair> av_pairs; std::vector<AvPair> av_pairs;
if (!ParseChallengeMessageV2(server_challenge_message, &challenge_flags, if (!ParseChallengeMessageV2(server_challenge_message, &challenge_flags,
server_challenge, &av_pairs)) { server_challenge, &av_pairs)) {
return Buffer(); return {};
} }
uint64_t timestamp; uint64_t timestamp;
...@@ -229,7 +227,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -229,7 +227,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
} else { } else {
if (!ParseChallengeMessage(server_challenge_message, &challenge_flags, if (!ParseChallengeMessage(server_challenge_message, &challenge_flags,
server_challenge)) { server_challenge)) {
return Buffer(); return {};
} }
// Calculate the responses for the authenticate message. // Calculate the responses for the authenticate message.
...@@ -306,7 +304,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -306,7 +304,7 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
DCHECK(authenticate_writer.IsEndOfBuffer()); DCHECK(authenticate_writer.IsEndOfBuffer());
DCHECK_EQ(authenticate_message_len, authenticate_writer.GetLength()); DCHECK_EQ(authenticate_message_len, authenticate_writer.GetLength());
Buffer auth_msg = authenticate_writer.Pass(); std::vector<uint8_t> auth_msg = authenticate_writer.Pass();
// Backfill the MIC if enabled. // Backfill the MIC if enabled.
if (IsMicEnabled()) { if (IsMicEnabled()) {
...@@ -314,9 +312,10 @@ Buffer NtlmClient::GenerateAuthenticateMessage( ...@@ -314,9 +312,10 @@ Buffer NtlmClient::GenerateAuthenticateMessage(
// set to zeros. // set to zeros.
DCHECK_LT(kMicOffsetV2 + kMicLenV2, authenticate_message_len); DCHECK_LT(kMicOffsetV2 + kMicLenV2, authenticate_message_len);
uint8_t* mic_ptr = reinterpret_cast<uint8_t*>(&auth_msg[kMicOffsetV2]); base::span<uint8_t, kMicLenV2> mic(
const_cast<uint8_t*>(auth_msg.data()) + kMicOffsetV2, kMicLenV2);
GenerateMicV2(v2_session_key, negotiate_message_, server_challenge_message, GenerateMicV2(v2_session_key, negotiate_message_, server_challenge_message,
auth_msg, mic_ptr); auth_msg, mic);
} }
return auth_msg; return auth_msg;
...@@ -381,4 +380,4 @@ size_t NtlmClient::GetNtlmResponseLength(size_t updated_target_info_len) const { ...@@ -381,4 +380,4 @@ size_t NtlmClient::GetNtlmResponseLength(size_t updated_target_info_len) const {
} }
} // namespace ntlm } // namespace ntlm
} // namespace net } // namespace net
\ No newline at end of file
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "base/containers/span.h"
#include "base/strings/string16.h" #include "base/strings/string16.h"
#include "base/strings/string_piece.h" #include "base/strings/string_piece.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
...@@ -50,11 +51,11 @@ class NET_EXPORT_PRIVATE NtlmClient { ...@@ -50,11 +51,11 @@ class NET_EXPORT_PRIVATE NtlmClient {
bool IsEpaEnabled() const { return IsNtlmV2() && features_.enable_EPA; } bool IsEpaEnabled() const { return IsNtlmV2() && features_.enable_EPA; }
// Returns a |Buffer| containing the Negotiate message. // Returns the Negotiate message.
Buffer GetNegotiateMessage() const; std::vector<uint8_t> GetNegotiateMessage() const;
// Returns a |Buffer| containing the Authenticate message. If the method // Returns a the Authenticate message. If the method fails an empty vector
// fails an empty |Buffer| is returned. // is returned.
// //
// |username| is treated case insensitively by NTLM however the mechanism // |username| is treated case insensitively by NTLM however the mechanism
// to uppercase is not clearly defined. In this implementation the default // to uppercase is not clearly defined. In this implementation the default
...@@ -81,12 +82,11 @@ class NET_EXPORT_PRIVATE NtlmClient { ...@@ -81,12 +82,11 @@ class NET_EXPORT_PRIVATE NtlmClient {
// 100 nanosecond ticks since midnight Jan 01, 1601 (UTC). If the server does // 100 nanosecond ticks since midnight Jan 01, 1601 (UTC). If the server does
// not send a timestamp, the client timestamp is used in the Proof Input // not send a timestamp, the client timestamp is used in the Proof Input
// instead. // instead.
// |client_challenge| must contain 8 bytes of random data.
// |server_challenge_message| is the full content of the challenge message // |server_challenge_message| is the full content of the challenge message
// sent by the server. // sent by the server.
// //
// [1] - https://technet.microsoft.com/en-us/library/jj852267(v=ws.11).aspx // [1] - https://technet.microsoft.com/en-us/library/jj852267(v=ws.11).aspx
Buffer GenerateAuthenticateMessage( std::vector<uint8_t> GenerateAuthenticateMessage(
const base::string16& domain, const base::string16& domain,
const base::string16& username, const base::string16& username,
const base::string16& password, const base::string16& password,
...@@ -94,19 +94,19 @@ class NET_EXPORT_PRIVATE NtlmClient { ...@@ -94,19 +94,19 @@ class NET_EXPORT_PRIVATE NtlmClient {
const std::string& channel_bindings, const std::string& channel_bindings,
const std::string& spn, const std::string& spn,
uint64_t client_time, uint64_t client_time,
const uint8_t* client_challenge, base::span<const uint8_t, kChallengeLen> client_challenge,
const Buffer& server_challenge_message) const; base::span<const uint8_t> server_challenge_message) const;
// Simplified method for NTLMv1 which does not require |channel_bindings|, // Simplified method for NTLMv1 which does not require |channel_bindings|,
// |spn|, or |client_time|. See |GenerateAuthenticateMessage| for more // |spn|, or |client_time|. See |GenerateAuthenticateMessage| for more
// details. // details.
Buffer GenerateAuthenticateMessageV1( std::vector<uint8_t> GenerateAuthenticateMessageV1(
const base::string16& domain, const base::string16& domain,
const base::string16& username, const base::string16& username,
const base::string16& password, const base::string16& password,
const std::string& hostname, const std::string& hostname,
const uint8_t* client_challenge, base::span<const uint8_t, 8> client_challenge,
const Buffer& server_challenge_message) const { base::span<const uint8_t> server_challenge_message) const {
DCHECK(!IsNtlmV2()); DCHECK(!IsNtlmV2());
return GenerateAuthenticateMessage( return GenerateAuthenticateMessage(
...@@ -150,7 +150,7 @@ class NET_EXPORT_PRIVATE NtlmClient { ...@@ -150,7 +150,7 @@ class NET_EXPORT_PRIVATE NtlmClient {
NtlmFeatures features_; NtlmFeatures features_;
NegotiateFlags negotiate_flags_; NegotiateFlags negotiate_flags_;
Buffer negotiate_message_; std::vector<uint8_t> negotiate_message_;
DISALLOW_COPY_AND_ASSIGN(NtlmClient); DISALLOW_COPY_AND_ASSIGN(NtlmClient);
}; };
...@@ -158,4 +158,4 @@ class NET_EXPORT_PRIVATE NtlmClient { ...@@ -158,4 +158,4 @@ class NET_EXPORT_PRIVATE NtlmClient {
} // namespace ntlm } // namespace ntlm
} // namespace net } // namespace net
#endif // NET_BASE_NTLM_CLIENT_H_ #endif // NET_BASE_NTLM_CLIENT_H_
\ No newline at end of file
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "base/containers/span.h"
#include "base/test/fuzzed_data_provider.h" #include "base/test/fuzzed_data_provider.h"
#include "net/ntlm/ntlm_client.h" #include "net/ntlm/ntlm_client.h"
#include "net/ntlm/ntlm_test_data.h" #include "net/ntlm/ntlm_test_data.h"
...@@ -21,8 +22,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { ...@@ -21,8 +22,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
base::FuzzedDataProvider fdp(data, size); base::FuzzedDataProvider fdp(data, size);
bool is_v2 = fdp.ConsumeBool(); bool is_v2 = fdp.ConsumeBool();
uint64_t client_time = uint64_t client_time =
((uint64_t)fdp.ConsumeUint32InRange(0, 0xffffffffu) << 32) | (static_cast<uint64_t>(fdp.ConsumeUint32InRange(0, 0xffffffffu)) << 32) |
(uint64_t)fdp.ConsumeUint32InRange(0, 0xffffffffu); static_cast<uint64_t>(fdp.ConsumeUint32InRange(0, 0xffffffffu));
net::ntlm::NtlmClient client((net::ntlm::NtlmFeatures(is_v2))); net::ntlm::NtlmClient client((net::ntlm::NtlmFeatures(is_v2)));
// Generate the input strings and challenge message. The strings will have a // Generate the input strings and challenge message. The strings will have a
...@@ -44,8 +45,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { ...@@ -44,8 +45,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
client.GenerateAuthenticateMessage( client.GenerateAuthenticateMessage(
domain, username, password, hostname, channel_bindings, spn, client_time, domain, username, password, hostname, channel_bindings, spn, client_time,
net::ntlm::test::kClientChallenge, net::ntlm::test::kClientChallenge,
net::ntlm::Buffer( base::as_bytes(base::make_span(challenge_msg_bytes)));
reinterpret_cast<const uint8_t*>(challenge_msg_bytes.data()),
challenge_msg_bytes.size()));
return 0; return 0;
} }
\ No newline at end of file
This diff is collapsed.
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/ntlm/ntlm_constants.h"
namespace net {
namespace ntlm {
AvPair::AvPair() = default;
AvPair::AvPair(TargetInfoAvId avid, uint16_t avlen)
: avid(avid), avlen(avlen) {}
AvPair::AvPair(TargetInfoAvId avid, std::vector<uint8_t> buffer)
: buffer(std::move(buffer)), avid(avid) {
avlen = this->buffer.size();
}
AvPair::AvPair(const AvPair& other) = default;
AvPair::AvPair(AvPair&& other) = default;
AvPair::~AvPair() = default;
AvPair& AvPair::operator=(const AvPair& other) = default;
AvPair& AvPair::operator=(AvPair&& other) = default;
} // namespace ntlm
} // namespace net
...@@ -7,17 +7,15 @@ ...@@ -7,17 +7,15 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <string>
#include <type_traits>
#include "base/macros.h" #include <vector>
#include "base/stl_util.h"
#include "net/base/net_export.h" #include "net/base/net_export.h"
namespace net { namespace net {
namespace ntlm { namespace ntlm {
using Buffer = std::basic_string<uint8_t>;
// A security buffer is a structure within an NTLM message that indicates // A security buffer is a structure within an NTLM message that indicates
// the offset from the beginning of the message and the length of a payload // the offset from the beginning of the message and the length of a payload
// that occurs later in the message. Within the raw message there is also // that occurs later in the message. Within the raw message there is also
...@@ -134,15 +132,18 @@ constexpr inline TargetInfoAvFlags operator&(TargetInfoAvFlags lhs, ...@@ -134,15 +132,18 @@ constexpr inline TargetInfoAvFlags operator&(TargetInfoAvFlags lhs,
// other AvPairs the value of these 2 fields is undefined and the payload // other AvPairs the value of these 2 fields is undefined and the payload
// is in the |buffer| field. For these fields the payload is copied verbatim // is in the |buffer| field. For these fields the payload is copied verbatim
// and it's content is not read or validated in any way. // and it's content is not read or validated in any way.
struct AvPair { struct NET_EXPORT_PRIVATE AvPair {
AvPair() {} AvPair();
AvPair(TargetInfoAvId avid, uint16_t avlen) : avid(avid), avlen(avlen) {} AvPair(TargetInfoAvId avid, uint16_t avlen);
AvPair(TargetInfoAvId avid, Buffer buffer) AvPair(TargetInfoAvId avid, std::vector<uint8_t> buffer);
: buffer(std::move(buffer)), avid(avid) { AvPair(const AvPair& other);
avlen = this->buffer.size(); AvPair(AvPair&& other);
} ~AvPair();
Buffer buffer; AvPair& operator=(const AvPair& other);
AvPair& operator=(AvPair&& other);
std::vector<uint8_t> buffer;
uint64_t timestamp; uint64_t timestamp;
TargetInfoAvFlags flags; TargetInfoAvFlags flags;
TargetInfoAvId avid; TargetInfoAvId avid;
...@@ -150,7 +151,7 @@ struct AvPair { ...@@ -150,7 +151,7 @@ struct AvPair {
}; };
static constexpr uint8_t kSignature[] = "NTLMSSP"; static constexpr uint8_t kSignature[] = "NTLMSSP";
static constexpr size_t kSignatureLen = arraysize(kSignature); static constexpr size_t kSignatureLen = base::size(kSignature);
static constexpr uint16_t kProofInputVersionV2 = 0x0101; static constexpr uint16_t kProofInputVersionV2 = 0x0101;
static constexpr size_t kSecurityBufferLen = static constexpr size_t kSecurityBufferLen =
(2 * sizeof(uint16_t)) + sizeof(uint32_t); (2 * sizeof(uint16_t)) + sizeof(uint32_t);
...@@ -186,4 +187,4 @@ static constexpr NegotiateFlags kNegotiateMessageFlags = ...@@ -186,4 +187,4 @@ static constexpr NegotiateFlags kNegotiateMessageFlags =
} // namespace ntlm } // namespace ntlm
} // namespace net } // namespace net
#endif // NET_BASE_NTLM_CONSTANTS_H_ #endif // NET_BASE_NTLM_CONSTANTS_H_
\ No newline at end of file
...@@ -572,4 +572,4 @@ constexpr uint8_t kExpectedAuthenticateMsgEmptyChannelBindingsV2[] = { ...@@ -572,4 +572,4 @@ constexpr uint8_t kExpectedAuthenticateMsgEmptyChannelBindingsV2[] = {
} // namespace ntlm } // namespace ntlm
} // namespace net } // namespace net
#endif // NET_BASE_NTLM_TEST_DATA_H_ #endif // NET_BASE_NTLM_TEST_DATA_H_
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment