Commit 1cceb257 authored by Andres Pico's avatar Andres Pico Committed by Commit Bot

Ability to Prevent Premature COM Uninitialization

This cl adds the capability to prevent the premature uninitialization
of the COM library in ScopedCOMInitializer. Premature uninitialization
usually occurs in the presence of unbalanced CoInitialize/CoUnitialize
pairs. While we can prevent this from ocurring in first party-code,
there is no mechanism that protects us when executing third-party code
in a COM enabled thread such as in the case of the Quarantine process.

Bug: 1075487
Change-Id: Ibb3cf304c6bbabc126867de47e963a52c9409248
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2378270Reviewed-by: default avatarBruce Dawson <brucedawson@chromium.org>
Reviewed-by: default avatarAsanka Herath <asanka@chromium.org>
Reviewed-by: default avatarGreg Thompson <grt@chromium.org>
Commit-Queue: Andres Pico <anpico@microsoft.com>
Cr-Commit-Position: refs/heads/master@{#804589}
parent 77109ca5
......@@ -955,6 +955,8 @@ component("base") {
"win/atl.h",
"win/atl_throw.cc",
"win/atl_throw.h",
"win/com_init_balancer.cc",
"win/com_init_balancer.h",
"win/com_init_check_hook.cc",
"win/com_init_check_hook.h",
"win/com_init_util.cc",
......@@ -2988,6 +2990,7 @@ test("base_unittests") {
"threading/platform_thread_win_unittest.cc",
"time/time_win_unittest.cc",
"win/async_operation_unittest.cc",
"win/com_init_balancer_unittest.cc",
"win/com_init_check_hook_unittest.cc",
"win/com_init_util_unittest.cc",
"win/core_winrt_util_unittest.cc",
......
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <objbase.h>
#include "base/check_op.h"
#include "base/win/com_init_balancer.h"
namespace base {
namespace win {
namespace internal {
ComInitBalancer::ComInitBalancer(DWORD co_init) : co_init_(co_init) {
ULARGE_INTEGER spy_cookie = {};
HRESULT hr = ::CoRegisterInitializeSpy(this, &spy_cookie);
if (SUCCEEDED(hr))
spy_cookie_ = spy_cookie;
}
ComInitBalancer::~ComInitBalancer() {
DCHECK(!spy_cookie_.has_value());
}
void ComInitBalancer::Disable() {
if (spy_cookie_.has_value()) {
::CoRevokeInitializeSpy(spy_cookie_.value());
reference_count_ = 0;
spy_cookie_.reset();
}
}
DWORD ComInitBalancer::GetReferenceCountForTesting() const {
return reference_count_;
}
IFACEMETHODIMP
ComInitBalancer::PreInitialize(DWORD apartment_type, DWORD reference_count) {
return S_OK;
}
IFACEMETHODIMP
ComInitBalancer::PostInitialize(HRESULT result,
DWORD apartment_type,
DWORD new_reference_count) {
reference_count_ = new_reference_count;
return result;
}
IFACEMETHODIMP
ComInitBalancer::PreUninitialize(DWORD reference_count) {
if (reference_count == 1 && spy_cookie_.has_value()) {
// Increase the reference count to prevent premature and unbalanced
// uninitalization of the COM library.
::CoInitializeEx(nullptr, co_init_);
}
return S_OK;
}
IFACEMETHODIMP
ComInitBalancer::PostUninitialize(DWORD new_reference_count) {
reference_count_ = new_reference_count;
return S_OK;
}
} // namespace internal
} // namespace win
} // namespace base
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef BASE_WIN_COM_INIT_BALANCER_H_
#define BASE_WIN_COM_INIT_BALANCER_H_
#include <objidl.h>
#include <winnt.h>
#include <wrl/implements.h>
#include "base/base_export.h"
#include "base/optional.h"
#include "base/threading/thread_checker.h"
#include "base/win/windows_types.h"
namespace base {
namespace win {
namespace internal {
// Implementation class of the IInitializeSpy Interface that prevents premature
// uninitialization of the COM library, often caused by unbalanced
// CoInitialize/CoUninitialize pairs. The use of this class is encouraged in
// COM-supporting threads that execute third-party code.
//
// Disable() must be called before uninitializing the COM library in order to
// revoke the registered spy and allow for the successful uninitialization of
// the COM library.
class BASE_EXPORT ComInitBalancer
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
IInitializeSpy> {
public:
// Constructs a COM initialize balancer. |co_init| defines the apartment's
// concurrency model used by the balancer.
explicit ComInitBalancer(DWORD co_init);
ComInitBalancer(const ComInitBalancer&) = delete;
ComInitBalancer& operator=(const ComInitBalancer&) = delete;
~ComInitBalancer() override;
// Disables balancer by revoking the registered spy and consequently
// unblocking attempts to uninitialize the COM library.
void Disable();
DWORD GetReferenceCountForTesting() const;
private:
// IInitializeSpy:
IFACEMETHODIMP PreInitialize(DWORD apartment_type,
DWORD reference_count) override;
IFACEMETHODIMP PostInitialize(HRESULT result,
DWORD apartment_type,
DWORD new_reference_count) override;
IFACEMETHODIMP PreUninitialize(DWORD reference_count) override;
IFACEMETHODIMP PostUninitialize(DWORD new_reference_count) override;
const DWORD co_init_;
// The current apartment reference count set after the completion of the last
// call made to CoInitialize or CoUninitialize.
DWORD reference_count_ = 0;
base::Optional<ULARGE_INTEGER> spy_cookie_;
THREAD_CHECKER(thread_checker_);
};
} // namespace internal
} // namespace win
} // namespace base
#endif // BASE_WIN_COM_INIT_BALANCER_H_
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/win/com_init_balancer.h"
#include <shlobj.h>
#include <wrl/client.h>
#include "base/test/gtest_util.h"
#include "base/win/com_init_util.h"
#include "base/win/scoped_com_initializer.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace base {
namespace win {
using Microsoft::WRL::ComPtr;
TEST(TestComInitBalancer, BalancedPairsWithComBalancerEnabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());
// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}
// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}
TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerEnabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());
// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();
::CoUninitialize();
// Assert COM is still initialized.
AssertComInitialized();
// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}
// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}
TEST(TestComInitBalancer, BalancedPairsWithComBalancerDisabled) {
{
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kAllow);
ASSERT_TRUE(com_initializer.Succeeded());
// Create COM object successfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(SUCCEEDED(hr));
}
// ScopedCOMInitializer has gone out of scope and COM has been uninitialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
}
TEST(TestComInitBalancer, UnbalancedPairsWithComBalancerDisabled) {
// Assert COM has initialized correctly.
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kAllow);
ASSERT_TRUE(com_initializer.Succeeded());
// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();
::CoUninitialize();
// Assert COM is not initialized.
EXPECT_DCHECK_DEATH(AssertComInitialized());
// Create COM object unsuccessfully.
ComPtr<IUnknown> shell_link;
HRESULT hr = ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
IID_PPV_ARGS(&shell_link));
EXPECT_TRUE(FAILED(hr));
EXPECT_EQ(CO_E_NOTINITIALIZED, hr);
}
TEST(TestComInitBalancer, OneRegisteredSpyRefCount) {
ScopedCOMInitializer com_initializer(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer.Succeeded());
// Reference count should be 1 after initialization.
EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());
// Attempt to prematurely uninitialize the COM library.
::CoUninitialize();
// Expect reference count to remain at 1.
EXPECT_EQ(DWORD(1), com_initializer.GetCOMBalancerReferenceCountForTesting());
}
TEST(TestComInitBalancer, ThreeRegisteredSpiesRefCount) {
ScopedCOMInitializer com_initializer_1(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ScopedCOMInitializer com_initializer_2(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ScopedCOMInitializer com_initializer_3(
ScopedCOMInitializer::Uninitialization::kBlockPremature);
ASSERT_TRUE(com_initializer_1.Succeeded());
ASSERT_TRUE(com_initializer_2.Succeeded());
ASSERT_TRUE(com_initializer_3.Succeeded());
// Reference count should be 3 after initialization.
EXPECT_EQ(DWORD(3),
com_initializer_1.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(3),
com_initializer_2.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(3),
com_initializer_3.GetCOMBalancerReferenceCountForTesting());
// Attempt to prematurely uninitialize the COM library.
::CoUninitialize(); // Reference count -> 2.
::CoUninitialize(); // Reference count -> 1.
::CoUninitialize();
// Expect reference count to remain at 1.
EXPECT_EQ(DWORD(1),
com_initializer_1.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(1),
com_initializer_2.GetCOMBalancerReferenceCountForTesting());
EXPECT_EQ(DWORD(1),
com_initializer_3.GetCOMBalancerReferenceCountForTesting());
}
} // namespace win
} // namespace base
......@@ -4,34 +4,53 @@
#include "base/win/scoped_com_initializer.h"
#include <wrl/implements.h>
#include "base/check_op.h"
namespace base {
namespace win {
ScopedCOMInitializer::ScopedCOMInitializer() {
Initialize(COINIT_APARTMENTTHREADED);
ScopedCOMInitializer::ScopedCOMInitializer(Uninitialization uninitialization) {
Initialize(COINIT_APARTMENTTHREADED, uninitialization);
}
ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta) {
Initialize(COINIT_MULTITHREADED);
ScopedCOMInitializer::ScopedCOMInitializer(SelectMTA mta,
Uninitialization uninitialization) {
Initialize(COINIT_MULTITHREADED, uninitialization);
}
ScopedCOMInitializer::~ScopedCOMInitializer() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
if (Succeeded())
if (Succeeded()) {
if (com_balancer_) {
com_balancer_->Disable();
com_balancer_.Reset();
}
CoUninitialize();
}
}
bool ScopedCOMInitializer::Succeeded() const {
return SUCCEEDED(hr_);
}
void ScopedCOMInitializer::Initialize(COINIT init) {
DWORD ScopedCOMInitializer::GetCOMBalancerReferenceCountForTesting() const {
if (com_balancer_)
return com_balancer_->GetReferenceCountForTesting();
return 0;
}
void ScopedCOMInitializer::Initialize(COINIT init,
Uninitialization uninitialization) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
// COINIT_DISABLE_OLE1DDE is always added based on:
// https://docs.microsoft.com/en-us/windows/desktop/learnwin32/initializing-the-com-library
hr_ = CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE);
if (uninitialization == Uninitialization::kBlockPremature) {
com_balancer_ = Microsoft::WRL::Details::Make<internal::ComInitBalancer>(
init | COINIT_DISABLE_OLE1DDE);
}
hr_ = ::CoInitializeEx(nullptr, init | COINIT_DISABLE_OLE1DDE);
DCHECK_NE(RPC_E_CHANGED_MODE, hr_) << "Invalid COM thread model change";
}
......
......@@ -6,10 +6,12 @@
#define BASE_WIN_SCOPED_COM_INITIALIZER_H_
#include <objbase.h>
#include <wrl/client.h>
#include "base/base_export.h"
#include "base/macros.h"
#include "base/threading/thread_checker.h"
#include "base/win/com_init_balancer.h"
#include "base/win/scoped_windows_thread_environment.h"
namespace base {
......@@ -18,6 +20,10 @@ namespace win {
// Initializes COM in the constructor (STA or MTA), and uninitializes COM in the
// destructor.
//
// It is strongly encouraged to block premature uninitialization of the COM
// libraries in threads that execute third-party code, as a way to protect
// against unbalanced CoInitialize/CoUninitialize pairs.
//
// WARNING: This should only be used once per thread, ideally scoped to a
// similar lifetime as the thread itself. You should not be using this in
// random utility functions that make COM calls -- instead ensure these
......@@ -27,21 +33,39 @@ class BASE_EXPORT ScopedCOMInitializer : public ScopedWindowsThreadEnvironment {
// Enum value provided to initialize the thread as an MTA instead of STA.
enum SelectMTA { kMTA };
// Constructor for STA initialization.
ScopedCOMInitializer();
// Enum values which enumerates uninitialization modes for the COM library.
enum class Uninitialization {
// Default value. Used in threads where no third-party code is executed.
kAllow,
// Blocks premature uninitialization of the COM libraries before going out
// of scope. Used in threads where third-party code is executed.
kBlockPremature,
};
// Constructor for MTA initialization.
explicit ScopedCOMInitializer(SelectMTA mta);
// Constructors for STA initialization.
explicit ScopedCOMInitializer(
Uninitialization uninitialization = Uninitialization::kAllow);
// Constructors for MTA initialization.
explicit ScopedCOMInitializer(
SelectMTA mta,
Uninitialization uninitialization = Uninitialization::kAllow);
~ScopedCOMInitializer() override;
// ScopedWindowsThreadEnvironment:
bool Succeeded() const override;
// Used for testing. Returns the COM balancer's apartment thread ref count.
DWORD GetCOMBalancerReferenceCountForTesting() const;
private:
void Initialize(COINIT init);
void Initialize(COINIT init, Uninitialization uninitialization);
HRESULT hr_;
HRESULT hr_ = S_OK;
Microsoft::WRL::ComPtr<internal::ComInitBalancer> com_balancer_;
THREAD_CHECKER(thread_checker_);
DISALLOW_COPY_AND_ASSIGN(ScopedCOMInitializer);
......
......@@ -13,7 +13,6 @@
#include "components/services/quarantine/quarantine.h"
#if defined(OS_WIN)
#include "base/win/scoped_com_initializer.h"
#include "components/services/quarantine/public/cpp/quarantine_features_win.h"
#endif // OS_WIN
......@@ -53,8 +52,6 @@ void QuarantineImpl::QuarantineFile(
if (base::FeatureList::IsEnabled(quarantine::kOutOfProcessQuarantine)) {
// In out of process case, we are running in a utility process,
// so directly call QuarantineFile and send the result.
base::win::ScopedCOMInitializer com_initializer;
QuarantineFileResult result = quarantine::QuarantineFile(
full_path, source_url, referrer_url, client_guid);
......
......@@ -7,10 +7,15 @@
#include <memory>
#include "build/build_config.h"
#include "components/services/quarantine/public/mojom/quarantine.mojom.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#if defined(OS_WIN)
#include "base/win/scoped_com_initializer.h"
#endif // OS_WIN
namespace quarantine {
class QuarantineImpl : public mojom::Quarantine {
......@@ -30,6 +35,11 @@ class QuarantineImpl : public mojom::Quarantine {
private:
mojo::Receiver<mojom::Quarantine> receiver_{this};
#if defined(OS_WIN)
base::win::ScopedCOMInitializer com_initializer_{
base::win::ScopedCOMInitializer::Uninitialization::kBlockPremature};
#endif // OS_WIN
DISALLOW_COPY_AND_ASSIGN(QuarantineImpl);
};
......
......@@ -14,6 +14,7 @@
#include "base/test/scoped_feature_list.h"
#include "base/test/test_file_util.h"
#include "base/test/test_reg_util_win.h"
#include "base/win/scoped_com_initializer.h"
#include "base/win/win_util.h"
#include "base/win/windows_version.h"
#include "components/services/quarantine/public/cpp/quarantine_features_win.h"
......@@ -158,6 +159,9 @@ class QuarantineWinTest : public ::testing::Test {
base::ScopedTempDir scoped_temp_dir_;
base::win::ScopedCOMInitializer com_initializer_{
base::win::ScopedCOMInitializer::Uninitialization::kBlockPremature};
// Due to caching, these sites zone must be set for all tests, so that the
// order the tests are run does not matter.
std::unique_ptr<ScopedZoneForSite> scoped_zone_for_trusted_site_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment