Commit f395befb authored by Cliff Smolinsky's avatar Cliff Smolinsky Committed by Commit Bot

Native Library improvements

This change adds LoadSystemLibrary to NativeLibrary to allow usage of
the NativeLibrary abstraction when loading system dlls and prevent dll
loading attacks. This change also moves ScopedNativeLibrary to base off
of ScopedGeneric so that the scoped class gets all the benefits of
ScopedGeneric without having to re-invent.

Bug: 1551709
Change-Id: I17a14678687bd7d167ab0dad1cff5dbce53c3313
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1560294
Commit-Queue: Cliff Smolinsky <cliffsmo@microsoft.com>
Reviewed-by: default avatarAntoine Labour <piman@chromium.org>
Reviewed-by: default avatarJoe Downing <joedow@chromium.org>
Reviewed-by: default avatarXiaohan Wang <xhwang@chromium.org>
Reviewed-by: default avatarBrandon Jones <bajones@chromium.org>
Reviewed-by: default avatarGreg Thompson <grt@chromium.org>
Reviewed-by: default avatarFrançois Doray <fdoray@chromium.org>
Cr-Commit-Position: refs/heads/master@{#650522}
parent 8526a5b8
......@@ -11,6 +11,7 @@
#include <string>
#include "base/base_export.h"
#include "base/files/file_path.h"
#include "base/strings/string_piece.h"
#include "build/build_config.h"
......@@ -22,8 +23,6 @@
namespace base {
class FilePath;
#if defined(OS_WIN)
using NativeLibrary = HMODULE;
#elif defined(OS_MACOSX)
......@@ -83,6 +82,15 @@ struct BASE_EXPORT NativeLibraryOptions {
BASE_EXPORT NativeLibrary LoadNativeLibrary(const FilePath& library_path,
NativeLibraryLoadError* error);
#if defined(OS_WIN)
// Loads a native library from the system directory using the appropriate flags.
// The function first checks to see if the library is already loaded and will
// get a handle if so. Blocking may occur if the library is not loaded and
// LoadLibrary must be called.
BASE_EXPORT NativeLibrary LoadSystemLibrary(FilePath::StringPieceType name,
NativeLibraryLoadError* error);
#endif
// Loads a native library from disk. Release it with UnloadNativeLibrary when
// you're done. Returns NULL on failure.
// If |error| is not NULL, it may be filled in on load error.
......
......@@ -8,6 +8,9 @@
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros.h"
#include "base/optional.h"
#include "base/path_service.h"
#include "base/scoped_native_library.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
......@@ -15,9 +18,11 @@
namespace base {
using AddDllDirectory = HMODULE (*)(PCWSTR new_directory);
namespace {
// forward declare
HMODULE AddDllDirectory(PCWSTR new_directory);
// This enum is used to back an UMA histogram, and should therefore be treated
// as append-only.
enum LoadLibraryResult {
......@@ -56,8 +61,9 @@ bool AreSearchFlagsAvailable() {
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms684179(v=vs.85).aspx
// The LOAD_LIBRARY_SEARCH_* flags are used in the LoadNativeLibraryHelper
// method.
auto add_dll_dir_func = reinterpret_cast<AddDllDirectory>(
GetProcAddress(GetModuleHandle(L"kernel32.dll"), "AddDllDirectory"));
static const auto add_dll_dir_func =
reinterpret_cast<decltype(AddDllDirectory)*>(
GetProcAddress(GetModuleHandle(L"kernel32.dll"), "AddDllDirectory"));
return !!add_dll_dir_func;
}
......@@ -81,7 +87,8 @@ LoadLibraryResult GetLoadLibraryResult(bool are_search_flags_available,
NativeLibrary LoadNativeLibraryHelper(const FilePath& library_path,
NativeLibraryLoadError* error) {
// LoadLibrary() opens the file off disk.
// LoadLibrary() opens the file off disk and acquires the LoaderLock, hence
// must not be called from DllMain.
ScopedBlockingCall scoped_blocking_call(FROM_HERE, BlockingType::MAY_BLOCK);
HMODULE module = nullptr;
......@@ -137,6 +144,40 @@ NativeLibrary LoadNativeLibraryHelper(const FilePath& library_path,
return module;
}
NativeLibrary LoadSystemLibraryHelper(const FilePath& library_path,
NativeLibraryLoadError* error) {
NativeLibrary module;
BOOL module_found =
::GetModuleHandleEx(0, as_wcstr(library_path.value()), &module);
if (!module_found) {
// LoadLibrary() opens the file off disk and acquires the LoaderLock, hence
// must not be called from DllMain.
ScopedBlockingCall scoped_blocking_call(FROM_HERE, BlockingType::MAY_BLOCK);
bool are_search_flags_available = AreSearchFlagsAvailable();
// prefer LOAD_LIBRARY_SEARCH_SYSTEM32 to avoid DLL preloading attacks
DWORD flags = are_search_flags_available ? LOAD_LIBRARY_SEARCH_SYSTEM32
: LOAD_WITH_ALTERED_SEARCH_PATH;
module = ::LoadLibraryExW(as_wcstr(library_path.value()), nullptr, flags);
if (!module && error)
error->code = ::GetLastError();
LogLibrarayLoadResultToUMA(
GetLoadLibraryResult(are_search_flags_available, !!module));
}
return module;
}
Optional<FilePath> GetSystemLibraryName(FilePath::StringPieceType name) {
FilePath library_path;
// Use an absolute path to load the DLL to avoid DLL preloading attacks.
if (!base::PathService::Get(base::DIR_SYSTEM, &library_path))
return base::nullopt;
return make_optional(library_path.Append(name));
}
} // namespace
std::string NativeLibraryLoadError::ToString() const {
......@@ -167,4 +208,14 @@ std::string GetLoadableModuleName(StringPiece name) {
return GetNativeLibraryName(name);
}
NativeLibrary LoadSystemLibrary(FilePath::StringPieceType name,
NativeLibraryLoadError* error) {
Optional<FilePath> library_path = GetSystemLibraryName(name);
if (library_path)
return LoadSystemLibraryHelper(library_path.value(), error);
if (error)
error->code = ERROR_NOT_FOUND;
return nullptr;
}
} // namespace base
......@@ -119,7 +119,7 @@ class ScopedGeneric {
TrackAcquire(data_.generic);
}
~ScopedGeneric() {
virtual ~ScopedGeneric() {
CHECK(!receiving_) << "ScopedGeneric destroyed with active receiver";
FreeIfNecessary();
}
......
......@@ -6,38 +6,35 @@
namespace base {
ScopedNativeLibrary::ScopedNativeLibrary() : library_(nullptr) {}
void NativeLibraryTraits::Free(NativeLibrary library) {
UnloadNativeLibrary(library);
}
using BaseClass = ScopedGeneric<NativeLibrary, NativeLibraryTraits>;
ScopedNativeLibrary::ScopedNativeLibrary() : BaseClass(), error_() {}
ScopedNativeLibrary::~ScopedNativeLibrary() = default;
ScopedNativeLibrary::ScopedNativeLibrary(NativeLibrary library)
: library_(library) {
}
: BaseClass(library), error_() {}
ScopedNativeLibrary::ScopedNativeLibrary(const FilePath& library_path) {
library_ = base::LoadNativeLibrary(library_path, nullptr);
ScopedNativeLibrary::ScopedNativeLibrary(const FilePath& library_path)
: ScopedNativeLibrary() {
reset(LoadNativeLibrary(library_path, &error_));
}
ScopedNativeLibrary::~ScopedNativeLibrary() {
if (library_)
base::UnloadNativeLibrary(library_);
}
ScopedNativeLibrary::ScopedNativeLibrary(ScopedNativeLibrary&& scoped_library)
: BaseClass(scoped_library.release()), error_() {}
void* ScopedNativeLibrary::GetFunctionPointer(
const char* function_name) const {
if (!library_)
void* ScopedNativeLibrary::GetFunctionPointer(const char* function_name) const {
if (!is_valid())
return nullptr;
return base::GetFunctionPointerFromNativeLibrary(library_, function_name);
}
void ScopedNativeLibrary::Reset(NativeLibrary library) {
if (library_)
base::UnloadNativeLibrary(library_);
library_ = library;
return GetFunctionPointerFromNativeLibrary(get(), function_name);
}
NativeLibrary ScopedNativeLibrary::Release() {
NativeLibrary result = library_;
library_ = nullptr;
return result;
const NativeLibraryLoadError* ScopedNativeLibrary::GetError() const {
return &error_;
}
} // namespace base
......@@ -8,44 +8,53 @@
#include "base/base_export.h"
#include "base/macros.h"
#include "base/native_library.h"
#include "base/scoped_generic.h"
namespace base {
class FilePath;
struct BASE_EXPORT NativeLibraryTraits {
// It's assumed that this is a fast inline function with little-to-no
// penalty for duplicate calls. This must be a static function even
// for stateful traits.
static NativeLibrary InvalidValue() { return nullptr; }
// This free function will not be called if library == InvalidValue()!
static void Free(NativeLibrary library);
};
// A class which encapsulates a base::NativeLibrary object available only in a
// scope.
// This class automatically unloads the loaded library in its destructor.
class BASE_EXPORT ScopedNativeLibrary {
class BASE_EXPORT ScopedNativeLibrary
: public ScopedGeneric<NativeLibrary, NativeLibraryTraits> {
public:
// Initializes with a NULL library.
ScopedNativeLibrary();
~ScopedNativeLibrary() override;
// Takes ownership of the given library handle.
explicit ScopedNativeLibrary(NativeLibrary library);
// Opens the given library and manages its lifetime.
explicit ScopedNativeLibrary(const FilePath& library_path);
~ScopedNativeLibrary();
// Returns true if there's a valid library loaded.
bool is_valid() const { return !!library_; }
// Move constructor. Takes ownership of handle stored in |scoped_library|
ScopedNativeLibrary(ScopedNativeLibrary&& scoped_library);
NativeLibrary get() const { return library_; }
// Move assignment operator. Takes ownership of handle stored in
// |scoped_library|.
ScopedNativeLibrary& operator=(ScopedNativeLibrary&& scoped_library) =
default;
void* GetFunctionPointer(const char* function_name) const;
// Takes ownership of the given library handle. Any existing handle will
// be freed.
void Reset(NativeLibrary library);
// Returns the native library handle and removes it from this object. The
// caller must manage the lifetime of the handle.
NativeLibrary Release();
const NativeLibraryLoadError* GetError() const;
private:
NativeLibrary library_;
NativeLibraryLoadError error_;
DISALLOW_COPY_AND_ASSIGN(ScopedNativeLibrary);
};
......
......@@ -126,10 +126,8 @@ TEST(SafeBrowsingEnvironmentDataCollectionWinTest, RecordLspFeature) {
TEST(SafeBrowsingEnvironmentDataCollectionWinTest, VerifyLoadedModules) {
// Load the test modules.
std::vector<base::ScopedNativeLibrary> test_dlls(kTestDllNamesCount);
for (size_t i = 0; i < kTestDllNamesCount; ++i) {
test_dlls[i].Reset(
LoadNativeLibrary(base::FilePath(kTestDllNames[i]), NULL));
}
for (size_t i = 0; i < kTestDllNamesCount; ++i)
test_dlls[i] = base::ScopedNativeLibrary(base::FilePath(kTestDllNames[i]));
// Edit the first byte of the function exported by the first module. Calling
// GetModuleHandle so we do not increment the library ref count.
......
......@@ -95,7 +95,7 @@ class SafeBrowsingModuleVerifierWinTest : public testing::Test {
LoadNativeLibrary(base::FilePath(kTestDllNames[0]), NULL);
ASSERT_NE(static_cast<HMODULE>(NULL), mem_dll_handle)
<< "GLE=" << GetLastError();
mem_dll_handle_.Reset(mem_dll_handle);
mem_dll_handle_ = base::ScopedNativeLibrary(mem_dll_handle);
ASSERT_TRUE(mem_dll_handle_.is_valid());
}
......
......@@ -311,27 +311,26 @@ void PpapiThread::OnLoadPlugin(const base::FilePath& path,
base::ScopedNativeLibrary library;
if (!plugin_entry_points_.initialize_module) {
// Load the plugin from the specified library.
base::NativeLibraryLoadError error;
base::TimeDelta load_time;
{
TRACE_EVENT1("ppapi", "PpapiThread::LoadPlugin", "path",
path.MaybeAsASCII());
base::TimeTicks start = base::TimeTicks::Now();
library.Reset(base::LoadNativeLibrary(path, &error));
library = base::ScopedNativeLibrary(path);
load_time = base::TimeTicks::Now() - start;
}
if (!library.is_valid()) {
LOG(ERROR) << "Failed to load Pepper module from " << path.value()
<< " (error: " << error.ToString() << ")";
<< " (error: " << library.GetError()->ToString() << ")";
if (!base::PathExists(path)) {
ReportLoadResult(path, FILE_MISSING);
return;
}
ReportLoadResult(path, LOAD_FAILED);
// Report detailed reason for load failure.
ReportLoadErrorCode(path, error);
ReportLoadErrorCode(path, library.GetError());
return;
}
......@@ -452,7 +451,7 @@ void PpapiThread::OnLoadPlugin(const base::FilePath& path,
}
// Initialization succeeded, so keep the plugin DLL loaded.
library_.Reset(library.Release());
library_ = std::move(library);
ReportLoadResult(path, LOAD_SUCCESS);
}
......@@ -571,12 +570,12 @@ void PpapiThread::ReportLoadResult(const base::FilePath& path,
void PpapiThread::ReportLoadErrorCode(
const base::FilePath& path,
const base::NativeLibraryLoadError& error) {
const base::NativeLibraryLoadError* error) {
// Only report load error code on Windows because that's the only platform that
// has a numerical error value.
#if defined(OS_WIN)
base::UmaHistogramSparse(GetHistogramName(is_broker_, "LoadErrorCode", path),
error.code);
error->code);
#endif
}
......
......@@ -138,7 +138,7 @@ class PpapiThread : public ChildThreadImpl,
// Reports |error| to UMA when plugin load fails.
void ReportLoadErrorCode(const base::FilePath& path,
const base::NativeLibraryLoadError& error);
const base::NativeLibraryLoadError* error);
// Reports time to load the plugin.
void ReportLoadTime(const base::FilePath& path,
......
......@@ -94,8 +94,7 @@ GamepadSource GamepadPlatformDataFetcherWin::source() {
}
void GamepadPlatformDataFetcherWin::OnAddedToProvider() {
xinput_dll_.Reset(
base::LoadNativeLibrary(base::FilePath(XInputDllFileName()), nullptr));
xinput_dll_ = base::ScopedNativeLibrary(base::FilePath(XInputDllFileName()));
xinput_available_ = GetXInputDllFunctions();
}
......
......@@ -61,11 +61,11 @@ void ReportLoadResult(LoadResult load_result) {
LoadResult::kLoadResultCount);
}
void ReportLoadErrorCode(const base::NativeLibraryLoadError& error) {
void ReportLoadErrorCode(const base::NativeLibraryLoadError* error) {
// Only report load error code on Windows because that's the only platform that
// has a numerical error value.
#if defined(OS_WIN)
base::UmaHistogramSparse("Media.EME.CdmLoadErrorCode", error.code);
base::UmaHistogramSparse("Media.EME.CdmLoadErrorCode", error->code);
#endif
}
......@@ -128,16 +128,15 @@ bool CdmModule::Initialize(const base::FilePath& cdm_path) {
cdm_path_ = cdm_path;
// Load the CDM.
base::NativeLibraryLoadError error;
base::TimeTicks start = base::TimeTicks::Now();
library_.Reset(base::LoadNativeLibrary(cdm_path, &error));
library_ = base::ScopedNativeLibrary(cdm_path);
base::TimeDelta load_time = base::TimeTicks::Now() - start;
if (!library_.is_valid()) {
LOG(ERROR) << "CDM at " << cdm_path.value() << " could not be loaded.";
LOG(ERROR) << "Error: " << error.ToString();
LOG(ERROR) << "Error: " << library_.GetError()->ToString();
ReportLoadResult(base::PathExists(cdm_path) ? LoadResult::kLoadFailed
: LoadResult::kFileMissing);
ReportLoadErrorCode(error);
ReportLoadErrorCode(library_.GetError());
return false;
}
......@@ -162,7 +161,7 @@ bool CdmModule::Initialize(const base::FilePath& cdm_path) {
deinitialize_cdm_module_func_ = nullptr;
create_cdm_func_ = nullptr;
get_cdm_version_func_ = nullptr;
library_.Release();
library_.reset();
ReportLoadResult(LoadResult::kEntryPointMissing);
return false;
}
......
......@@ -41,9 +41,8 @@ void ExternalClearKeyTestHelper::LoadLibrary() {
ASSERT_TRUE(base::PathExists(library_path_)) << library_path_.value();
// Now load the CDM library.
base::NativeLibraryLoadError error;
library_.Reset(base::LoadNativeLibrary(library_path_, &error));
ASSERT_TRUE(library_.is_valid()) << error.ToString();
library_ = base::ScopedNativeLibrary(library_path_);
ASSERT_TRUE(library_.is_valid()) << library_.GetError()->ToString();
// Call INITIALIZE_CDM_MODULE()
typedef void (*InitializeCdmFunc)();
......
......@@ -133,7 +133,7 @@ std::unique_ptr<TouchInjectorWinDelegate> TouchInjectorWinDelegate::Create() {
}
return std::unique_ptr<TouchInjectorWinDelegate>(new TouchInjectorWinDelegate(
library.Release(), init_func, inject_touch_func));
library.release(), init_func, inject_touch_func));
}
TouchInjectorWinDelegate::TouchInjectorWinDelegate(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment