Commit eff1fa52 authored by shess@chromium.org's avatar shess@chromium.org

Put debugging assertions into sql::Statement.

Pulls out the core of gbillock's http://codereview.chromium.org/8283002/
- Move NOTREACHED and similar checks into the sql:: implementation code.
- Add malformed SQL checks to Connection::Execute.
- Add SQL-checking convenience methods to Connection.

The general idea is that the sql:: framework assumes valid statements,
rather than having client code contain scattered ad-hoc (and thus
inconsistent) checks.

This version puts back Statement operator overloading and loosy-goosy
Execute() calls to allow other code to be updated in small batches.

R=gbillock@chromium.org,jhawkins@chromium.org,dhollowa@chromium.org 
BUG=none
TEST=sql_unittests,unit_tests:*Table*.*


Review URL: http://codereview.chromium.org/8899012

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@114118 0039d316-1c4b-4281-b951-d872f2087c98
parent d2528e60
...@@ -132,7 +132,7 @@ void Connection::Close() { ...@@ -132,7 +132,7 @@ void Connection::Close() {
void Connection::Preload() { void Connection::Preload() {
if (!db_) { if (!db_) {
NOTREACHED(); DLOG(FATAL) << "Cannot preload null db";
return; return;
} }
...@@ -142,7 +142,7 @@ void Connection::Preload() { ...@@ -142,7 +142,7 @@ void Connection::Preload() {
if (!DoesTableExist("meta")) if (!DoesTableExist("meta"))
return; return;
Statement dummy(GetUniqueStatement("SELECT * FROM meta")); Statement dummy(GetUniqueStatement("SELECT * FROM meta"));
if (!dummy || !dummy.Step()) if (!dummy.Step())
return; return;
#if !defined(USE_SYSTEM_SQLITE) #if !defined(USE_SYSTEM_SQLITE)
...@@ -166,7 +166,7 @@ bool Connection::BeginTransaction() { ...@@ -166,7 +166,7 @@ bool Connection::BeginTransaction() {
needs_rollback_ = false; needs_rollback_ = false;
Statement begin(GetCachedStatement(SQL_FROM_HERE, "BEGIN TRANSACTION")); Statement begin(GetCachedStatement(SQL_FROM_HERE, "BEGIN TRANSACTION"));
if (!begin || !begin.Run()) if (!begin.Run())
return false; return false;
} }
transaction_nesting_++; transaction_nesting_++;
...@@ -175,7 +175,7 @@ bool Connection::BeginTransaction() { ...@@ -175,7 +175,7 @@ bool Connection::BeginTransaction() {
void Connection::RollbackTransaction() { void Connection::RollbackTransaction() {
if (!transaction_nesting_) { if (!transaction_nesting_) {
NOTREACHED() << "Rolling back a nonexistent transaction"; DLOG(FATAL) << "Rolling back a nonexistent transaction";
return; return;
} }
...@@ -192,7 +192,7 @@ void Connection::RollbackTransaction() { ...@@ -192,7 +192,7 @@ void Connection::RollbackTransaction() {
bool Connection::CommitTransaction() { bool Connection::CommitTransaction() {
if (!transaction_nesting_) { if (!transaction_nesting_) {
NOTREACHED() << "Rolling back a nonexistent transaction"; DLOG(FATAL) << "Rolling back a nonexistent transaction";
return false; return false;
} }
transaction_nesting_--; transaction_nesting_--;
...@@ -208,15 +208,22 @@ bool Connection::CommitTransaction() { ...@@ -208,15 +208,22 @@ bool Connection::CommitTransaction() {
} }
Statement commit(GetCachedStatement(SQL_FROM_HERE, "COMMIT")); Statement commit(GetCachedStatement(SQL_FROM_HERE, "COMMIT"));
if (!commit)
return false;
return commit.Run(); return commit.Run();
} }
bool Connection::Execute(const char* sql) { int Connection::ExecuteAndReturnErrorCode(const char* sql) {
if (!db_) if (!db_)
return false; return false;
return sqlite3_exec(db_, sql, NULL, NULL, NULL) == SQLITE_OK; return sqlite3_exec(db_, sql, NULL, NULL, NULL);
}
bool Connection::Execute(const char* sql) {
int error = ExecuteAndReturnErrorCode(sql);
// TODO(shess,gbillock): DLOG(FATAL) once Execute() clients are
// converted.
if (error == SQLITE_ERROR)
DLOG(ERROR) << "SQL Error in " << sql << ", " << GetErrorMessage();
return error == SQLITE_OK;
} }
bool Connection::ExecuteWithTimeout(const char* sql, base::TimeDelta timeout) { bool Connection::ExecuteWithTimeout(const char* sql, base::TimeDelta timeout) {
...@@ -225,7 +232,7 @@ bool Connection::ExecuteWithTimeout(const char* sql, base::TimeDelta timeout) { ...@@ -225,7 +232,7 @@ bool Connection::ExecuteWithTimeout(const char* sql, base::TimeDelta timeout) {
ScopedBusyTimeout busy_timeout(db_); ScopedBusyTimeout busy_timeout(db_);
busy_timeout.SetTimeout(timeout); busy_timeout.SetTimeout(timeout);
return sqlite3_exec(db_, sql, NULL, NULL, NULL) == SQLITE_OK; return Execute(sql);
} }
bool Connection::HasCachedStatement(const StatementID& id) const { bool Connection::HasCachedStatement(const StatementID& id) const {
...@@ -259,22 +266,28 @@ scoped_refptr<Connection::StatementRef> Connection::GetUniqueStatement( ...@@ -259,22 +266,28 @@ scoped_refptr<Connection::StatementRef> Connection::GetUniqueStatement(
sqlite3_stmt* stmt = NULL; sqlite3_stmt* stmt = NULL;
if (sqlite3_prepare_v2(db_, sql, -1, &stmt, NULL) != SQLITE_OK) { if (sqlite3_prepare_v2(db_, sql, -1, &stmt, NULL) != SQLITE_OK) {
// Treat this as non-fatal, it can occur in a number of valid cases, and // This is evidence of a syntax error in the incoming SQL.
// callers should be doing their own error handling. DLOG(FATAL) << "SQL compile error " << GetErrorMessage();
DLOG(WARNING) << "SQL compile error " << GetErrorMessage();
return new StatementRef(this, NULL); return new StatementRef(this, NULL);
} }
return new StatementRef(this, stmt); return new StatementRef(this, stmt);
} }
bool Connection::IsSQLValid(const char* sql) {
sqlite3_stmt* stmt = NULL;
if (sqlite3_prepare_v2(db_, sql, -1, &stmt, NULL) != SQLITE_OK)
return false;
sqlite3_finalize(stmt);
return true;
}
bool Connection::DoesTableExist(const char* table_name) const { bool Connection::DoesTableExist(const char* table_name) const {
// GetUniqueStatement can't be const since statements may modify the // GetUniqueStatement can't be const since statements may modify the
// database, but we know ours doesn't modify it, so the cast is safe. // database, but we know ours doesn't modify it, so the cast is safe.
Statement statement(const_cast<Connection*>(this)->GetUniqueStatement( Statement statement(const_cast<Connection*>(this)->GetUniqueStatement(
"SELECT name FROM sqlite_master " "SELECT name FROM sqlite_master "
"WHERE type='table' AND name=?")); "WHERE type='table' AND name=?"));
if (!statement)
return false;
statement.BindString(0, table_name); statement.BindString(0, table_name);
return statement.Step(); // Table exists if any row was returned. return statement.Step(); // Table exists if any row was returned.
} }
...@@ -288,8 +301,6 @@ bool Connection::DoesColumnExist(const char* table_name, ...@@ -288,8 +301,6 @@ bool Connection::DoesColumnExist(const char* table_name,
// Our SQL is non-mutating, so this cast is OK. // Our SQL is non-mutating, so this cast is OK.
Statement statement(const_cast<Connection*>(this)->GetUniqueStatement( Statement statement(const_cast<Connection*>(this)->GetUniqueStatement(
sql.c_str())); sql.c_str()));
if (!statement)
return false;
while (statement.Step()) { while (statement.Step()) {
if (!statement.ColumnString(1).compare(column_name)) if (!statement.ColumnString(1).compare(column_name))
...@@ -300,7 +311,7 @@ bool Connection::DoesColumnExist(const char* table_name, ...@@ -300,7 +311,7 @@ bool Connection::DoesColumnExist(const char* table_name,
int64 Connection::GetLastInsertRowId() const { int64 Connection::GetLastInsertRowId() const {
if (!db_) { if (!db_) {
NOTREACHED(); DLOG(FATAL) << "Illegal use of connection without a db";
return 0; return 0;
} }
return sqlite3_last_insert_rowid(db_); return sqlite3_last_insert_rowid(db_);
...@@ -308,7 +319,7 @@ int64 Connection::GetLastInsertRowId() const { ...@@ -308,7 +319,7 @@ int64 Connection::GetLastInsertRowId() const {
int Connection::GetLastChangeCount() const { int Connection::GetLastChangeCount() const {
if (!db_) { if (!db_) {
NOTREACHED(); DLOG(FATAL) << "Illegal use of connection without a db";
return 0; return 0;
} }
return sqlite3_changes(db_); return sqlite3_changes(db_);
...@@ -339,7 +350,7 @@ const char* Connection::GetErrorMessage() const { ...@@ -339,7 +350,7 @@ const char* Connection::GetErrorMessage() const {
bool Connection::OpenInternal(const std::string& file_name) { bool Connection::OpenInternal(const std::string& file_name) {
if (db_) { if (db_) {
NOTREACHED() << "sql::Connection is already open."; DLOG(FATAL) << "sql::Connection is already open.";
return false; return false;
} }
...@@ -370,7 +381,7 @@ bool Connection::OpenInternal(const std::string& file_name) { ...@@ -370,7 +381,7 @@ bool Connection::OpenInternal(const std::string& file_name) {
// which requests exclusive locking but doesn't get it is almost // which requests exclusive locking but doesn't get it is almost
// certain to be ill-tested. // certain to be ill-tested.
if (!Execute("PRAGMA locking_mode=EXCLUSIVE")) if (!Execute("PRAGMA locking_mode=EXCLUSIVE"))
NOTREACHED() << "Could not set locking mode: " << GetErrorMessage(); DLOG(FATAL) << "Could not set locking mode: " << GetErrorMessage();
} }
const base::TimeDelta kBusyTimeout = const base::TimeDelta kBusyTimeout =
...@@ -384,17 +395,17 @@ bool Connection::OpenInternal(const std::string& file_name) { ...@@ -384,17 +395,17 @@ bool Connection::OpenInternal(const std::string& file_name) {
DCHECK_LE(page_size_, kSqliteMaxPageSize); DCHECK_LE(page_size_, kSqliteMaxPageSize);
const std::string sql = StringPrintf("PRAGMA page_size=%d", page_size_); const std::string sql = StringPrintf("PRAGMA page_size=%d", page_size_);
if (!ExecuteWithTimeout(sql.c_str(), kBusyTimeout)) if (!ExecuteWithTimeout(sql.c_str(), kBusyTimeout))
NOTREACHED() << "Could not set page size: " << GetErrorMessage(); DLOG(FATAL) << "Could not set page size: " << GetErrorMessage();
} }
if (cache_size_ != 0) { if (cache_size_ != 0) {
const std::string sql = StringPrintf("PRAGMA cache_size=%d", cache_size_); const std::string sql = StringPrintf("PRAGMA cache_size=%d", cache_size_);
if (!ExecuteWithTimeout(sql.c_str(), kBusyTimeout)) if (!ExecuteWithTimeout(sql.c_str(), kBusyTimeout))
NOTREACHED() << "Could not set cache size: " << GetErrorMessage(); DLOG(FATAL) << "Could not set cache size: " << GetErrorMessage();
} }
if (!ExecuteWithTimeout("PRAGMA secure_delete=ON", kBusyTimeout)) { if (!ExecuteWithTimeout("PRAGMA secure_delete=ON", kBusyTimeout)) {
NOTREACHED() << "Could not enable secure_delete: " << GetErrorMessage(); DLOG(FATAL) << "Could not enable secure_delete: " << GetErrorMessage();
Close(); Close();
return false; return false;
} }
...@@ -404,8 +415,7 @@ bool Connection::OpenInternal(const std::string& file_name) { ...@@ -404,8 +415,7 @@ bool Connection::OpenInternal(const std::string& file_name) {
void Connection::DoRollback() { void Connection::DoRollback() {
Statement rollback(GetCachedStatement(SQL_FROM_HERE, "ROLLBACK")); Statement rollback(GetCachedStatement(SQL_FROM_HERE, "ROLLBACK"));
if (rollback) rollback.Run();
rollback.Run();
} }
void Connection::StatementRefCreated(StatementRef* ref) { void Connection::StatementRefCreated(StatementRef* ref) {
...@@ -416,7 +426,7 @@ void Connection::StatementRefCreated(StatementRef* ref) { ...@@ -416,7 +426,7 @@ void Connection::StatementRefCreated(StatementRef* ref) {
void Connection::StatementRefDeleted(StatementRef* ref) { void Connection::StatementRefDeleted(StatementRef* ref) {
StatementRefSet::iterator i = open_statements_.find(ref); StatementRefSet::iterator i = open_statements_.find(ref);
if (i == open_statements_.end()) if (i == open_statements_.end())
NOTREACHED(); DLOG(FATAL) << "Could not find statement";
else else
open_statements_.erase(i); open_statements_.erase(i);
} }
...@@ -436,7 +446,7 @@ int Connection::OnSqliteError(int err, sql::Statement *stmt) { ...@@ -436,7 +446,7 @@ int Connection::OnSqliteError(int err, sql::Statement *stmt) {
if (error_delegate_.get()) if (error_delegate_.get())
return error_delegate_->OnError(err, this, stmt); return error_delegate_->OnError(err, this, stmt);
// The default handling is to assert on debug and to ignore on release. // The default handling is to assert on debug and to ignore on release.
NOTREACHED() << GetErrorMessage(); DLOG(FATAL) << GetErrorMessage();
return err; return err;
} }
......
...@@ -204,8 +204,12 @@ class SQL_EXPORT Connection { ...@@ -204,8 +204,12 @@ class SQL_EXPORT Connection {
// Executes the given SQL string, returning true on success. This is // Executes the given SQL string, returning true on success. This is
// normally used for simple, 1-off statements that don't take any bound // normally used for simple, 1-off statements that don't take any bound
// parameters and don't return any data (e.g. CREATE TABLE). // parameters and don't return any data (e.g. CREATE TABLE).
// This will DCHECK if the |sql| contains errors.
bool Execute(const char* sql); bool Execute(const char* sql);
// Like Execute(), but returns the error code given by SQLite.
int ExecuteAndReturnErrorCode(const char* sql);
// Returns true if we have a statement with the given identifier already // Returns true if we have a statement with the given identifier already
// cached. This is normally not necessary to call, but can be useful if the // cached. This is normally not necessary to call, but can be useful if the
// caller has to dynamically build up SQL to avoid doing so if it's already // caller has to dynamically build up SQL to avoid doing so if it's already
...@@ -217,8 +221,10 @@ class SQL_EXPORT Connection { ...@@ -217,8 +221,10 @@ class SQL_EXPORT Connection {
// keeping commonly-used ones around for future use is important for // keeping commonly-used ones around for future use is important for
// performance. // performance.
// //
// The SQL may have an error, so the caller must check validity of the // If the |sql| has an error, an invalid, inert StatementRef is returned (and
// statement before using it. // the code will crash in debug). The caller must deal with this eventuality,
// either by checking validity of the |sql| before calling, by correctly
// handling the return of an inert statement, or both.
// //
// The StatementID and the SQL must always correspond to one-another. The // The StatementID and the SQL must always correspond to one-another. The
// ID is the lookup into the cache, so crazy things will happen if you use // ID is the lookup into the cache, so crazy things will happen if you use
...@@ -236,6 +242,10 @@ class SQL_EXPORT Connection { ...@@ -236,6 +242,10 @@ class SQL_EXPORT Connection {
scoped_refptr<StatementRef> GetCachedStatement(const StatementID& id, scoped_refptr<StatementRef> GetCachedStatement(const StatementID& id,
const char* sql); const char* sql);
// Used to check a |sql| statement for syntactic validity. If the statement is
// valid SQL, returns true.
bool IsSQLValid(const char* sql);
// Returns a non-cached statement for the given SQL. Use this for SQL that // Returns a non-cached statement for the given SQL. Use this for SQL that
// is only executed once or only rarely (there is overhead associated with // is only executed once or only rarely (there is overhead associated with
// keeping a statement cached). // keeping a statement cached).
...@@ -274,7 +284,7 @@ class SQL_EXPORT Connection { ...@@ -274,7 +284,7 @@ class SQL_EXPORT Connection {
const char* GetErrorMessage() const; const char* GetErrorMessage() const;
private: private:
// Statement access StatementRef which we don't want to expose to erverybody // Statement accesses StatementRef which we don't want to expose to everybody
// (they should go through Statement). // (they should go through Statement).
friend class Statement; friend class Statement;
......
...@@ -35,10 +35,21 @@ TEST_F(SQLConnectionTest, Execute) { ...@@ -35,10 +35,21 @@ TEST_F(SQLConnectionTest, Execute) {
EXPECT_EQ(SQLITE_OK, db().GetErrorCode()); EXPECT_EQ(SQLITE_OK, db().GetErrorCode());
// Invalid statement should fail. // Invalid statement should fail.
ASSERT_FALSE(db().Execute("CREATE TAB foo (a, b")); ASSERT_EQ(SQLITE_ERROR,
db().ExecuteAndReturnErrorCode("CREATE TAB foo (a, b"));
EXPECT_EQ(SQLITE_ERROR, db().GetErrorCode()); EXPECT_EQ(SQLITE_ERROR, db().GetErrorCode());
} }
TEST_F(SQLConnectionTest, ExecuteWithErrorCode) {
ASSERT_EQ(SQLITE_OK,
db().ExecuteAndReturnErrorCode("CREATE TABLE foo (a, b)"));
ASSERT_EQ(SQLITE_ERROR,
db().ExecuteAndReturnErrorCode("CREATE TABLE TABLE"));
ASSERT_EQ(SQLITE_ERROR,
db().ExecuteAndReturnErrorCode(
"INSERT INTO foo(a, b) VALUES (1, 2, 3, 4)"));
}
TEST_F(SQLConnectionTest, CachedStatement) { TEST_F(SQLConnectionTest, CachedStatement) {
sql::StatementID id1("foo", 12); sql::StatementID id1("foo", 12);
...@@ -48,7 +59,7 @@ TEST_F(SQLConnectionTest, CachedStatement) { ...@@ -48,7 +59,7 @@ TEST_F(SQLConnectionTest, CachedStatement) {
// Create a new cached statement. // Create a new cached statement.
{ {
sql::Statement s(db().GetCachedStatement(id1, "SELECT a FROM foo")); sql::Statement s(db().GetCachedStatement(id1, "SELECT a FROM foo"));
ASSERT_FALSE(!s); // Test ! operator for validity. ASSERT_TRUE(s.is_valid());
ASSERT_TRUE(s.Step()); ASSERT_TRUE(s.Step());
EXPECT_EQ(12, s.ColumnInt(0)); EXPECT_EQ(12, s.ColumnInt(0));
...@@ -61,7 +72,7 @@ TEST_F(SQLConnectionTest, CachedStatement) { ...@@ -61,7 +72,7 @@ TEST_F(SQLConnectionTest, CachedStatement) {
// Get the same statement using different SQL. This should ignore our // Get the same statement using different SQL. This should ignore our
// SQL and use the cached one (so it will be valid). // SQL and use the cached one (so it will be valid).
sql::Statement s(db().GetCachedStatement(id1, "something invalid(")); sql::Statement s(db().GetCachedStatement(id1, "something invalid("));
ASSERT_FALSE(!s); // Test ! operator for validity. ASSERT_TRUE(s.is_valid());
ASSERT_TRUE(s.Step()); ASSERT_TRUE(s.Step());
EXPECT_EQ(12, s.ColumnInt(0)); EXPECT_EQ(12, s.ColumnInt(0));
...@@ -71,6 +82,12 @@ TEST_F(SQLConnectionTest, CachedStatement) { ...@@ -71,6 +82,12 @@ TEST_F(SQLConnectionTest, CachedStatement) {
EXPECT_FALSE(db().HasCachedStatement(SQL_FROM_HERE)); EXPECT_FALSE(db().HasCachedStatement(SQL_FROM_HERE));
} }
TEST_F(SQLConnectionTest, IsSQLValidTest) {
ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)"));
ASSERT_TRUE(db().IsSQLValid("SELECT a FROM foo"));
ASSERT_FALSE(db().IsSQLValid("SELECT no_exist FROM foo"));
}
TEST_F(SQLConnectionTest, DoesStuffExist) { TEST_F(SQLConnectionTest, DoesStuffExist) {
// Test DoesTableExist. // Test DoesTableExist.
EXPECT_FALSE(db().DoesTableExist("foo")); EXPECT_FALSE(db().DoesTableExist("foo"));
...@@ -103,4 +120,3 @@ TEST_F(SQLConnectionTest, GetLastInsertRowId) { ...@@ -103,4 +120,3 @@ TEST_F(SQLConnectionTest, GetLastInsertRowId) {
ASSERT_TRUE(s.Step()); ASSERT_TRUE(s.Step());
EXPECT_EQ(12, s.ColumnInt(0)); EXPECT_EQ(12, s.ColumnInt(0));
} }
...@@ -51,8 +51,7 @@ void MetaTable::Reset() { ...@@ -51,8 +51,7 @@ void MetaTable::Reset() {
bool MetaTable::SetValue(const char* key, const std::string& value) { bool MetaTable::SetValue(const char* key, const std::string& value) {
Statement s; Statement s;
if (!PrepareSetStatement(&s, key)) PrepareSetStatement(&s, key);
return false;
s.BindString(1, value); s.BindString(1, value);
return s.Run(); return s.Run();
} }
...@@ -68,9 +67,7 @@ bool MetaTable::GetValue(const char* key, std::string* value) { ...@@ -68,9 +67,7 @@ bool MetaTable::GetValue(const char* key, std::string* value) {
bool MetaTable::SetValue(const char* key, int value) { bool MetaTable::SetValue(const char* key, int value) {
Statement s; Statement s;
if (!PrepareSetStatement(&s, key)) PrepareSetStatement(&s, key);
return false;
s.BindInt(1, value); s.BindInt(1, value);
return s.Run(); return s.Run();
} }
...@@ -86,8 +83,7 @@ bool MetaTable::GetValue(const char* key, int* value) { ...@@ -86,8 +83,7 @@ bool MetaTable::GetValue(const char* key, int* value) {
bool MetaTable::SetValue(const char* key, int64 value) { bool MetaTable::SetValue(const char* key, int64 value) {
Statement s; Statement s;
if (!PrepareSetStatement(&s, key)) PrepareSetStatement(&s, key);
return false;
s.BindInt64(1, value); s.BindInt64(1, value);
return s.Run(); return s.Run();
} }
...@@ -123,26 +119,17 @@ int MetaTable::GetCompatibleVersionNumber() { ...@@ -123,26 +119,17 @@ int MetaTable::GetCompatibleVersionNumber() {
return version; return version;
} }
bool MetaTable::PrepareSetStatement(Statement* statement, const char* key) { void MetaTable::PrepareSetStatement(Statement* statement, const char* key) {
DCHECK(db_ && statement); DCHECK(db_ && statement);
statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE, statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE,
"INSERT OR REPLACE INTO meta (key,value) VALUES (?,?)")); "INSERT OR REPLACE INTO meta (key,value) VALUES (?,?)"));
if (!statement->is_valid()) {
NOTREACHED() << db_->GetErrorMessage();
return false;
}
statement->BindCString(0, key); statement->BindCString(0, key);
return true;
} }
bool MetaTable::PrepareGetStatement(Statement* statement, const char* key) { bool MetaTable::PrepareGetStatement(Statement* statement, const char* key) {
DCHECK(db_ && statement); DCHECK(db_ && statement);
statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE, statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE,
"SELECT value FROM meta WHERE key=?")); "SELECT value FROM meta WHERE key=?"));
if (!statement->is_valid()) {
NOTREACHED() << db_->GetErrorMessage();
return false;
}
statement->BindCString(0, key); statement->BindCString(0, key);
if (!statement->Step()) if (!statement->Step())
return false; return false;
......
...@@ -71,7 +71,7 @@ class SQL_EXPORT MetaTable { ...@@ -71,7 +71,7 @@ class SQL_EXPORT MetaTable {
private: private:
// Conveniences to prepare the two types of statements used by // Conveniences to prepare the two types of statements used by
// MetaTableHelper. // MetaTableHelper.
bool PrepareSetStatement(Statement* statement, const char* key); void PrepareSetStatement(Statement* statement, const char* key);
bool PrepareGetStatement(Statement* statement, const char* key); bool PrepareGetStatement(Statement* statement, const char* key);
Connection* db_; Connection* db_;
......
...@@ -81,7 +81,8 @@ class SQLiteFeaturesTest : public testing::Test { ...@@ -81,7 +81,8 @@ class SQLiteFeaturesTest : public testing::Test {
// Do not include fts1 support, it is not useful, and nobody is // Do not include fts1 support, it is not useful, and nobody is
// looking at it. // looking at it.
TEST_F(SQLiteFeaturesTest, NoFTS1) { TEST_F(SQLiteFeaturesTest, NoFTS1) {
ASSERT_FALSE(db().Execute("CREATE VIRTUAL TABLE foo USING fts1(x)")); ASSERT_EQ(SQLITE_ERROR, db().ExecuteAndReturnErrorCode(
"CREATE VIRTUAL TABLE foo USING fts1(x)"));
} }
// fts2 is used for older history files, so we're signed on for // fts2 is used for older history files, so we're signed on for
......
...@@ -35,15 +35,23 @@ void Statement::Assign(scoped_refptr<Connection::StatementRef> ref) { ...@@ -35,15 +35,23 @@ void Statement::Assign(scoped_refptr<Connection::StatementRef> ref) {
ref_ = ref; ref_ = ref;
} }
bool Statement::Run() { bool Statement::CheckValid() const {
if (!is_valid()) if (!is_valid())
DLOG(FATAL) << "Cannot call mutating statements on an invalid statement.";
return is_valid();
}
bool Statement::Run() {
if (!CheckValid())
return false; return false;
return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_DONE; return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_DONE;
} }
bool Statement::Step() { bool Statement::Step() {
if (!is_valid()) if (!CheckValid())
return false; return false;
return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_ROW; return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_ROW;
} }
...@@ -55,21 +63,22 @@ void Statement::Reset() { ...@@ -55,21 +63,22 @@ void Statement::Reset() {
sqlite3_clear_bindings(ref_->stmt()); sqlite3_clear_bindings(ref_->stmt());
sqlite3_reset(ref_->stmt()); sqlite3_reset(ref_->stmt());
} }
succeeded_ = false; succeeded_ = false;
} }
bool Statement::Succeeded() const { bool Statement::Succeeded() const {
if (!is_valid()) if (!is_valid())
return false; return false;
return succeeded_; return succeeded_;
} }
bool Statement::BindNull(int col) { bool Statement::BindNull(int col) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_null(ref_->stmt(), col + 1)); return false;
return err == SQLITE_OK;
} return CheckOk(sqlite3_bind_null(ref_->stmt(), col + 1));
return false;
} }
bool Statement::BindBool(int col, bool val) { bool Statement::BindBool(int col, bool val) {
...@@ -77,45 +86,43 @@ bool Statement::BindBool(int col, bool val) { ...@@ -77,45 +86,43 @@ bool Statement::BindBool(int col, bool val) {
} }
bool Statement::BindInt(int col, int val) { bool Statement::BindInt(int col, int val) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_int(ref_->stmt(), col + 1, val)); return false;
return err == SQLITE_OK;
} return CheckOk(sqlite3_bind_int(ref_->stmt(), col + 1, val));
return false;
} }
bool Statement::BindInt64(int col, int64 val) { bool Statement::BindInt64(int col, int64 val) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_int64(ref_->stmt(), col + 1, val)); return false;
return err == SQLITE_OK;
} return CheckOk(sqlite3_bind_int64(ref_->stmt(), col + 1, val));
return false;
} }
bool Statement::BindDouble(int col, double val) { bool Statement::BindDouble(int col, double val) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_double(ref_->stmt(), col + 1, val)); return false;
return err == SQLITE_OK;
} return CheckOk(sqlite3_bind_double(ref_->stmt(), col + 1, val));
return false;
} }
bool Statement::BindCString(int col, const char* val) { bool Statement::BindCString(int col, const char* val) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_text(ref_->stmt(), col + 1, val, -1, return false;
SQLITE_TRANSIENT));
return err == SQLITE_OK; return CheckOk(
} sqlite3_bind_text(ref_->stmt(), col + 1, val, -1, SQLITE_TRANSIENT));
return false;
} }
bool Statement::BindString(int col, const std::string& val) { bool Statement::BindString(int col, const std::string& val) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_text(ref_->stmt(), col + 1, val.data(), return false;
val.size(), SQLITE_TRANSIENT));
return err == SQLITE_OK; return CheckOk(sqlite3_bind_text(ref_->stmt(),
} col + 1,
return false; val.data(),
val.size(),
SQLITE_TRANSIENT));
} }
bool Statement::BindString16(int col, const string16& value) { bool Statement::BindString16(int col, const string16& value) {
...@@ -123,19 +130,17 @@ bool Statement::BindString16(int col, const string16& value) { ...@@ -123,19 +130,17 @@ bool Statement::BindString16(int col, const string16& value) {
} }
bool Statement::BindBlob(int col, const void* val, int val_len) { bool Statement::BindBlob(int col, const void* val, int val_len) {
if (is_valid()) { if (!is_valid())
int err = CheckError(sqlite3_bind_blob(ref_->stmt(), col + 1, return false;
val, val_len, SQLITE_TRANSIENT));
return err == SQLITE_OK; return CheckOk(
} sqlite3_bind_blob(ref_->stmt(), col + 1, val, val_len, SQLITE_TRANSIENT));
return false;
} }
int Statement::ColumnCount() const { int Statement::ColumnCount() const {
if (!is_valid()) { if (!is_valid())
NOTREACHED();
return 0; return 0;
}
return sqlite3_column_count(ref_->stmt()); return sqlite3_column_count(ref_->stmt());
} }
...@@ -155,34 +160,30 @@ bool Statement::ColumnBool(int col) const { ...@@ -155,34 +160,30 @@ bool Statement::ColumnBool(int col) const {
} }
int Statement::ColumnInt(int col) const { int Statement::ColumnInt(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return 0; return 0;
}
return sqlite3_column_int(ref_->stmt(), col); return sqlite3_column_int(ref_->stmt(), col);
} }
int64 Statement::ColumnInt64(int col) const { int64 Statement::ColumnInt64(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return 0; return 0;
}
return sqlite3_column_int64(ref_->stmt(), col); return sqlite3_column_int64(ref_->stmt(), col);
} }
double Statement::ColumnDouble(int col) const { double Statement::ColumnDouble(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return 0; return 0;
}
return sqlite3_column_double(ref_->stmt(), col); return sqlite3_column_double(ref_->stmt(), col);
} }
std::string Statement::ColumnString(int col) const { std::string Statement::ColumnString(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return ""; return "";
}
const char* str = reinterpret_cast<const char*>( const char* str = reinterpret_cast<const char*>(
sqlite3_column_text(ref_->stmt(), col)); sqlite3_column_text(ref_->stmt(), col));
int len = sqlite3_column_bytes(ref_->stmt(), col); int len = sqlite3_column_bytes(ref_->stmt(), col);
...@@ -194,36 +195,31 @@ std::string Statement::ColumnString(int col) const { ...@@ -194,36 +195,31 @@ std::string Statement::ColumnString(int col) const {
} }
string16 Statement::ColumnString16(int col) const { string16 Statement::ColumnString16(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return string16(); return string16();
}
std::string s = ColumnString(col); std::string s = ColumnString(col);
return !s.empty() ? UTF8ToUTF16(s) : string16(); return !s.empty() ? UTF8ToUTF16(s) : string16();
} }
int Statement::ColumnByteLength(int col) const { int Statement::ColumnByteLength(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return 0; return 0;
}
return sqlite3_column_bytes(ref_->stmt(), col); return sqlite3_column_bytes(ref_->stmt(), col);
} }
const void* Statement::ColumnBlob(int col) const { const void* Statement::ColumnBlob(int col) const {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return NULL; return NULL;
}
return sqlite3_column_blob(ref_->stmt(), col); return sqlite3_column_blob(ref_->stmt(), col);
} }
bool Statement::ColumnBlobAsString(int col, std::string* blob) { bool Statement::ColumnBlobAsString(int col, std::string* blob) {
if (!is_valid()) { if (!CheckValid())
NOTREACHED();
return false; return false;
}
const void* p = ColumnBlob(col); const void* p = ColumnBlob(col);
size_t len = ColumnByteLength(col); size_t len = ColumnByteLength(col);
blob->resize(len); blob->resize(len);
...@@ -234,12 +230,11 @@ bool Statement::ColumnBlobAsString(int col, std::string* blob) { ...@@ -234,12 +230,11 @@ bool Statement::ColumnBlobAsString(int col, std::string* blob) {
return true; return true;
} }
void Statement::ColumnBlobAsVector(int col, std::vector<char>* val) const { bool Statement::ColumnBlobAsVector(int col, std::vector<char>* val) const {
val->clear(); val->clear();
if (!is_valid()) {
NOTREACHED(); if (!CheckValid())
return; return false;
}
const void* data = sqlite3_column_blob(ref_->stmt(), col); const void* data = sqlite3_column_blob(ref_->stmt(), col);
int len = sqlite3_column_bytes(ref_->stmt(), col); int len = sqlite3_column_bytes(ref_->stmt(), col);
...@@ -247,18 +242,23 @@ void Statement::ColumnBlobAsVector(int col, std::vector<char>* val) const { ...@@ -247,18 +242,23 @@ void Statement::ColumnBlobAsVector(int col, std::vector<char>* val) const {
val->resize(len); val->resize(len);
memcpy(&(*val)[0], data, len); memcpy(&(*val)[0], data, len);
} }
return true;
} }
void Statement::ColumnBlobAsVector( bool Statement::ColumnBlobAsVector(
int col, int col,
std::vector<unsigned char>* val) const { std::vector<unsigned char>* val) const {
ColumnBlobAsVector(col, reinterpret_cast< std::vector<char>* >(val)); return ColumnBlobAsVector(col, reinterpret_cast< std::vector<char>* >(val));
} }
const char* Statement::GetSQLStatement() { const char* Statement::GetSQLStatement() {
return sqlite3_sql(ref_->stmt()); return sqlite3_sql(ref_->stmt());
} }
bool Statement::CheckOk(int err) const {
return err == SQLITE_OK;
}
int Statement::CheckError(int err) { int Statement::CheckError(int err) {
// Please don't add DCHECKs here, OnSqliteError() already has them. // Please don't add DCHECKs here, OnSqliteError() already has them.
succeeded_ = (err == SQLITE_OK || err == SQLITE_ROW || err == SQLITE_DONE); succeeded_ = (err == SQLITE_OK || err == SQLITE_ROW || err == SQLITE_DONE);
......
...@@ -29,13 +29,16 @@ enum ColType { ...@@ -29,13 +29,16 @@ enum ColType {
// Normal usage: // Normal usage:
// sql::Statement s(connection_.GetUniqueStatement(...)); // sql::Statement s(connection_.GetUniqueStatement(...));
// if (!s) // You should check for errors before using the statement.
// return false;
//
// s.BindInt(0, a); // s.BindInt(0, a);
// if (s.Step()) // if (s.Step())
// return s.ColumnString(0); // return s.ColumnString(0);
// //
// If there are errors getting the statement, the statement will be inert; no
// mutating or database-access methods will work. If you need to check for
// validity, use:
// if (!s.is_valid())
// return false;
//
// Step() and Run() just return true to signal success. If you want to handle // Step() and Run() just return true to signal success. If you want to handle
// specific errors such as database corruption, install an error handler in // specific errors such as database corruption, install an error handler in
// in the connection object using set_error_delegate(). // in the connection object using set_error_delegate().
...@@ -61,6 +64,7 @@ class SQL_EXPORT Statement { ...@@ -61,6 +64,7 @@ class SQL_EXPORT Statement {
// These operators allow conveniently checking if the statement is valid // These operators allow conveniently checking if the statement is valid
// or not. See the pattern above for an example. // or not. See the pattern above for an example.
// TODO(shess,gbillock): Remove these once clients are converted.
operator bool() const { return is_valid(); } operator bool() const { return is_valid(); }
bool operator!() const { return !is_valid(); } bool operator!() const { return !is_valid(); }
...@@ -96,7 +100,7 @@ class SQL_EXPORT Statement { ...@@ -96,7 +100,7 @@ class SQL_EXPORT Statement {
// Binding ------------------------------------------------------------------- // Binding -------------------------------------------------------------------
// These all take a 0-based argument index and return true on failure. You // These all take a 0-based argument index and return true on success. You
// may not always care about the return value (they'll DCHECK if they fail). // may not always care about the return value (they'll DCHECK if they fail).
// The main thing you may want to check is when binding large blobs or // The main thing you may want to check is when binding large blobs or
// strings there may be out of memory. // strings there may be out of memory.
...@@ -137,8 +141,8 @@ class SQL_EXPORT Statement { ...@@ -137,8 +141,8 @@ class SQL_EXPORT Statement {
int ColumnByteLength(int col) const; int ColumnByteLength(int col) const;
const void* ColumnBlob(int col) const; const void* ColumnBlob(int col) const;
bool ColumnBlobAsString(int col, std::string* blob); bool ColumnBlobAsString(int col, std::string* blob);
void ColumnBlobAsVector(int col, std::vector<char>* val) const; bool ColumnBlobAsVector(int col, std::vector<char>* val) const;
void ColumnBlobAsVector(int col, std::vector<unsigned char>* val) const; bool ColumnBlobAsVector(int col, std::vector<unsigned char>* val) const;
// Diagnostics -------------------------------------------------------------- // Diagnostics --------------------------------------------------------------
...@@ -152,6 +156,24 @@ class SQL_EXPORT Statement { ...@@ -152,6 +156,24 @@ class SQL_EXPORT Statement {
// enhanced in the future to do the notification. // enhanced in the future to do the notification.
int CheckError(int err); int CheckError(int err);
// Contraction for checking an error code against SQLITE_OK. Does not set the
// succeeded flag.
bool CheckOk(int err) const;
// Should be called by all mutating methods to check that the statement is
// valid. Returns true if the statement is valid. DCHECKS and returns false
// if it is not.
// The reason for this is to handle two specific cases in which a Statement
// may be invalid. The first case is that the programmer made an SQL error.
// Those cases need to be DCHECKed so that we are guaranteed to find them
// before release. The second case is that the computer has an error (probably
// out of disk space) which is prohibiting the correct operation of the
// database. Our testing apparatus should not exhibit this defect, but release
// situations may. Therefore, the code is handling disjoint situations in
// release and test. In test, we're ensuring correct SQL. In release, we're
// ensuring that contracts are honored in error edge cases.
bool CheckValid() const;
// The actual sqlite statement. This may be unique to us, or it may be cached // The actual sqlite statement. This may be unique to us, or it may be cached
// by the connection, which is why it's refcounted. This pointer is // by the connection, which is why it's refcounted. This pointer is
// guaranteed non-NULL. // guaranteed non-NULL.
......
...@@ -70,13 +70,9 @@ class SQLStatementTest : public testing::Test { ...@@ -70,13 +70,9 @@ class SQLStatementTest : public testing::Test {
TEST_F(SQLStatementTest, Assign) { TEST_F(SQLStatementTest, Assign) {
sql::Statement s; sql::Statement s;
EXPECT_FALSE(s); // bool conversion operator.
EXPECT_TRUE(!s); // ! operator.
EXPECT_FALSE(s.is_valid()); EXPECT_FALSE(s.is_valid());
s.Assign(db().GetUniqueStatement("CREATE TABLE foo (a, b)")); s.Assign(db().GetUniqueStatement("CREATE TABLE foo (a, b)"));
EXPECT_TRUE(s);
EXPECT_FALSE(!s);
EXPECT_TRUE(s.is_valid()); EXPECT_TRUE(s.is_valid());
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment