import os import shutil import pytest import sqlite3 from calculate.templates.format.base_format import FormatError from calculate.templates.format.sqlite_format import SqliteFormat from calculate.templates.template_engine import ParametersContainer TEST_DIR = 'tests/templates/format/testfiles/sqlite' DB_BACKUP_PATH = os.path.join(TEST_DIR, "test.db.backup") DB_PATH = os.path.join(TEST_DIR, "test.db") db_values = [(1, 'Amenra', 'Belgium'), (2, 'Windir', 'Norway'), (3, 'Pest Noire', 'France'), (4, 'Midnight Odyssey', 'Australia'), (5, 'Death in June', 'England'), (6, 'Bathory', 'Sweden'), (7, 'Primordial', 'Ireland'), (8, 'Mayhem', 'Norway'), (9, 'Sargeist', 'Finland'), (10, 'Lifelover', 'Sweden'), (11, 'Shining', 'Sweden')] @pytest.mark.formats @pytest.mark.sqlite def test_to_copy_test_db(): shutil.copy(DB_BACKUP_PATH, DB_PATH) @pytest.mark.formats @pytest.mark.sqlite def test_first(): template_text = """INSERT INTO bands(band_name, country) VALUES ("Current 93", "England"), ("Nine Inch Nails", "USA"), ("Enter Shikari", "England");""" sqlite_object = SqliteFormat(template_text, "/path/to/the/template", parameters=ParametersContainer()) sqlite_object.execute_format(DB_PATH, TEST_DIR) connection = sqlite3.connect(DB_PATH) cursor = connection.cursor() result = cursor.execute("SELECT * FROM bands;").fetchall() added = [(12, 'Current 93', 'England'), (13, 'Nine Inch Nails', 'USA'), (14, 'Enter Shikari', 'England')] db_values.extend(added) assert result == db_values @pytest.mark.formats @pytest.mark.sqlite def test_with_execsql_parameter_is_rollback(): parameters = ParametersContainer({"execsql": "rollback"}) template_text = """UPDATE bands SET band_name = "Burzum" WHERE band_name = "Mayhem"; INSERT INTO bands VALUES ("Bring Me Horizon"); INSERT INTO bands(band_name, country) VALUES ("Joy Division", "England");""" sqlite_object = SqliteFormat(template_text, "/path/to/the/template", parameters=parameters) sqlite_object.execute_format(DB_PATH, TEST_DIR) connection = sqlite3.connect(DB_PATH) cursor = connection.cursor() result = cursor.execute("SELECT * FROM bands;").fetchall() assert result == db_values @pytest.mark.formats @pytest.mark.sqlite def test_with_execsql_parameter_is_stop(): parameters = ParametersContainer({"execsql": "stop"}) template_text = """UPDATE bands SET band_name = "Burzum" WHERE band_name = "Mayhem"; INSERT INTO bands VALUES ("Bring Me Horizon"); INSERT INTO bands(band_name, country) VALUES ("Joy Division", "England");""" sqlite_object = SqliteFormat(template_text, "/path/to/the/template", parameters=parameters) sqlite_object.execute_format(DB_PATH, TEST_DIR) connection = sqlite3.connect(DB_PATH) cursor = connection.cursor() result = cursor.execute("SELECT * FROM bands;").fetchall() db_values[7] = (8, "Burzum", "Norway") assert result == db_values @pytest.mark.formats @pytest.mark.sqlite def test_with_execsql_parameter_is_continue(): parameters = ParametersContainer({"execsql": "continue"}) template_text = """INSERT INTO bands(band_name, country) VALUES ("The Cure", "England"); INSERT INTO bands VALUES ("Bring Me Horizon"); INSERT INTO bands(band_name, country) VALUES ("Joy Division", "England");""" sqlite_object = SqliteFormat(template_text, "/path/to/the/template", parameters=parameters) sqlite_object.execute_format(DB_PATH, TEST_DIR) connection = sqlite3.connect(DB_PATH) cursor = connection.cursor() result = cursor.execute("SELECT * FROM bands;").fetchall() added = [(15, 'The Cure', 'England'), (16, 'Joy Division', 'England')] db_values.extend(added) assert result == db_values @pytest.mark.formats @pytest.mark.sqlite def test_to_remove_test_db(): os.unlink(DB_PATH)