diff --git a/tools/shell/shell.cpp b/tools/shell/shell.cpp index 60aa2368105e..31df60f4fdb2 100644 --- a/tools/shell/shell.cpp +++ b/tools/shell/shell.cpp @@ -3626,48 +3626,54 @@ bool ShellState::OpenDatabase(const char **azArg, idx_t nArg) { } char *zNewFilename; /* Name of the database file to open */ idx_t iName = 1; /* Index in azArg[] of the filename */ - bool newFlag = false; /* True to delete file before opening */ - /* Close the existing database */ - close_db(db); - db = nullptr; - globalDb = nullptr; zDbFilename = string(); - openMode = SHELL_OPEN_UNSPEC; - openFlags = openFlags & ~(SQLITE_OPEN_NOFOLLOW); // don't overwrite settings loaded in the command line - szMax = 0; - /* Check for command-line arguments */ - for (idx_t iName = 1; iName < nArg && azArg[iName][0] == '-'; iName++) { - const char *z = azArg[iName]; - if (optionMatch(z, "new")) { - newFlag = true; - } else if (optionMatch(z, "readonly")) { - openMode = SHELL_OPEN_READONLY; - } else if (optionMatch(z, "nofollow")) { - openFlags |= SQLITE_OPEN_NOFOLLOW; - } else if (z[0] == '-') { - utf8_printf(stderr, "unknown option: %s\n", z); - return false; - } + bool read_only = false; + // 2 is the first non filename arg + for (idx_t iName = 2; iName < nArg && azArg[iName][0] == '-'; iName++) { + const char *z = azArg[iName]; + if (optionMatch(z, "readonly")) { + read_only = true; + } else if (z[0] == '-') { + utf8_printf(stderr, "unknown option: %s\n", z); + return false; + } } /* If a filename is specified, try to open it first */ zNewFilename = nArg > iName ? sqlite3_mprintf("%s", azArg[iName]) : 0; - if (zNewFilename || openMode == SHELL_OPEN_HEXDB) { - if (newFlag) { - shellDeleteFile(zNewFilename); - } + if (zNewFilename) { zDbFilename = zNewFilename; sqlite3_free(zNewFilename); - OpenDB(OPEN_DB_KEEPALIVE); - if (!db) { - utf8_printf(stderr, "Error: cannot open '%s'\n", zNewFilename); - } - } - if (!db) { - /* As a fall-back open a TEMP database */ - zDbFilename = string(); - OpenDB(0); + idx_t dot_placement = zDbFilename.find("."); + bool found_extension = dot_placement >= 0; + string zDbDBName; + if (!found_extension) { + zDbDBName = zDbFilename; + } else { + zDbDBName = zDbFilename.substr(0, dot_placement); + } + string attach = "ATTACH '" + zDbFilename + "' as " + zDbDBName; + if (read_only) { + attach += " (READ_ONLY)"; + } + attach += ";"; + const char *attach_sql = attach.c_str(); + string use = " USE " + zDbDBName + ";"; + const char *use_sql = use.c_str(); + char *zErrMsg = 0; + sqlite3_exec(db, attach_sql, NULL, NULL, &zErrMsg); + if (zErrMsg) { + utf8_printf(stderr, "%s\n", zErrMsg); + return false; + }; + sqlite3_exec(db, use_sql, NULL, NULL, &zErrMsg); + if (zErrMsg) { + utf8_printf(stderr, "%s\n", zErrMsg); + return false; + }; + return true; } - return true; + utf8_printf(stderr, "Valid Filename not provided\n"); + return false; } MetadataResult OpenDatabase(ShellState &state, const char **azArg, idx_t nArg) { diff --git a/tools/shell/tests/test_backwards_compatibility.py b/tools/shell/tests/test_backwards_compatibility.py index 98d5135d1790..c36f196ed021 100644 --- a/tools/shell/tests/test_backwards_compatibility.py +++ b/tools/shell/tests/test_backwards_compatibility.py @@ -10,7 +10,7 @@ def test_version_dev(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_dev.db") + .statement("Attach 'test/storage/bc/db_dev.db' as db_dev;") ) result = test.run() result.check_stderr("older development version") @@ -18,7 +18,7 @@ def test_version_dev(shell): def test_version_0_3_1(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_031.db") + .statement("Attach 'test/storage/bc/db_031.db' as db_031;") ) result = test.run() result.check_stderr("v0.3.1") @@ -26,7 +26,7 @@ def test_version_0_3_1(shell): def test_version_0_3_2(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_032.db") + .statement("Attach 'test/storage/bc/db_032.db' as db_032;") ) result = test.run() result.check_stderr("v0.3.2") @@ -34,7 +34,7 @@ def test_version_0_3_2(shell): def test_version_0_4(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_04.db") + .statement("Attach 'test/storage/bc/db_04.db' as db_04;") ) result = test.run() result.check_stderr("v0.4.0") @@ -42,7 +42,7 @@ def test_version_0_4(shell): def test_version_0_5_1(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_051.db") + .statement("Attach 'test/storage/bc/db_051.db' as db_051;") ) result = test.run() result.check_stderr("v0.5.1") @@ -50,7 +50,7 @@ def test_version_0_5_1(shell): def test_version_0_6_0(shell): test = ( ShellTest(shell) - .statement(".open test/storage/bc/db_060.db") + .statement("Attach 'test/storage/bc/db_060.db' as db_060;") ) result = test.run() result.check_stderr("v0.6.0")