diff --git a/calculate/templates/format/sqlite_format.py b/calculate/templates/format/sqlite_format.py new file mode 100644 index 0000000..094d209 --- /dev/null +++ b/calculate/templates/format/sqlite_format.py @@ -0,0 +1,85 @@ +# vim: fileencoding=utf-8 +# +from ..template_engine import ParametersContainer +from .base_format import Format, FormatError +from sqlite3 import connect, Connection, OperationalError, IntegrityError +import os + + +class SqliteFormat(Format): + FORMAT = 'sqlite' + EXECUTABLE = True + + def __init__(self, template_text: str, + template_path: str, + parameters: ParametersContainer = ParametersContainer(), + **kwargs): + self._template_text = template_text + if parameters.execsql: + self._mode = parameters.execsql + else: + self._mode = "rollback" + + self._executor = self.__getattribute__(f"_execute_{self._mode}") + + # Измененные файлы. + self.changed_files = dict() + + # Предупреждения. + self._warnings: list = [] + + def execute_format(self, target_path: str, + chroot_path: str = '/') -> dict: + '''Метод для запуска работы формата.''' + if os.path.exists(target_path) and os.path.isdir(target_path): + raise FormatError(f"directory on target path: {target_path}") + + connection = connect(target_path) + self._executor(connection, self._template_text) + connection.close() + + return self.changed_files + + def _execute_continue(self, connection: Connection, template_text: str): + """Метод для выполнения sql-команд при значения параметра + execsql = continue. В этом случае после обнаружения ошибки в скрипте + выполнение команд будет продолжено.""" + cursor = connection.cursor() + commands = template_text.split(";\n") + for command in commands: + try: + cursor.execute(command.strip()) + connection.commit() + except (OperationalError, IntegrityError) as error: + self._warnings.append(str(error)) + + def _execute_rollback(self, connection: Connection, template_text: str): + """Метод для выполнения sql-команд при значения параметра + execsql = rollback. В этом случае после обнаружения ошибки в скрипте + результат выполнения команд будет сброшен.""" + commands = template_text.split(";\n") + cursor = connection.cursor() + cursor.execute("BEGIN TRANSACTION;") + try: + for command in commands: + cursor.execute(command) + cursor.execute("COMMIT;") + except (OperationalError, IntegrityError) as error: + self._warnings.append(str(error)) + cursor.execute("ROLLBACK;") + + def _execute_stop(self, connection: Connection, template_text: str): + """Метод для выполнения sql-команд при значения параметра + execsql = stop. В этом случае после обнаружения ошибки в скрипте + выполнение команд будет остановлено.""" + cursor = connection.cursor() + try: + cursor.executescript(template_text) + except (OperationalError, IntegrityError) as error: + self._warnings.append(str(error)) + finally: + connection.commit() + + @property + def warnings(self): + return self._warnings diff --git a/calculate/templates/template_engine.py b/calculate/templates/template_engine.py index d675fd3..f3edb93 100644 --- a/calculate/templates/template_engine.py +++ b/calculate/templates/template_engine.py @@ -146,7 +146,7 @@ class ParametersProcessor: 'env', 'package', 'merge', 'postmerge', 'action', 'rebuild', 'restart', 'stop', 'start', 'handler', 'notify', 'group', - 'convert', 'stretch'} + 'convert', 'stretch', 'execsql'} inheritable_parameters: set = {'chmod', 'chown', 'autoupdate', 'env', 'package', 'action', 'handler', 'group'} @@ -229,6 +229,7 @@ class ParametersProcessor: 'notify': self.check_notify_parameter, 'convert': self.check_convert_parameter, 'stretch': self.check_stretch_parameter, + 'execsql': self.check_execsql_parameter, }) # Если добавляемый параметр должен быть проверен после того, как @@ -244,6 +245,7 @@ class ParametersProcessor: 'handler': self.check_postparse_handler, 'convert': self.check_postparse_convert, 'stretch': self.check_postparse_stretch, + 'execsql': self.check_postparse_execsql, }) # Если параметр является наследуемым только при некоторых условиях -- @@ -622,6 +624,19 @@ class ParametersProcessor: f" value not '{type(parameter_value)}'") return parameter_value + def check_execsql_parameter(self, parameter_value: Any) -> bool: + if not isinstance(parameter_value, str): + raise IncorrectParameter("'execsql' parameter value should be str" + f" value not '{type(parameter_value)}'") + available_values = {"rollback", "stop", "continue"} + parameter_value = parameter_value.lower().strip() + if parameter_value not in available_values: + raise IncorrectParameter("'exesql' parameter value" + f" '{parameter_value}' is not available." + " Available values: " + f"{', '.join(available_values)}") + return parameter_value + # Методы для проверки параметров после разбора всего шаблона. def check_postparse_append(self, parameter_value: str) -> NoReturn: if parameter_value == 'link': @@ -757,6 +772,12 @@ class ParametersProcessor: raise IncorrectParameter("'stretch' parameter available for" " 'backgrounds' format only.") + def check_postparse_execsql(self, parameter_value: str) -> NoReturn: + template_format = self._parameters_container.format + if not template_format or template_format != "sqlite": + raise IncorrectParameter("'execsql' parameter available for" + " 'sqlite' format only.") + # Методы для проверки того, являются ли параметры наследуемыми. def is_chmod_inheritable(self, parameter_value: str) -> bool: diff --git a/pytest.ini b/pytest.ini index cabf2b5..2b4e4aa 100644 --- a/pytest.ini +++ b/pytest.ini @@ -17,6 +17,7 @@ markers = openrc: marker for running test for openrc format. raw: marker for running test fot raw format. regex: marker for running test fot regex format. + sqlite: marker for running test fot sqlite format. postfix: marker for running test for postfix format. procmail: marker for running test for procmail format. proftpd: marker for running tests for proftpd format. diff --git a/tests/templates/format/test_sqlite.py b/tests/templates/format/test_sqlite.py new file mode 100644 index 0000000..9ead794 --- /dev/null +++ b/tests/templates/format/test_sqlite.py @@ -0,0 +1,126 @@ +import os +import shutil +import pytest +import sqlite3 +from calculate.templates.format.base_format import FormatError +from calculate.templates.format.sqlite_format import SqliteFormat +from calculate.templates.template_engine import ParametersContainer + + +TEST_DIR = 'tests/templates/format/testfiles/sqlite' +DB_BACKUP_PATH = os.path.join(TEST_DIR, "test.db.backup") +DB_PATH = os.path.join(TEST_DIR, "test.db") + +db_values = [(1, 'Amenra', 'Belgium'), + (2, 'Windir', 'Norway'), + (3, 'Pest Noire', 'France'), + (4, 'Midnight Odyssey', 'Australia'), + (5, 'Death in June', 'England'), + (6, 'Bathory', 'Sweden'), + (7, 'Primordial', 'Ireland'), + (8, 'Mayhem', 'Norway'), + (9, 'Sargeist', 'Finland'), + (10, 'Lifelover', 'Sweden'), + (11, 'Shining', 'Sweden')] + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_to_copy_test_db(): + shutil.copy(DB_BACKUP_PATH, DB_PATH) + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_first(): + template_text = """INSERT INTO bands(band_name, country) VALUES + ("Current 93", "England"), + ("Nine Inch Nails", "USA"), + ("Enter Shikari", "England");""" + + sqlite_object = SqliteFormat(template_text, "/path/to/the/template", + parameters=ParametersContainer()) + sqlite_object.execute_format(DB_PATH, TEST_DIR) + connection = sqlite3.connect(DB_PATH) + cursor = connection.cursor() + result = cursor.execute("SELECT * FROM bands;").fetchall() + added = [(12, 'Current 93', 'England'), + (13, 'Nine Inch Nails', 'USA'), + (14, 'Enter Shikari', 'England')] + db_values.extend(added) + + assert result == db_values + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_with_execsql_parameter_is_rollback(): + parameters = ParametersContainer({"execsql": "rollback"}) + template_text = """UPDATE bands SET + band_name = "Burzum" WHERE band_name = "Mayhem"; + + INSERT INTO bands VALUES ("Bring Me Horizon"); + + INSERT INTO bands(band_name, country) + VALUES ("Joy Division", "England");""" + + sqlite_object = SqliteFormat(template_text, "/path/to/the/template", + parameters=parameters) + sqlite_object.execute_format(DB_PATH, TEST_DIR) + connection = sqlite3.connect(DB_PATH) + cursor = connection.cursor() + result = cursor.execute("SELECT * FROM bands;").fetchall() + assert result == db_values + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_with_execsql_parameter_is_stop(): + parameters = ParametersContainer({"execsql": "stop"}) + template_text = """UPDATE bands SET + band_name = "Burzum" WHERE band_name = "Mayhem"; + + INSERT INTO bands VALUES ("Bring Me Horizon"); + + INSERT INTO bands(band_name, country) + VALUES ("Joy Division", "England");""" + + sqlite_object = SqliteFormat(template_text, "/path/to/the/template", + parameters=parameters) + sqlite_object.execute_format(DB_PATH, TEST_DIR) + connection = sqlite3.connect(DB_PATH) + cursor = connection.cursor() + result = cursor.execute("SELECT * FROM bands;").fetchall() + db_values[7] = (8, "Burzum", "Norway") + assert result == db_values + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_with_execsql_parameter_is_continue(): + parameters = ParametersContainer({"execsql": "continue"}) + template_text = """INSERT INTO bands(band_name, country) VALUES + ("The Cure", "England"); + + INSERT INTO bands VALUES ("Bring Me Horizon"); + + INSERT INTO bands(band_name, country) + VALUES ("Joy Division", "England");""" + + sqlite_object = SqliteFormat(template_text, "/path/to/the/template", + parameters=parameters) + sqlite_object.execute_format(DB_PATH, TEST_DIR) + connection = sqlite3.connect(DB_PATH) + cursor = connection.cursor() + result = cursor.execute("SELECT * FROM bands;").fetchall() + added = [(15, 'The Cure', 'England'), + (16, 'Joy Division', 'England')] + db_values.extend(added) + + assert result == db_values + + +@pytest.mark.formats +@pytest.mark.sqlite +def test_to_remove_test_db(): + os.unlink(DB_PATH) diff --git a/tests/templates/format/testfiles/sqlite/test.db.backup b/tests/templates/format/testfiles/sqlite/test.db.backup new file mode 100644 index 0000000..2f69903 Binary files /dev/null and b/tests/templates/format/testfiles/sqlite/test.db.backup differ