Commit 1bb72b7e authored by Xiaohan Wang's avatar Xiaohan Wang Committed by Commit Bot

media: Handle duplication session ID in MediaFoundationCdm

In normal cases the session ID returned by the MediaFoundation CDM
(IMFContentDecryptionModule) should never be the same. In the rare case
duplicate session IDs are returned, the current implementation will
crash.

This CL fixes the implementation and adds a unit test.

NOPRESUBMIT=true

Bug: 999747
Change-Id: I319c33517dd6452cb45ee30d3c1c02a41b0c6ccf
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2278416Reviewed-by: default avatarJohn Rummell <jrummell@chromium.org>
Commit-Queue: Xiaohan Wang <xhwang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#785817}
parent 544e0953
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "media/cdm/win/media_foundation_cdm.h" #include "media/cdm/win/media_foundation_cdm.h"
#include "base/bind.h"
#include "base/logging.h" #include "base/logging.h"
#include "media/base/bind_to_current_loop.h" #include "media/base/bind_to_current_loop.h"
#include "media/base/cdm_promise.h" #include "media/base/cdm_promise.h"
...@@ -176,10 +177,11 @@ void MediaFoundationCdm::CreateSessionAndGenerateRequest( ...@@ -176,10 +177,11 @@ void MediaFoundationCdm::CreateSessionAndGenerateRequest(
int session_token = next_session_token_++; int session_token = next_session_token_++;
// Keep a raw pointer since the |promise| will be moved to the callback. // Keep a raw pointer since the |promise| will be moved to the callback.
// Use base::Unretained() is safe because |session| is owned by |this|.
auto* raw_promise = promise.get(); auto* raw_promise = promise.get();
auto session_id_cb = base::BindOnce(&MediaFoundationCdm::OnSessionId, auto session_id_cb =
weak_factory_.GetWeakPtr(), session_token, base::BindOnce(&MediaFoundationCdm::OnSessionId, base::Unretained(this),
std::move(promise)); session_token, std::move(promise));
if (FAILED(session->GenerateRequest(init_data_type, init_data, if (FAILED(session->GenerateRequest(init_data_type, init_data,
std::move(session_id_cb)))) { std::move(session_id_cb)))) {
...@@ -285,7 +287,7 @@ bool MediaFoundationCdm::GetMediaFoundationCdmProxy( ...@@ -285,7 +287,7 @@ bool MediaFoundationCdm::GetMediaFoundationCdmProxy(
return true; return true;
} }
void MediaFoundationCdm::OnSessionId( bool MediaFoundationCdm::OnSessionId(
int session_token, int session_token,
std::unique_ptr<NewSessionCdmPromise> promise, std::unique_ptr<NewSessionCdmPromise> promise,
const std::string& session_id) { const std::string& session_id) {
...@@ -301,11 +303,12 @@ void MediaFoundationCdm::OnSessionId( ...@@ -301,11 +303,12 @@ void MediaFoundationCdm::OnSessionId(
if (session_id.empty() || sessions_.count(session_id)) { if (session_id.empty() || sessions_.count(session_id)) {
promise->reject(Exception::INVALID_STATE_ERROR, 0, promise->reject(Exception::INVALID_STATE_ERROR, 0,
"Empty or duplicate session ID"); "Empty or duplicate session ID");
return; return false;
} }
sessions_.emplace(session_id, std::move(session)); sessions_.emplace(session_id, std::move(session));
promise->resolve(session_id); promise->resolve(session_id);
return true;
} }
MediaFoundationCdmSession* MediaFoundationCdm::GetSession( MediaFoundationCdmSession* MediaFoundationCdm::GetSession(
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "base/memory/weak_ptr.h"
#include "media/base/cdm_context.h" #include "media/base/cdm_context.h"
#include "media/base/content_decryption_module.h" #include "media/base/content_decryption_module.h"
#include "media/base/media_export.h" #include "media/base/media_export.h"
...@@ -64,7 +63,8 @@ class MEDIA_EXPORT MediaFoundationCdm : public ContentDecryptionModule, ...@@ -64,7 +63,8 @@ class MEDIA_EXPORT MediaFoundationCdm : public ContentDecryptionModule,
private: private:
~MediaFoundationCdm() final; ~MediaFoundationCdm() final;
void OnSessionId(int session_token, // Returns whether the |session_id| is accepted by the |this|.
bool OnSessionId(int session_token,
std::unique_ptr<NewSessionCdmPromise> promise, std::unique_ptr<NewSessionCdmPromise> promise,
const std::string& session_id); const std::string& session_id);
...@@ -89,9 +89,6 @@ class MEDIA_EXPORT MediaFoundationCdm : public ContentDecryptionModule, ...@@ -89,9 +89,6 @@ class MEDIA_EXPORT MediaFoundationCdm : public ContentDecryptionModule,
std::map<std::string, std::unique_ptr<MediaFoundationCdmSession>> sessions_; std::map<std::string, std::unique_ptr<MediaFoundationCdmSession>> sessions_;
Microsoft::WRL::ComPtr<IMFCdmProxy> cdm_proxy_; Microsoft::WRL::ComPtr<IMFCdmProxy> cdm_proxy_;
// NOTE: Weak pointers must be invalidated before all other member variables.
base::WeakPtrFactory<MediaFoundationCdm> weak_factory_{this};
}; };
} // namespace media } // namespace media
......
...@@ -225,15 +225,17 @@ void MediaFoundationCdmSession::OnSessionMessage( ...@@ -225,15 +225,17 @@ void MediaFoundationCdmSession::OnSessionMessage(
const std::vector<uint8_t>& message) { const std::vector<uint8_t>& message) {
DVLOG_FUNC(2); DVLOG_FUNC(2);
if (session_id_.empty()) if (session_id_.empty() && !session_id_cb_) {
SetSessionId(); DLOG(ERROR) << "Unexpected session message";
return;
}
// Empty |session_id_| will be treated as failure by the caller. // If |session_id_| has not been set, set it now.
if (session_id_cb_) if (session_id_.empty() && !SetSessionId())
std::move(session_id_cb_).Run(session_id_); return;
if (!session_id_.empty()) DCHECK(!session_id_.empty());
session_message_cb_.Run(session_id_, message_type, message); session_message_cb_.Run(session_id_, message_type, message);
} }
void MediaFoundationCdmSession::OnSessionKeysChange() { void MediaFoundationCdmSession::OnSessionKeysChange() {
...@@ -265,16 +267,33 @@ void MediaFoundationCdmSession::OnSessionKeysChange() { ...@@ -265,16 +267,33 @@ void MediaFoundationCdmSession::OnSessionKeysChange() {
} }
} }
void MediaFoundationCdmSession::SetSessionId() { bool MediaFoundationCdmSession::SetSessionId() {
DCHECK(session_id_.empty()); DCHECK(session_id_.empty() && session_id_cb_);
base::win::ScopedCoMem<wchar_t> session_id; base::win::ScopedCoMem<wchar_t> session_id;
HRESULT hr = mf_cdm_session_->GetSessionId(&session_id); HRESULT hr = mf_cdm_session_->GetSessionId(&session_id);
if (FAILED(hr) || !session_id) if (FAILED(hr) || !session_id) {
return; bool success = std::move(session_id_cb_).Run("");
DCHECK(!success) << "Empty session ID should not be accepted";
return false;
}
auto session_id_str = base::UTF16ToUTF8(session_id.get());
if (session_id_str.empty()) {
bool success = std::move(session_id_cb_).Run("");
DCHECK(!success) << "Empty session ID should not be accepted";
return false;
}
bool success = std::move(session_id_cb_).Run(session_id_str);
if (!success) {
DLOG(ERROR) << "Session ID " << session_id_str << " rejected";
return false;
}
session_id_ = base::UTF16ToUTF8(session_id.get()); DVLOG_FUNC(1) << "session_id_=" << session_id_str;
DVLOG_FUNC(1) << "session_id_=" << session_id_; session_id_ = session_id_str;
return true;
} }
} // namespace media } // namespace media
...@@ -35,17 +35,24 @@ class MEDIA_EXPORT MediaFoundationCdmSession { ...@@ -35,17 +35,24 @@ class MEDIA_EXPORT MediaFoundationCdmSession {
// EME MediaKeySession methods. Returns S_OK on success, otherwise forwards // EME MediaKeySession methods. Returns S_OK on success, otherwise forwards
// the HRESULT from IMFContentDecryptionModuleSession. // the HRESULT from IMFContentDecryptionModuleSession.
// Note on GenerateRequest():
// - Returns S_OK, which has two cases: // Callback to pass the session ID to the caller. The return value indicates
// * If |session_id_| is successfully set, |session_id_cb| will be run // whether the session ID is accepted by the caller. If returns false, the
// followed by the session message. // session ID is rejected by the caller (e.g. empty of duplicate session IDs),
// * Otherwise, |session_id_cb| will be run with an empty session ID to // and |this| could be destructed immediately by the caller.
// indicate error. No session message in this case. using SessionIdCB = base::OnceCallback<bool(const std::string&)>;
// - Otherwise, no callbacks will be run.
using SessionIdCB = base::OnceCallback<void(const std::string&)>; // Creates session ID and generates requests. Returns an error HRESULT on
// immediate failure, in which case no callbacks will be run. Otherwise,
// returns S_OK, with the following two cases:
// - If |session_id_| is successfully set, |session_id_cb| will be run with
// |session_id_| followed by the session message via |session_message_cb_|.
// - Otherwise, |session_id_cb| will be run with an empty session ID to
// indicate error. No session message in this case.
HRESULT GenerateRequest(EmeInitDataType init_data_type, HRESULT GenerateRequest(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data, const std::vector<uint8_t>& init_data,
SessionIdCB session_id_cb); SessionIdCB session_id_cb);
HRESULT Load(const std::string& session_id); HRESULT Load(const std::string& session_id);
HRESULT Update(const std::vector<uint8_t>& response); HRESULT Update(const std::vector<uint8_t>& response);
HRESULT Close(); HRESULT Close();
...@@ -57,8 +64,9 @@ class MEDIA_EXPORT MediaFoundationCdmSession { ...@@ -57,8 +64,9 @@ class MEDIA_EXPORT MediaFoundationCdmSession {
const std::vector<uint8_t>& message); const std::vector<uint8_t>& message);
void OnSessionKeysChange(); void OnSessionKeysChange();
// Sets |session_id_|, which could be empty on failure. // Sets |session_id_| and returns whether the operation succeeded.
void SetSessionId(); // Note: |this| could already been destructed if false is returned.
bool SetSessionId();
// Callbacks for firing session events. // Callbacks for firing session events.
SessionMessageCB session_message_cb_; SessionMessageCB session_message_cb_;
......
...@@ -81,7 +81,7 @@ class MediaFoundationCdmSessionTest : public testing::Test { ...@@ -81,7 +81,7 @@ class MediaFoundationCdmSessionTest : public testing::Test {
})); }));
COM_EXPECT_CALL(mf_cdm_session_, GetSessionId(_)) COM_EXPECT_CALL(mf_cdm_session_, GetSessionId(_))
.WillOnce(DoAll(SetArgPointee<0>(session_id), Return(S_OK))); .WillOnce(DoAll(SetArgPointee<0>(session_id), Return(S_OK)));
EXPECT_CALL(session_id_cb, Run(_)); EXPECT_CALL(session_id_cb, Run(_)).WillOnce(Return(true));
EXPECT_CALL(cdm_client_, EXPECT_CALL(cdm_client_,
OnSessionMessage(_, CdmMessageType::LICENSE_REQUEST, OnSessionMessage(_, CdmMessageType::LICENSE_REQUEST,
license_request)); license_request));
......
...@@ -63,7 +63,8 @@ class MediaFoundationCdmTest : public testing::Test { ...@@ -63,7 +63,8 @@ class MediaFoundationCdmTest : public testing::Test {
void SetGenerateRequestExpectations( void SetGenerateRequestExpectations(
ComPtr<MockMFCdmSession> mf_cdm_session, ComPtr<MockMFCdmSession> mf_cdm_session,
const char* session_id, const char* session_id,
IMFContentDecryptionModuleSessionCallbacks** mf_cdm_session_callbacks) { IMFContentDecryptionModuleSessionCallbacks** mf_cdm_session_callbacks,
bool expect_message = true) {
std::vector<uint8_t> license_request = StringToVector("request"); std::vector<uint8_t> license_request = StringToVector("request");
// Session ID to return. Will be released by |mf_cdm_session_|. // Session ID to return. Will be released by |mf_cdm_session_|.
...@@ -85,9 +86,11 @@ class MediaFoundationCdmTest : public testing::Test { ...@@ -85,9 +86,11 @@ class MediaFoundationCdmTest : public testing::Test {
COM_EXPECT_CALL(mf_cdm_session, GetSessionId(_)) COM_EXPECT_CALL(mf_cdm_session, GetSessionId(_))
.WillOnce(DoAll(SetArgPointee<0>(mf_session_id), Return(S_OK))); .WillOnce(DoAll(SetArgPointee<0>(mf_session_id), Return(S_OK)));
EXPECT_CALL(cdm_client_, if (expect_message) {
OnSessionMessage(session_id, CdmMessageType::LICENSE_REQUEST, EXPECT_CALL(cdm_client_,
license_request)); OnSessionMessage(session_id, CdmMessageType::LICENSE_REQUEST,
license_request));
}
} }
void CreateSessionAndGenerateRequest() { void CreateSessionAndGenerateRequest() {
...@@ -222,6 +225,46 @@ TEST_F(MediaFoundationCdmTest, ...@@ -222,6 +225,46 @@ TEST_F(MediaFoundationCdmTest,
EXPECT_TRUE(session_id_.empty()); EXPECT_TRUE(session_id_.empty());
} }
// Duplicate session IDs cause session creation failure.
TEST_F(MediaFoundationCdmTest,
CreateSessionAndGenerateRequest_DuplicateSessionId) {
std::vector<uint8_t> init_data = StringToVector("init_data");
auto mf_cdm_session_1 = MakeComPtr<MockMFCdmSession>();
auto mf_cdm_session_2 = MakeComPtr<MockMFCdmSession>();
ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_1;
ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_2;
COM_EXPECT_CALL(mf_cdm_,
CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
.WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_1),
SetComPointee<2>(mf_cdm_session_1.Get()), Return(S_OK)))
.WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_2),
SetComPointee<2>(mf_cdm_session_2.Get()), Return(S_OK)));
// In both sessions we return kSessionId. Session 1 succeeds. Session 2 fails
// because of duplicate session ID.
SetGenerateRequestExpectations(mf_cdm_session_1, kSessionId,
&mf_cdm_session_callbacks_1);
SetGenerateRequestExpectations(mf_cdm_session_2, kSessionId,
&mf_cdm_session_callbacks_2,
/*expect_message=*/false);
std::string session_id_1;
std::string session_id_2;
cdm_->CreateSessionAndGenerateRequest(
CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
std::make_unique<MockCdmSessionPromise>(/*expect_success=*/true,
&session_id_1));
cdm_->CreateSessionAndGenerateRequest(
CdmSessionType::kTemporary, EmeInitDataType::WEBM, init_data,
std::make_unique<MockCdmSessionPromise>(/*expect_success=*/false,
&session_id_2));
task_environment_.RunUntilIdle();
EXPECT_EQ(session_id_1, kSessionId);
EXPECT_TRUE(session_id_2.empty());
}
// LoadSession() is not implemented. // LoadSession() is not implemented.
TEST_F(MediaFoundationCdmTest, LoadSession) { TEST_F(MediaFoundationCdmTest, LoadSession) {
cdm_->LoadSession(CdmSessionType::kPersistentLicense, kSessionId, cdm_->LoadSession(CdmSessionType::kPersistentLicense, kSessionId,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment