From 91690f73e26413bec425b01e8d5f0ad2b2586439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=94=D0=B5=D0=BD=D0=B8=D1=81=20=D0=98=D0=B2=D0=B0=D0=BD?= =?UTF-8?q?=D0=BE=D0=B2?= Date: Wed, 24 Feb 2021 15:28:00 +0300 Subject: [PATCH] Added authentification using access tokens. --- .gitignore | 1 + calculate/parameters/parameters.py | 4 +- calculate/server/config.py | 7 + calculate/server/models/database.py | 14 ++ calculate/server/models/users.py | 34 +++ calculate/server/models/workers.py | 0 calculate/server/routers/commands.py | 56 +++++ calculate/server/routers/users.py | 46 ++++ calculate/server/routers/workers.py | 0 calculate/server/schemas/config.py | 13 ++ calculate/server/schemas/tokens.py | 11 + calculate/server/schemas/users.py | 49 +++++ calculate/server/schemas/workers.py | 0 calculate/server/server.py | 203 +++++------------- calculate/server/server_data.py | 104 +++++++++ calculate/server/utils/auth.py | 77 +++++++ calculate/server/utils/dependencies.py | 50 +++++ calculate/server/utils/users.py | 38 ++++ .../server/{worker.py => utils/workers.py} | 4 +- calculate/templates/format/sqlite_format.py | 9 +- calculate/utils/ldap.py | 0 client.py | 2 + conftest.py | 3 + requirements.txt | 8 + tests/server/config.py | 11 + tests/server/test_server.py | 129 ++++++----- tests/server/testfiles/etc.backup/temp | 0 27 files changed, 659 insertions(+), 214 deletions(-) create mode 100644 calculate/server/config.py create mode 100644 calculate/server/models/database.py create mode 100644 calculate/server/models/users.py create mode 100644 calculate/server/models/workers.py create mode 100644 calculate/server/routers/commands.py create mode 100644 calculate/server/routers/users.py create mode 100644 calculate/server/routers/workers.py create mode 100644 calculate/server/schemas/config.py create mode 100644 calculate/server/schemas/tokens.py create mode 100644 calculate/server/schemas/users.py create mode 100644 calculate/server/schemas/workers.py create mode 100644 calculate/server/server_data.py create mode 100644 calculate/server/utils/auth.py create mode 100644 calculate/server/utils/dependencies.py create mode 100644 calculate/server/utils/users.py rename calculate/server/{worker.py => utils/workers.py} (98%) create mode 100644 calculate/utils/ldap.py create mode 100644 tests/server/config.py create mode 100644 tests/server/testfiles/etc.backup/temp diff --git a/.gitignore b/.gitignore index e45c141..348f0d0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.pyo *.swp *.sock +*.db build/ dist/ calculate_lib.egg-info/ diff --git a/calculate/parameters/parameters.py b/calculate/parameters/parameters.py index 132a79b..c2be990 100644 --- a/calculate/parameters/parameters.py +++ b/calculate/parameters/parameters.py @@ -8,11 +8,11 @@ from typing import Tuple, Union, Any class ParameterError(Exception): - ... + pass class ValidationError(ParameterError): - ... + pass class CyclicValidationError(ValidationError): diff --git a/calculate/server/config.py b/calculate/server/config.py new file mode 100644 index 0000000..59b4b6f --- /dev/null +++ b/calculate/server/config.py @@ -0,0 +1,7 @@ +from calculate.logging import dictLogConfig + + +config = {"socket_path": "./input.sock", + "variables_path": "calculate/variables", + "commands_path": "calculate/commands", + "logger_config": dictLogConfig} diff --git a/calculate/server/models/database.py b/calculate/server/models/database.py new file mode 100644 index 0000000..c99e493 --- /dev/null +++ b/calculate/server/models/database.py @@ -0,0 +1,14 @@ +from os import environ +from databases import Database + + +TESTING = bool(environ.get("TESTING", False)) + + +if TESTING: + DATABASE_URL = "sqlite:///tests/server/testfiles/test.db" +else: + # Временно. + DATABASE_URL = "sqlite:///calculate/server/tmp.db" + +database = Database(DATABASE_URL) diff --git a/calculate/server/models/users.py b/calculate/server/models/users.py new file mode 100644 index 0000000..0eb10cd --- /dev/null +++ b/calculate/server/models/users.py @@ -0,0 +1,34 @@ +from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey + + +metadata = MetaData() + + +users_table: Table = Table("users", + metadata, + Column("id", Integer, primary_key=True), + Column("login", + String(20), + unique=True, + nullable=False, + index=True), + Column("password", + String(77), + nullable=False) + ) + +rights_table: Table = Table("rights", + metadata, + Column("id", Integer, primary_key=True), + Column("name", + String(20), + unique=True, + nullable=False, + index=True), + Column("description", String(40))) + +users_rights: Table = Table("users_rights", + metadata, + Column("user_id", ForeignKey("users.id")), + Column("right_id", ForeignKey("rights.id")) + ) diff --git a/calculate/server/models/workers.py b/calculate/server/models/workers.py new file mode 100644 index 0000000..e69de29 diff --git a/calculate/server/routers/commands.py b/calculate/server/routers/commands.py new file mode 100644 index 0000000..7ad00f6 --- /dev/null +++ b/calculate/server/routers/commands.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, Depends + +from ..utils.dependencies import right_checkers + +from ..server_data import ServerData + + +data = ServerData() +router = APIRouter() + + +@router.get("/commands", tags=["Commands management"], + dependencies=[Depends(right_checkers["read"])]) +async def get_commands() -> dict: + '''Обработчик, отвечающий на запросы списка команд.''' + response = {} + for command_id, command_object in data.commands.items(): + response.update({command_id: {"title": command_object.title, + "category": command_object.category, + "icon": command_object.icon, + "command": command_object.command}}) + return response + + +@router.get("/commands/{cid}", tags=["Commands management"], + dependencies=[Depends(right_checkers["read"])]) +async def get_command(cid: int) -> dict: + '''Обработчик запросов списка команд.''' + if cid not in data.commands_instances: + # TODO добавить какую-то обработку ошибки. + pass + return {'id': cid, + 'name': f'command_{cid}'} + + +@router.get("/commands/{cid}/groups", tags=["Commands management"], + dependencies=[Depends(right_checkers["read"])]) +async def get_command_parameters_groups(cid: int) -> dict: + '''Обработчик запросов на получение групп параметров указанной команды.''' + pass + + +@router.get("/commands/{cid}/parameters", tags=["Commands management"], + dependencies=[Depends(right_checkers["read"])]) +async def get_command_parameters(cid: int) -> dict: + '''Обработчик запросов на получение параметров указанной команды.''' + pass + + +@router.post("/commands/{command_id}", tags=["Commands management"], + dependencies=[Depends(right_checkers["write"])]) +async def post_command(command_id: str) -> int: + if command_id not in data.commands: + # TODO добавить какую-то обработку ошибки. + pass + return diff --git a/calculate/server/routers/users.py b/calculate/server/routers/users.py new file mode 100644 index 0000000..01c98ff --- /dev/null +++ b/calculate/server/routers/users.py @@ -0,0 +1,46 @@ +from datetime import timedelta + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordRequestForm + +from ..utils.auth import auth_user, create_access_token +from ..utils.dependencies import right_checkers + +from ..schemas.tokens import Token +from ..schemas.users import User, UserCreate + + +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +router = APIRouter() + + +@router.post("/auth", tags=["Authentication"], response_model=Token) +async def authenticate(form_data: OAuth2PasswordRequestForm = Depends()): + '''Метод обрабатывающий запросы на аутентификацию пользователей по данным + указанным в форме, возвращает access_token.''' + user = await auth_user(form_data.username, + form_data.password) + if not user: + raise HTTPException( + status_code=400, + detail=f"User '{form_data.username}' not found" + ) + + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token(data={"sub": user.login}, + expires_delta=access_token_expires) + # TODO добавить refresh_token + return {"access_token": access_token, + "token_type": "bearer"} + + +@router.post("/refresh", tags=["Authentication"], response_model=Token) +async def refresh(): + '''Метод для обработки запросов с refresh-токенами.''' + pass + + +@router.post("/users/create", tags=["Users management"], response_model=User) +async def create_user(user_info: UserCreate, + current_user=Depends(right_checkers["admin"])): + print("CREATE USER") diff --git a/calculate/server/routers/workers.py b/calculate/server/routers/workers.py new file mode 100644 index 0000000..e69de29 diff --git a/calculate/server/schemas/config.py b/calculate/server/schemas/config.py new file mode 100644 index 0000000..1ce298e --- /dev/null +++ b/calculate/server/schemas/config.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class ConfigSchema(BaseModel): + socket_path: str + variables_path: str + commands_path: str + + logger_config: dict + + class Config: + min_any_str_length = 1 + anystr_strip_whitespace = True diff --git a/calculate/server/schemas/tokens.py b/calculate/server/schemas/tokens.py new file mode 100644 index 0000000..96f219b --- /dev/null +++ b/calculate/server/schemas/tokens.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: str + expire: int diff --git a/calculate/server/schemas/users.py b/calculate/server/schemas/users.py new file mode 100644 index 0000000..8a92e3f --- /dev/null +++ b/calculate/server/schemas/users.py @@ -0,0 +1,49 @@ +from fastapi import HTTPException, status +from pydantic import BaseModel, validator +from typing import List, Optional + + +class UserCreate(BaseModel): + login: str + password: str + rights: Optional[List[str]] = [] + + +class UserData(UserCreate): + id: int + + +class User(BaseModel): + id: int + login: str + rights: List[str] + + +class UserRead(User): + @validator("rights") + def check_permissions(cls, value): + check_rights("read", value) + return value + + +class UserWrite(User): + @validator("rights") + def check_permissions(cls, value): + check_rights("write", value) + return value + + +class UserAdmin(User): + @validator("rights") + def check_permissions(cls, value): + check_rights("admin", value) + return value + + +def check_rights(right: str, rights_list: List[str]): + if right not in rights_list: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/calculate/server/schemas/workers.py b/calculate/server/schemas/workers.py new file mode 100644 index 0000000..e69de29 diff --git a/calculate/server/server.py b/calculate/server/server.py index 11a6819..d47591a 100644 --- a/calculate/server/server.py +++ b/calculate/server/server.py @@ -1,19 +1,10 @@ -from ..variables.loader import Datavars -from ..commands.commands import Command -from ..logging import dictLogConfig -from logging.config import dictConfig -from typing import ( - Callable, - Optional, - NoReturn - ) -from logging import getLogger -from fastapi import FastAPI -from .worker import Worker -import importlib -import uvicorn -import asyncio -import os +from fastapi import FastAPI, Depends + +from .server_data import ServerData +from .utils.dependencies import right_checkers +from .models.database import database +from .routers.users import router as users_router +from .routers.commands import router as commands_router # TODO @@ -21,142 +12,44 @@ import os # 2. Разобраться с объектами воркеров. И способом их функционирования. -class Server: - def __init__(self, socket_path: str = './input.sock', - datavars_path: str = 'calculate/vars/', - commands_path: str = 'calculate/commands'): - self._app = FastAPI() - self._socket_path = socket_path - self._event_loop = asyncio.get_event_loop() - - # Конфигурируем логгирование. - dictConfig(dictLogConfig) - self._logger = getLogger("main") - self.log_msg = {'DEBUG': self._logger.debug, - 'INFO': self._logger.info, - 'WARNING': self._logger.warning, - 'ERROR': self._logger.error, - 'CRITICAL': self._logger.critical} - - self._datavars = Datavars(variables_path=datavars_path, - logger=self._logger) - - # Словарь описаний команд. - self._commands = self._get_commands_list(commands_path) - - # Словарь CID и экземпляров команд, передаваемых воркерам. - self._commands_instances = {} - - # Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам. - self._workers = {} - - # Соответствие путей обработчикам запросов для HTTP-метода GET. - self._add_routes(self._app.get, - {"/": self._get_root, - "/commands": self._get_commands, - "/commands/{cid}": self._get_command, - "/workers/{wid}": self._get_worker}) - self._add_routes(self._app.post, - {"/commands/{command_id}": self._post_command}) - - def _get_commands_list(self, commands_path: str) -> list: - '''Метод для получения совокупности описаний команд.''' - output = {} - package = ".".join(commands_path.split("/")) - - for entry in os.scandir(commands_path): - if (not entry.name.endswith('.py') - or entry.name in {"commands.py", "__init__.py"}): - continue - module_name = entry.name[:-3] - try: - module = importlib.import_module("{}.{}".format(package, - module_name)) - for obj in dir(module): - if type(module.__getattribute__(obj)) == Command: - command_object = module.__getattribute__(obj) - output[command_object.id] = command_object - except Exception: - continue - return output - - # Обработчики запросов серверу. - async def _get_root(self) -> dict: - '''Обработчик корневых запросов.''' - return {'msg': 'root msg'} - - async def _get_commands(self) -> dict: - '''Обработчик, отвечающий на запросы списка команд.''' - response = {} - for command_id, command_object in self._commands.items(): - response.update({command_id: {"title": command_object.title, - "category": command_object.category, - "icon": command_object.icon, - "command": command_object.command}}) - return response - - async def _get_command(self, cid: int) -> dict: - '''Обработчик запросов списка команд.''' - if cid not in self._commands_instances: - # TODO добавить какую-то обработку ошибки. - pass - return {'id': cid, - 'name': f'command_{cid}'} - - async def _get_worker(self, wid: int): - '''Тестовый обработчик.''' - self._make_worker(wid=wid) - worker = self._workers[wid] - worker.run(None) - await worker.send({"text": "INFO"}) - data = await worker.get() - if data['type'] == 'log': - self.log_msg[data['level']](data['msg']) - return data - - async def _post_command(self, command_id: str) -> int: - if command_id not in self._commands: - # TODO добавить какую-то обработку ошибки. - pass - return - - # Обработчики сообщений воркеров. - - # Вспомогательные методы. - def _add_routes(self, method: Callable, routes: dict) -> NoReturn: - '''Метод для добавления методов.''' - for path, handler in routes.items(): - router = method(path) - router(handler) - - def _make_worker(self, wid: Optional[int] = None): - '''Метод для создания воркера для команды.''' - if wid is not None: - self._workers[wid] = Worker(wid, self._event_loop) - return wid - elif not self._workers: - self._workers[0] = Worker(0, self._event_loop) - return 0 - else: - wid = max(self._workers.keys()) + 1 - self._workers[wid] = Worker(wid, self._event_loop) - return wid - - def _make_command(self, command_id: str) -> int: - '''Метод для создания команды по ее описанию.''' - command_description = self._commands[command] - - @property - def app(self): - return self._app - - def run(self): - '''Метод для запуска сервера.''' - # Выгружаем список команд. - uvicorn.run(self._app, - uds=self._socket_path) - - -if __name__ == '__main__': - server = Server() - server.run() +data = ServerData() +app = FastAPI() + + +@app.on_event("startup") +async def startup(): + await database.connect() + + +@app.on_event("shutdown") +async def shutdown(): + await database.disconnect() + + +@app.get("/", tags=["Root"], + dependencies=[Depends(right_checkers["read"])]) +async def get_root() -> dict: + '''Обработчик корневых запросов.''' + return {'msg': 'root msg'} + + +@app.get("/workers/{wid}", tags=["Workers management"], + dependencies=[Depends(right_checkers["write"])]) +async def get_worker(wid: int): + '''Тестовый обработчик.''' + worker = data.get_worker_object(wid=wid) + worker.run() + print(f"worker: {type(worker)} object {worker}") + + await worker.send({"text": "INFO"}) + worker_data = await worker.get() + + if worker_data['type'] == 'log': + data.log_message[data['level']](data['msg']) + return worker_data + +# Authentification and users management. +app.include_router(users_router) + +# Commands creation and management. +app.include_router(commands_router) diff --git a/calculate/server/server_data.py b/calculate/server/server_data.py new file mode 100644 index 0000000..9448162 --- /dev/null +++ b/calculate/server/server_data.py @@ -0,0 +1,104 @@ +import os +import asyncio +import importlib +from typing import Dict, Optional + +from logging.config import dictConfig +from logging import getLogger, Logger + +from ..variables.loader import Datavars +from ..commands.commands import Command, CommandRunner +from .utils.workers import Worker + +from .schemas.config import ConfigSchema + + +# Получаем конфигурацию сервера. +TESTING = bool(os.environ.get("TESTING", False)) +if not TESTING: + from .config import config + server_config = ConfigSchema(**config) +else: + from tests.server.config import config + server_config = ConfigSchema(**config) + + +class ServerData: + def __init__(self, config: ConfigSchema = server_config): + self.event_loop = asyncio.get_event_loop() + + self._variables_path = config.variables_path + + # Конфигурируем логгирование. + dictConfig(config.logger_config) + self.logger: Logger = getLogger("main") + self.log_message = {'DEBUG': self.logger.debug, + 'INFO': self.logger.info, + 'WARNING': self.logger.warning, + 'ERROR': self.logger.error, + 'CRITICAL': self.logger.critical} + + self._datavars: Optional[Datavars] = None + + # Словарь описаний команд. + self.commands: Dict[str, Command] = self._get_commands_descriptions( + config.commands_path) + + # Словарь CID и экземпляров команд, передаваемых воркерам. + self.commands_runners: Dict[str, CommandRunner] = {} + + # Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам. + self.workers: Dict[int, Worker] = {} + + @property + def datavars(self): + if self._datavars is not None: + return self._datavars + else: + self._datavars = self._load_datavars(self._variables_path, + self.logger) + return self._datavars + + def _load_datavars(vars_path: str, logger: Logger) -> Datavars: + return Datavars(variables_path=vars_path, + logger=logger) + + def _get_commands_descriptions(self, commands_path: str + ) -> Dict[str, Command]: + '''Метод для получения совокупности описаний команд.''' + output = {} + package = ".".join(commands_path.split("/")) + + for entry in os.scandir(commands_path): + if (not entry.name.endswith('.py') + or entry.name in {"commands.py", "__init__.py"}): + continue + module_name = entry.name[:-3] + try: + module = importlib.import_module("{}.{}".format(package, + module_name)) + for obj_name in dir(module): + obj = module.__getattribute__(obj_name) + if isinstance(obj, Command): + output[obj.id] = obj + except Exception: + continue + return output + + def make_command(self, command_id: str, ) -> int: + '''Метод для создания команды по ее описанию.''' + command_description = self.commands[command_id] + + def _get_worker_object(self, wid: Optional[int] = None) -> Worker: + '''Метод для получения воркера для команды.''' + if wid is not None: + worker = Worker(wid, self._event_loop, self._datavars) + self._workers[wid] = worker + elif not self._workers: + worker = Worker(0, self._event_loop, self._datavars) + self._workers[0] = worker + else: + wid = max(self._workers.keys()) + 1 + worker = Worker(wid, self._event_loop, self._datavars) + self._workers[wid] = worker + return worker diff --git a/calculate/server/utils/auth.py b/calculate/server/utils/auth.py new file mode 100644 index 0000000..a1b8ab9 --- /dev/null +++ b/calculate/server/utils/auth.py @@ -0,0 +1,77 @@ +from hashlib import pbkdf2_hmac +from random import choices +from string import ascii_letters +from datetime import datetime, timedelta + +from typing import Optional, Union +from jose import jwt + +from ..schemas.tokens import TokenData +from ..schemas.users import UserData + +from .users import get_user_by_username + +from calculate.utils.files import Process + + +def make_secret_key(): + openssl_process = Process("/usr/bin/openssl", "rand", "-hex", "32") + secret_key = openssl_process.read() + return secret_key.strip() + + +# SECRET_KEY =\ +# "efe90242c1c221b20fc718edc3aa5da4f78147eb2f4e81e809e945bbcdf0710e" +SECRET_KEY = make_secret_key() +ALGORITHM = "HS256" + + +async def auth_user(username: str, password: str) -> Union[UserData, bool]: + user_row = await get_user_by_username(username) + if not user_row: + return False + user = UserData(**user_row) + if not validate_password(password, user.password): + return False + return user + + +def decode_jwt(token: str) -> Union[TokenData, None]: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + expire = payload.get("exp") + if username is None: + return None + return TokenData(username=username, + expire=expire) + + +def create_access_token(data: dict, + expires_delta: Optional[timedelta] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def get_salt(length=12): + """Метод для получения случайной строки символов используемой как соль + при кэшировании.""" + return "".join(choices(ascii_letters, k=length)) + + +def validate_password(password: str, hashed_password: str) -> bool: + salt, db_hash = hashed_password.split("$") + hashed = hash_password(password, salt) + return hashed == db_hash + + +def hash_password(password: str, salt: str) -> str: + if salt is None: + salt = get_salt() + enc = pbkdf2_hmac("sha256", password.encode(), salt.encode(), 100_000) + return enc.hex() diff --git a/calculate/server/utils/dependencies.py b/calculate/server/utils/dependencies.py new file mode 100644 index 0000000..3077d4a --- /dev/null +++ b/calculate/server/utils/dependencies.py @@ -0,0 +1,50 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer + +from jose import JWTError + +from pydantic import ValidationError + +from .auth import decode_jwt +from .users import get_user_by_username + +from ..schemas.users import UserRead, UserWrite, UserAdmin + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth") + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"} + ) + try: + token_data = decode_jwt(token) + if token_data is None: + raise credentials_exception + + except (JWTError, ValidationError): + raise credentials_exception + + user = await get_user_by_username(token_data.username) + if user is None: + raise credentials_exception + + return user + + +def make_right_checkers(): + rights_schemas = {"read": UserRead, "write": UserWrite, "admin": UserAdmin} + dependencies = {} + for right, schema in rights_schemas.items(): + async def depend_function(token: str = Depends(oauth2_scheme)): + user = await get_current_user(token=token) + return schema(**user) + + dependencies[right] = depend_function + return dependencies + + +right_checkers = make_right_checkers() diff --git a/calculate/server/utils/users.py b/calculate/server/utils/users.py new file mode 100644 index 0000000..44eedc9 --- /dev/null +++ b/calculate/server/utils/users.py @@ -0,0 +1,38 @@ +from sqlalchemy import func, and_ +from typing import List + +from ..models.database import database +from ..models.users import users_table, users_rights, rights_table + +from ..schemas.users import UserCreate + + +async def get_user_by_username(username: str): + '''Метод для получения строки с данными пользователя из базы данных по + username.''' + query = users_table.select().where(users_table.c.login == username) + user_data = await database.fetch_one(query) + query = (users_table. + join(users_rights). + join(rights_table). + select(). + where(and_(users_table.c.id == users_rights.c.user_id, + users_table.c.login == username, + rights_table.c.id == users_rights.c.right_id)). + with_only_columns([users_table.c.id, + users_table.c.login, + users_table.c.password, + func.group_concat(rights_table.c.name, + ' ').label("rights")]). + group_by(users_table.c.id)) + response = await database.fetch_one(query) + + user_data = dict(response) + user_data['rights'] = user_data['rights'].split() + return user_data + + +async def create_user(username: str, hashed_password: str, rights: List[str]): + user = UserCreate(login=username, + password=hashed_password, + rights=rights) diff --git a/calculate/server/worker.py b/calculate/server/utils/workers.py similarity index 98% rename from calculate/server/worker.py rename to calculate/server/utils/workers.py index f45eb38..98c99e3 100644 --- a/calculate/server/worker.py +++ b/calculate/server/utils/workers.py @@ -3,9 +3,9 @@ import json import socket import logging from typing import Union -from ..variables.loader import Datavars +from calculate.variables.loader import Datavars from multiprocessing import Queue, Process -from ..commands.commands import CommandRunner, Command +from calculate.commands.commands import CommandRunner, Command # from time import sleep diff --git a/calculate/templates/format/sqlite_format.py b/calculate/templates/format/sqlite_format.py index 094d209..d0fc01d 100644 --- a/calculate/templates/format/sqlite_format.py +++ b/calculate/templates/format/sqlite_format.py @@ -9,6 +9,7 @@ import os class SqliteFormat(Format): FORMAT = 'sqlite' EXECUTABLE = True + FORMAT_PARAMETERS = {'execsql'} def __init__(self, template_text: str, template_path: str, @@ -33,11 +34,15 @@ class SqliteFormat(Format): '''Метод для запуска работы формата.''' if os.path.exists(target_path) and os.path.isdir(target_path): raise FormatError(f"directory on target path: {target_path}") + status = "N" + if os.path.exists(target_path): + status = "M" connection = connect(target_path) self._executor(connection, self._template_text) connection.close() + self.changed_files[target_path] = status return self.changed_files def _execute_continue(self, connection: Connection, template_text: str): @@ -63,10 +68,10 @@ class SqliteFormat(Format): try: for command in commands: cursor.execute(command) - cursor.execute("COMMIT;") + connection.commit() except (OperationalError, IntegrityError) as error: self._warnings.append(str(error)) - cursor.execute("ROLLBACK;") + connection.rollback() def _execute_stop(self, connection: Connection, template_text: str): """Метод для выполнения sql-команд при значения параметра diff --git a/calculate/utils/ldap.py b/calculate/utils/ldap.py new file mode 100644 index 0000000..e69de29 diff --git a/client.py b/client.py index a8ff745..d5a78c7 100644 --- a/client.py +++ b/client.py @@ -1,3 +1,4 @@ +import os import asyncio import aiohttp @@ -26,6 +27,7 @@ async def main(): if __name__ == '__main__': + print(f"TESTING = {os.environ.get('TESTING', None)}") loop = asyncio.new_event_loop() try: loop.run_until_complete(main()) diff --git a/conftest.py b/conftest.py index 4d0fc6e..245495c 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,8 @@ +import os import pytest + from collections import OrderedDict +os.environ["TESTING"] = "True" @pytest.fixture(scope='function') diff --git a/requirements.txt b/requirements.txt index b8a382c..b9f6232 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,13 @@ +python-ldap +requests +uvicorn +fastapi pytest jinja2 xattr lxml mock +python-jose +python-multipart +sqlalchemy +databases[sqlite] diff --git a/tests/server/config.py b/tests/server/config.py new file mode 100644 index 0000000..8ce2df6 --- /dev/null +++ b/tests/server/config.py @@ -0,0 +1,11 @@ +import os +from calculate.logging import dictLogConfig + + +testfiles_path = os.path.join(os.getcwd(), 'tests/server/testfiles') + + +config = {"socket_path": "./input.sock", + "variables_path": 'tests/server/testfiles/variables', + "commands_path": 'tests/server/testfiles/commands', + "logger_config": dictLogConfig} diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 75a9642..cb5c2e6 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -2,62 +2,85 @@ import os import pytest import shutil from fastapi.testclient import TestClient -from calculate.server.server import Server +from calculate.server.server import app TESTFILES_PATH = os.path.join(os.getcwd(), 'tests/server/testfiles') -VARS_PATH = os.path.join(TESTFILES_PATH, 'variables') -COMMANDS_PATH = 'tests/server/testfiles/commands' -server = Server(datavars_path=VARS_PATH, commands_path=COMMANDS_PATH) -test_client = TestClient(server.app) + +test_client = TestClient(app) + + +def authenticate(username: str, password: str): + request_headers = {"accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded"} + request_data = {"username": username, "password": password, + "grant_type": "password", "scope": None, "client_id": None, + "client_id": None} + token = test_client.post("/auth", data=request_data, json=request_headers) + token_header = token.json() + access_token = token_header["access_token"] + return {"accept": "application/json", + "Authorization": f"Bearer {access_token}"} + + +@pytest.mark.server +def test_to_make_testfiles(): + shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'), + os.path.join(TESTFILES_PATH, 'var'), + symlinks=True) + shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'), + os.path.join(TESTFILES_PATH, 'etc'), + symlinks=True) + + +@pytest.mark.server +def test_get_root_message(): + authorization_headers = authenticate("denis", "secret") + response = test_client.get("/", headers=authorization_headers) + assert response.status_code == 200 + assert response.json() == {"msg": "root msg"} + + +@pytest.mark.server +def test_get_commands_list(): + authorization_headers = authenticate("denis", "secret") + response = test_client.get("/commands", headers=authorization_headers) + assert response.status_code == 200 + assert response.json() == {"test_1": + {"title": "Test 1", + "category": "Test Category", + "icon": "/path/to/icon_1.png", + "command": "test_1"}, + "test_2": + {"title": "Test 2", + "category": "Test Category", + "icon": "/path/to/icon_2.png", + "command": "cl_test_2"}} + + +# @pytest.mark.server +# def test_post_command(): +# response = test_client.get("/commands/") +# assert response.status_code == 200 + + +# @pytest.mark.server +# def test_get_command_by_cid(): +# response = test_client.get("/commands/0") +# assert response.status_code == 200 +# assert response.json() == {"id": 0, "name": "command_0"} + + +# @pytest.mark.server +# def test_get_worker_message_by_wid(): +# response = test_client.get("/workers/0") +# assert response.status_code == 200 +# data = response.json() +# assert data == {'type': 'log', 'level': 'INFO', +# 'msg': 'recieved message INFO'} @pytest.mark.server -class TestServer: - pass - # def test_to_make_testfiles(self): - # shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'), - # os.path.join(TESTFILES_PATH, 'var'), - # symlinks=True) - # shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'), - # os.path.join(TESTFILES_PATH, 'etc'), - # symlinks=True) - - # def test_get_root_message(self): - # response = test_client.get("/") - # assert response.status_code == 200 - # assert response.json() == {"msg": "root msg"} - - # def test_get_commands_list(self): - # response = test_client.get("/commands") - # assert response.status_code == 200 - # assert response.json() == {"test_1": - # {"title": "Test 1", - # "category": "Test Category", - # "icon": "/path/to/icon_1.png", - # "command": "test_1"}, - # "test_2": - # {"title": "Test 2", - # "category": "Test Category", - # "icon": "/path/to/icon_2.png", - # "command": "cl_test_2"}} - - # def test_post_command(self): - # response = test_client.get("/commands/") - # assert response.status_code == 200 - - # def test_get_command_by_cid(self): - # response = test_client.get("/commands/0") - # assert response.status_code == 200 - # assert response.json() == {"id": 0, "name": "command_0"} - - # def test_get_worker_message_by_wid(self): - # response = test_client.get("/workers/0") - # assert response.status_code == 200 - # data = response.json() - # assert data == {'type': 'log', 'level': 'INFO', - # 'msg': 'recieved message INFO'} - - # def test_for_removing_testfiles(self): - # shutil.rmtree(os.path.join(TESTFILES_PATH, 'var')) - # shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc')) +def test_for_removing_testfiles(): + shutil.rmtree(os.path.join(TESTFILES_PATH, 'var')) + shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc')) diff --git a/tests/server/testfiles/etc.backup/temp b/tests/server/testfiles/etc.backup/temp new file mode 100644 index 0000000..e69de29