Commit 8ef30a35 authored by jrummell's avatar jrummell Committed by Commit bot

[eme] Reject CDM calls after connection error

Once the mojo connection is broken, all subsequent calls to the CDM
should fail.

BUG=671362
TEST=new tests pass

Review-Url: https://codereview.chromium.org/2561263002
Cr-Commit-Position: refs/heads/master@{#439032}
parent 550cea9c
...@@ -7,9 +7,12 @@ ...@@ -7,9 +7,12 @@
#include "base/logging.h" #include "base/logging.h"
using ::testing::_; using ::testing::_;
using ::testing::Invoke;
using ::testing::NotNull;
using ::testing::Return; using ::testing::Return;
using ::testing::SaveArg;
MATCHER(NotEmpty, "") {
return !arg.empty();
}
namespace media { namespace media {
...@@ -146,6 +149,37 @@ void MockCdmContext::set_cdm_id(int cdm_id) { ...@@ -146,6 +149,37 @@ void MockCdmContext::set_cdm_id(int cdm_id) {
cdm_id_ = cdm_id; cdm_id_ = cdm_id;
} }
MockCdmPromise::MockCdmPromise(bool expect_success) {
if (expect_success) {
EXPECT_CALL(*this, resolve());
EXPECT_CALL(*this, reject(_, _, _)).Times(0);
} else {
EXPECT_CALL(*this, resolve()).Times(0);
EXPECT_CALL(*this, reject(_, _, NotEmpty()));
}
}
MockCdmPromise::~MockCdmPromise() {
// The EXPECT calls will verify that the promise is in fact fulfilled.
MarkPromiseSettled();
}
MockCdmSessionPromise::MockCdmSessionPromise(bool expect_success,
std::string* new_session_id) {
if (expect_success) {
EXPECT_CALL(*this, resolve(_)).WillOnce(SaveArg<0>(new_session_id));
EXPECT_CALL(*this, reject(_, _, _)).Times(0);
} else {
EXPECT_CALL(*this, resolve(_)).Times(0);
EXPECT_CALL(*this, reject(_, _, NotEmpty()));
}
}
MockCdmSessionPromise::~MockCdmSessionPromise() {
// The EXPECT calls will verify that the promise is in fact fulfilled.
MarkPromiseSettled();
}
MockStreamParser::MockStreamParser() {} MockStreamParser::MockStreamParser() {}
MockStreamParser::~MockStreamParser() {} MockStreamParser::~MockStreamParser() {}
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "media/base/audio_renderer.h" #include "media/base/audio_renderer.h"
#include "media/base/cdm_context.h" #include "media/base/cdm_context.h"
#include "media/base/cdm_key_information.h" #include "media/base/cdm_key_information.h"
#include "media/base/cdm_promise.h"
#include "media/base/content_decryption_module.h" #include "media/base/content_decryption_module.h"
#include "media/base/decoder_buffer.h" #include "media/base/decoder_buffer.h"
#include "media/base/decryptor.h" #include "media/base/decryptor.h"
...@@ -414,6 +415,37 @@ class MockCdmContext : public CdmContext { ...@@ -414,6 +415,37 @@ class MockCdmContext : public CdmContext {
DISALLOW_COPY_AND_ASSIGN(MockCdmContext); DISALLOW_COPY_AND_ASSIGN(MockCdmContext);
}; };
class MockCdmPromise : public SimpleCdmPromise {
public:
// |expect_success| is true if resolve() should be called, false if reject()
// is expected.
explicit MockCdmPromise(bool expect_success);
~MockCdmPromise() override;
MOCK_METHOD0(resolve, void());
MOCK_METHOD3(reject,
void(CdmPromise::Exception, uint32_t, const std::string&));
private:
DISALLOW_COPY_AND_ASSIGN(MockCdmPromise);
};
class MockCdmSessionPromise : public NewSessionCdmPromise {
public:
// |expect_success| is true if resolve() should be called, false if reject()
// is expected. |new_session_id| is updated with the new session's ID on
// resolve().
MockCdmSessionPromise(bool expect_success, std::string* new_session_id);
~MockCdmSessionPromise() override;
MOCK_METHOD1(resolve, void(const std::string&));
MOCK_METHOD3(reject,
void(CdmPromise::Exception, uint32_t, const std::string&));
private:
DISALLOW_COPY_AND_ASSIGN(MockCdmSessionPromise);
};
class MockStreamParser : public StreamParser { class MockStreamParser : public StreamParser {
public: public:
MockStreamParser(); MockStreamParser();
......
...@@ -122,14 +122,14 @@ void MojoCdm::InitializeCdm(const std::string& key_system, ...@@ -122,14 +122,14 @@ void MojoCdm::InitializeCdm(const std::string& key_system,
base::Bind(&MojoCdm::OnCdmInitialized, base::Unretained(this))); base::Bind(&MojoCdm::OnCdmInitialized, base::Unretained(this)));
} }
// TODO(xhwang): Properly handle CDM calls after connection error.
// See http://crbug.com/671362
void MojoCdm::OnConnectionError(uint32_t custom_reason, void MojoCdm::OnConnectionError(uint32_t custom_reason,
const std::string& description) { const std::string& description) {
LOG(ERROR) << "Remote CDM connection error: custom_reason=" << custom_reason LOG(ERROR) << "Remote CDM connection error: custom_reason=" << custom_reason
<< ", description=\"" << description << "\""; << ", description=\"" << description << "\"";
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
remote_cdm_.reset();
// Handle initial connection error. // Handle initial connection error.
if (pending_init_promise_) { if (pending_init_promise_) {
DCHECK(!cdm_session_tracker_.HasRemainingSessions()); DCHECK(!cdm_session_tracker_.HasRemainingSessions());
...@@ -148,6 +148,12 @@ void MojoCdm::SetServerCertificate(const std::vector<uint8_t>& certificate, ...@@ -148,6 +148,12 @@ void MojoCdm::SetServerCertificate(const std::vector<uint8_t>& certificate,
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->SetServerCertificate( remote_cdm_->SetServerCertificate(
certificate, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult, certificate, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult,
base::Unretained(this), base::Passed(&promise))); base::Unretained(this), base::Passed(&promise)));
...@@ -161,6 +167,12 @@ void MojoCdm::CreateSessionAndGenerateRequest( ...@@ -161,6 +167,12 @@ void MojoCdm::CreateSessionAndGenerateRequest(
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->CreateSessionAndGenerateRequest( remote_cdm_->CreateSessionAndGenerateRequest(
session_type, init_data_type, init_data, session_type, init_data_type, init_data,
base::Bind(&MojoCdm::OnNewSessionCdmPromiseResult, base::Unretained(this), base::Bind(&MojoCdm::OnNewSessionCdmPromiseResult, base::Unretained(this),
...@@ -173,6 +185,12 @@ void MojoCdm::LoadSession(SessionType session_type, ...@@ -173,6 +185,12 @@ void MojoCdm::LoadSession(SessionType session_type,
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->LoadSession( remote_cdm_->LoadSession(
session_type, session_id, session_type, session_id,
base::Bind(&MojoCdm::OnNewSessionCdmPromiseResult, base::Unretained(this), base::Bind(&MojoCdm::OnNewSessionCdmPromiseResult, base::Unretained(this),
...@@ -185,6 +203,12 @@ void MojoCdm::UpdateSession(const std::string& session_id, ...@@ -185,6 +203,12 @@ void MojoCdm::UpdateSession(const std::string& session_id,
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->UpdateSession( remote_cdm_->UpdateSession(
session_id, response, session_id, response,
base::Bind(&MojoCdm::OnSimpleCdmPromiseResult, base::Unretained(this), base::Bind(&MojoCdm::OnSimpleCdmPromiseResult, base::Unretained(this),
...@@ -196,6 +220,12 @@ void MojoCdm::CloseSession(const std::string& session_id, ...@@ -196,6 +220,12 @@ void MojoCdm::CloseSession(const std::string& session_id,
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->CloseSession( remote_cdm_->CloseSession(
session_id, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult, session_id, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult,
base::Unretained(this), base::Passed(&promise))); base::Unretained(this), base::Passed(&promise)));
...@@ -206,6 +236,12 @@ void MojoCdm::RemoveSession(const std::string& session_id, ...@@ -206,6 +236,12 @@ void MojoCdm::RemoveSession(const std::string& session_id,
DVLOG(2) << __func__; DVLOG(2) << __func__;
DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(thread_checker_.CalledOnValidThread());
if (!remote_cdm_) {
promise->reject(media::CdmPromise::INVALID_STATE_ERROR, 0,
"CDM connection lost.");
return;
}
remote_cdm_->RemoveSession( remote_cdm_->RemoveSession(
session_id, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult, session_id, base::Bind(&MojoCdm::OnSimpleCdmPromiseResult,
base::Unretained(this), base::Passed(&promise))); base::Unretained(this), base::Passed(&promise)));
......
...@@ -20,8 +20,13 @@ ...@@ -20,8 +20,13 @@
#include "mojo/public/cpp/bindings/interface_request.h" #include "mojo/public/cpp/bindings/interface_request.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
using ::testing::_;
using ::testing::StrictMock; using ::testing::StrictMock;
MATCHER(NotEmpty, "") {
return !arg.empty();
}
namespace media { namespace media {
namespace { namespace {
...@@ -29,11 +34,18 @@ namespace { ...@@ -29,11 +34,18 @@ namespace {
const char kClearKeyKeySystem[] = "org.w3.clearkey"; const char kClearKeyKeySystem[] = "org.w3.clearkey";
const char kTestSecurityOrigin[] = "https://www.test.com"; const char kTestSecurityOrigin[] = "https://www.test.com";
// Random key ID used to create a session.
const uint8_t kKeyId[] = {
// base64 equivalent is AQIDBAUGBwgJCgsMDQ4PEA
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
};
} // namespace } // namespace
class MojoCdmTest : public ::testing::Test { class MojoCdmTest : public ::testing::Test {
public: public:
enum ExpectedResult { SUCCESS, CONNECTION_ERROR }; enum ExpectedResult { SUCCESS, CONNECTION_ERROR, FAILURE };
MojoCdmTest() MojoCdmTest()
: mojo_cdm_service_(base::MakeUnique<MojoCdmService>( : mojo_cdm_service_(base::MakeUnique<MojoCdmService>(
...@@ -43,12 +55,14 @@ class MojoCdmTest : public ::testing::Test { ...@@ -43,12 +55,14 @@ class MojoCdmTest : public ::testing::Test {
virtual ~MojoCdmTest() {} virtual ~MojoCdmTest() {}
void Initialize(ExpectedResult expected_result) { void Initialize(const std::string& key_system,
ExpectedResult expected_result) {
mojom::ContentDecryptionModulePtr remote_cdm; mojom::ContentDecryptionModulePtr remote_cdm;
auto cdm_request = mojo::GetProxy(&remote_cdm); auto cdm_request = mojo::GetProxy(&remote_cdm);
switch (expected_result) { switch (expected_result) {
case SUCCESS: case SUCCESS:
case FAILURE:
cdm_binding_.Bind(std::move(cdm_request)); cdm_binding_.Bind(std::move(cdm_request));
break; break;
case CONNECTION_ERROR: case CONNECTION_ERROR:
...@@ -56,7 +70,7 @@ class MojoCdmTest : public ::testing::Test { ...@@ -56,7 +70,7 @@ class MojoCdmTest : public ::testing::Test {
break; break;
} }
MojoCdm::Create(kClearKeyKeySystem, GURL(kTestSecurityOrigin), CdmConfig(), MojoCdm::Create(key_system, GURL(kTestSecurityOrigin), CdmConfig(),
std::move(remote_cdm), std::move(remote_cdm),
base::Bind(&MockCdmClient::OnSessionMessage, base::Bind(&MockCdmClient::OnSessionMessage,
base::Unretained(&cdm_client_)), base::Unretained(&cdm_client_)),
...@@ -70,18 +84,13 @@ class MojoCdmTest : public ::testing::Test { ...@@ -70,18 +84,13 @@ class MojoCdmTest : public ::testing::Test {
base::Unretained(this), expected_result)); base::Unretained(this), expected_result));
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
if (expected_result == SUCCESS) {
EXPECT_TRUE(mojo_cdm_);
} else {
EXPECT_FALSE(mojo_cdm_);
}
} }
void OnCdmCreated(ExpectedResult expected_result, void OnCdmCreated(ExpectedResult expected_result,
const scoped_refptr<ContentDecryptionModule>& cdm, const scoped_refptr<ContentDecryptionModule>& cdm,
const std::string& error_message) { const std::string& error_message) {
if (!cdm) { if (!cdm) {
EXPECT_NE(SUCCESS, expected_result);
DVLOG(1) << error_message; DVLOG(1) << error_message;
return; return;
} }
...@@ -90,6 +99,57 @@ class MojoCdmTest : public ::testing::Test { ...@@ -90,6 +99,57 @@ class MojoCdmTest : public ::testing::Test {
mojo_cdm_ = cdm; mojo_cdm_ = cdm;
} }
void ForceConnectionError() {
// If there is an existing session it will get closed when the connection
// is broken.
if (!session_id_.empty()) {
EXPECT_CALL(cdm_client_, OnSessionClosed(session_id_));
}
cdm_binding_.CloseWithReason(2, "Test closed connection.");
base::RunLoop().RunUntilIdle();
}
void SetServerCertificateAndExpect(const std::vector<uint8_t>& certificate,
ExpectedResult expected_result) {
mojo_cdm_->SetServerCertificate(
certificate,
base::MakeUnique<MockCdmPromise>(expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
void CreateSessionAndExpect(EmeInitDataType data_type,
const std::vector<uint8_t>& key_id,
ExpectedResult expected_result) {
if (expected_result == SUCCESS) {
EXPECT_CALL(cdm_client_, OnSessionMessage(NotEmpty(), _, _));
}
mojo_cdm_->CreateSessionAndGenerateRequest(
ContentDecryptionModule::SessionType::TEMPORARY_SESSION, data_type,
key_id, base::MakeUnique<MockCdmSessionPromise>(
expected_result == SUCCESS, &session_id_));
base::RunLoop().RunUntilIdle();
}
void CloseSessionAndExpect(ExpectedResult expected_result) {
DCHECK(!session_id_.empty()) << "CloseSessionAndExpect() must be called "
"after a successful "
"CreateSessionAndExpect()";
if (expected_result == SUCCESS) {
EXPECT_CALL(cdm_client_, OnSessionClosed(session_id_));
}
mojo_cdm_->CloseSession(session_id_, base::MakeUnique<MockCdmPromise>(
expected_result == SUCCESS));
base::RunLoop().RunUntilIdle();
}
// Fixture members. // Fixture members.
base::TestMessageLoop message_loop_; base::TestMessageLoop message_loop_;
...@@ -104,18 +164,55 @@ class MojoCdmTest : public ::testing::Test { ...@@ -104,18 +164,55 @@ class MojoCdmTest : public ::testing::Test {
mojo::Binding<mojom::ContentDecryptionModule> cdm_binding_; mojo::Binding<mojom::ContentDecryptionModule> cdm_binding_;
scoped_refptr<ContentDecryptionModule> mojo_cdm_; scoped_refptr<ContentDecryptionModule> mojo_cdm_;
// |session_id_| is the latest successful result of calling CreateSession().
std::string session_id_;
private: private:
DISALLOW_COPY_AND_ASSIGN(MojoCdmTest); DISALLOW_COPY_AND_ASSIGN(MojoCdmTest);
}; };
TEST_F(MojoCdmTest, Create_Success) { TEST_F(MojoCdmTest, Create_Success) {
Initialize(SUCCESS); Initialize(kClearKeyKeySystem, SUCCESS);
} }
TEST_F(MojoCdmTest, Create_ConnectionError) { TEST_F(MojoCdmTest, Create_ConnectionError) {
Initialize(CONNECTION_ERROR); Initialize(kClearKeyKeySystem, CONNECTION_ERROR);
}
TEST_F(MojoCdmTest, Create_Failure) {
// This fails as DefaultCdmFactory only supports Clear Key.
Initialize("org.random.cdm", FAILURE);
}
TEST_F(MojoCdmTest, SetServerCertificate_AfterConnectionError) {
Initialize(kClearKeyKeySystem, SUCCESS);
ForceConnectionError();
SetServerCertificateAndExpect({0, 1, 2}, FAILURE);
}
TEST_F(MojoCdmTest, CreateSessionAndGenerateRequest_AfterConnectionError) {
std::vector<uint8_t> key_id(kKeyId, kKeyId + arraysize(kKeyId));
Initialize(kClearKeyKeySystem, SUCCESS);
ForceConnectionError();
CreateSessionAndExpect(EmeInitDataType::WEBM, key_id, FAILURE);
} }
// TODO(xhwang): Add more test cases! TEST_F(MojoCdmTest, CloseSession_Success) {
std::vector<uint8_t> key_id(kKeyId, kKeyId + arraysize(kKeyId));
Initialize(kClearKeyKeySystem, SUCCESS);
CreateSessionAndExpect(EmeInitDataType::WEBM, key_id, SUCCESS);
CloseSessionAndExpect(SUCCESS);
}
TEST_F(MojoCdmTest, CloseSession_AfterConnectionError) {
std::vector<uint8_t> key_id(kKeyId, kKeyId + arraysize(kKeyId));
Initialize(kClearKeyKeySystem, SUCCESS);
CreateSessionAndExpect(EmeInitDataType::WEBM, key_id, SUCCESS);
ForceConnectionError();
CloseSessionAndExpect(FAILURE);
}
} // namespace media } // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment