Browse Source

Added authentification using access tokens.

master
Денис Иванов 4 months ago
parent
commit
91690f73e2
27 changed files with 659 additions and 214 deletions
  1. +1
    -0
      .gitignore
  2. +2
    -2
      calculate/parameters/parameters.py
  3. +7
    -0
      calculate/server/config.py
  4. +14
    -0
      calculate/server/models/database.py
  5. +34
    -0
      calculate/server/models/users.py
  6. +0
    -0
      calculate/server/models/workers.py
  7. +56
    -0
      calculate/server/routers/commands.py
  8. +46
    -0
      calculate/server/routers/users.py
  9. +0
    -0
      calculate/server/routers/workers.py
  10. +13
    -0
      calculate/server/schemas/config.py
  11. +11
    -0
      calculate/server/schemas/tokens.py
  12. +49
    -0
      calculate/server/schemas/users.py
  13. +0
    -0
      calculate/server/schemas/workers.py
  14. +48
    -155
      calculate/server/server.py
  15. +104
    -0
      calculate/server/server_data.py
  16. +77
    -0
      calculate/server/utils/auth.py
  17. +50
    -0
      calculate/server/utils/dependencies.py
  18. +38
    -0
      calculate/server/utils/users.py
  19. +2
    -2
      calculate/server/utils/workers.py
  20. +7
    -2
      calculate/templates/format/sqlite_format.py
  21. +0
    -0
      calculate/utils/ldap.py
  22. +2
    -0
      client.py
  23. +3
    -0
      conftest.py
  24. +8
    -0
      requirements.txt
  25. +11
    -0
      tests/server/config.py
  26. +76
    -53
      tests/server/test_server.py
  27. +0
    -0
      tests/server/testfiles/etc.backup/temp

+ 1
- 0
.gitignore View File

@@ -2,6 +2,7 @@
*.pyo
*.swp
*.sock
*.db
build/
dist/
calculate_lib.egg-info/


+ 2
- 2
calculate/parameters/parameters.py View File

@@ -8,11 +8,11 @@ from typing import Tuple, Union, Any


class ParameterError(Exception):
...
pass


class ValidationError(ParameterError):
...
pass


class CyclicValidationError(ValidationError):


+ 7
- 0
calculate/server/config.py View File

@@ -0,0 +1,7 @@
from calculate.logging import dictLogConfig


config = {"socket_path": "./input.sock",
"variables_path": "calculate/variables",
"commands_path": "calculate/commands",
"logger_config": dictLogConfig}

+ 14
- 0
calculate/server/models/database.py View File

@@ -0,0 +1,14 @@
from os import environ
from databases import Database


TESTING = bool(environ.get("TESTING", False))


if TESTING:
DATABASE_URL = "sqlite:///tests/server/testfiles/test.db"
else:
# Временно.
DATABASE_URL = "sqlite:///calculate/server/tmp.db"

database = Database(DATABASE_URL)

+ 34
- 0
calculate/server/models/users.py View File

@@ -0,0 +1,34 @@
from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey


metadata = MetaData()


users_table: Table = Table("users",
metadata,
Column("id", Integer, primary_key=True),
Column("login",
String(20),
unique=True,
nullable=False,
index=True),
Column("password",
String(77),
nullable=False)
)

rights_table: Table = Table("rights",
metadata,
Column("id", Integer, primary_key=True),
Column("name",
String(20),
unique=True,
nullable=False,
index=True),
Column("description", String(40)))

users_rights: Table = Table("users_rights",
metadata,
Column("user_id", ForeignKey("users.id")),
Column("right_id", ForeignKey("rights.id"))
)

+ 0
- 0
calculate/server/models/workers.py View File


+ 56
- 0
calculate/server/routers/commands.py View File

@@ -0,0 +1,56 @@
from fastapi import APIRouter, Depends

from ..utils.dependencies import right_checkers

from ..server_data import ServerData


data = ServerData()
router = APIRouter()


@router.get("/commands", tags=["Commands management"],
dependencies=[Depends(right_checkers["read"])])
async def get_commands() -> dict:
'''Обработчик, отвечающий на запросы списка команд.'''
response = {}
for command_id, command_object in data.commands.items():
response.update({command_id: {"title": command_object.title,
"category": command_object.category,
"icon": command_object.icon,
"command": command_object.command}})
return response


@router.get("/commands/{cid}", tags=["Commands management"],
dependencies=[Depends(right_checkers["read"])])
async def get_command(cid: int) -> dict:
'''Обработчик запросов списка команд.'''
if cid not in data.commands_instances:
# TODO добавить какую-то обработку ошибки.
pass
return {'id': cid,
'name': f'command_{cid}'}


@router.get("/commands/{cid}/groups", tags=["Commands management"],
dependencies=[Depends(right_checkers["read"])])
async def get_command_parameters_groups(cid: int) -> dict:
'''Обработчик запросов на получение групп параметров указанной команды.'''
pass


@router.get("/commands/{cid}/parameters", tags=["Commands management"],
dependencies=[Depends(right_checkers["read"])])
async def get_command_parameters(cid: int) -> dict:
'''Обработчик запросов на получение параметров указанной команды.'''
pass


@router.post("/commands/{command_id}", tags=["Commands management"],
dependencies=[Depends(right_checkers["write"])])
async def post_command(command_id: str) -> int:
if command_id not in data.commands:
# TODO добавить какую-то обработку ошибки.
pass
return

+ 46
- 0
calculate/server/routers/users.py View File

@@ -0,0 +1,46 @@
from datetime import timedelta

from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm

from ..utils.auth import auth_user, create_access_token
from ..utils.dependencies import right_checkers

from ..schemas.tokens import Token
from ..schemas.users import User, UserCreate


ACCESS_TOKEN_EXPIRE_MINUTES = 30
router = APIRouter()


@router.post("/auth", tags=["Authentication"], response_model=Token)
async def authenticate(form_data: OAuth2PasswordRequestForm = Depends()):
'''Метод обрабатывающий запросы на аутентификацию пользователей по данным
указанным в форме, возвращает access_token.'''
user = await auth_user(form_data.username,
form_data.password)
if not user:
raise HTTPException(
status_code=400,
detail=f"User '{form_data.username}' not found"
)

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={"sub": user.login},
expires_delta=access_token_expires)
# TODO добавить refresh_token
return {"access_token": access_token,
"token_type": "bearer"}


@router.post("/refresh", tags=["Authentication"], response_model=Token)
async def refresh():
'''Метод для обработки запросов с refresh-токенами.'''
pass


@router.post("/users/create", tags=["Users management"], response_model=User)
async def create_user(user_info: UserCreate,
current_user=Depends(right_checkers["admin"])):
print("CREATE USER")

+ 0
- 0
calculate/server/routers/workers.py View File


+ 13
- 0
calculate/server/schemas/config.py View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel


class ConfigSchema(BaseModel):
socket_path: str
variables_path: str
commands_path: str

logger_config: dict

class Config:
min_any_str_length = 1
anystr_strip_whitespace = True

+ 11
- 0
calculate/server/schemas/tokens.py View File

@@ -0,0 +1,11 @@
from pydantic import BaseModel


class Token(BaseModel):
access_token: str
token_type: str


class TokenData(BaseModel):
username: str
expire: int

+ 49
- 0
calculate/server/schemas/users.py View File

@@ -0,0 +1,49 @@
from fastapi import HTTPException, status
from pydantic import BaseModel, validator
from typing import List, Optional


class UserCreate(BaseModel):
login: str
password: str
rights: Optional[List[str]] = []


class UserData(UserCreate):
id: int


class User(BaseModel):
id: int
login: str
rights: List[str]


class UserRead(User):
@validator("rights")
def check_permissions(cls, value):
check_rights("read", value)
return value


class UserWrite(User):
@validator("rights")
def check_permissions(cls, value):
check_rights("write", value)
return value


class UserAdmin(User):
@validator("rights")
def check_permissions(cls, value):
check_rights("admin", value)
return value


def check_rights(right: str, rights_list: List[str]):
if right not in rights_list:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": "Bearer"},
)

+ 0
- 0
calculate/server/schemas/workers.py View File


+ 48
- 155
calculate/server/server.py View File

@@ -1,19 +1,10 @@
from ..variables.loader import Datavars
from ..commands.commands import Command
from ..logging import dictLogConfig
from logging.config import dictConfig
from typing import (
Callable,
Optional,
NoReturn
)
from logging import getLogger
from fastapi import FastAPI
from .worker import Worker
import importlib
import uvicorn
import asyncio
import os
from fastapi import FastAPI, Depends

from .server_data import ServerData
from .utils.dependencies import right_checkers
from .models.database import database
from .routers.users import router as users_router
from .routers.commands import router as commands_router


# TODO
@@ -21,142 +12,44 @@ import os
# 2. Разобраться с объектами воркеров. И способом их функционирования.


class Server:
def __init__(self, socket_path: str = './input.sock',
datavars_path: str = 'calculate/vars/',
commands_path: str = 'calculate/commands'):
self._app = FastAPI()
self._socket_path = socket_path
self._event_loop = asyncio.get_event_loop()

# Конфигурируем логгирование.
dictConfig(dictLogConfig)
self._logger = getLogger("main")
self.log_msg = {'DEBUG': self._logger.debug,
'INFO': self._logger.info,
'WARNING': self._logger.warning,
'ERROR': self._logger.error,
'CRITICAL': self._logger.critical}

self._datavars = Datavars(variables_path=datavars_path,
logger=self._logger)

# Словарь описаний команд.
self._commands = self._get_commands_list(commands_path)

# Словарь CID и экземпляров команд, передаваемых воркерам.
self._commands_instances = {}

# Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам.
self._workers = {}

# Соответствие путей обработчикам запросов для HTTP-метода GET.
self._add_routes(self._app.get,
{"/": self._get_root,
"/commands": self._get_commands,
"/commands/{cid}": self._get_command,
"/workers/{wid}": self._get_worker})
self._add_routes(self._app.post,
{"/commands/{command_id}": self._post_command})

def _get_commands_list(self, commands_path: str) -> list:
'''Метод для получения совокупности описаний команд.'''
output = {}
package = ".".join(commands_path.split("/"))

for entry in os.scandir(commands_path):
if (not entry.name.endswith('.py')
or entry.name in {"commands.py", "__init__.py"}):
continue
module_name = entry.name[:-3]
try:
module = importlib.import_module("{}.{}".format(package,
module_name))
for obj in dir(module):
if type(module.__getattribute__(obj)) == Command:
command_object = module.__getattribute__(obj)
output[command_object.id] = command_object
except Exception:
continue
return output

# Обработчики запросов серверу.
async def _get_root(self) -> dict:
'''Обработчик корневых запросов.'''
return {'msg': 'root msg'}

async def _get_commands(self) -> dict:
'''Обработчик, отвечающий на запросы списка команд.'''
response = {}
for command_id, command_object in self._commands.items():
response.update({command_id: {"title": command_object.title,
"category": command_object.category,
"icon": command_object.icon,
"command": command_object.command}})
return response

async def _get_command(self, cid: int) -> dict:
'''Обработчик запросов списка команд.'''
if cid not in self._commands_instances:
# TODO добавить какую-то обработку ошибки.
pass
return {'id': cid,
'name': f'command_{cid}'}

async def _get_worker(self, wid: int):
'''Тестовый обработчик.'''
self._make_worker(wid=wid)
worker = self._workers[wid]
worker.run(None)
await worker.send({"text": "INFO"})
data = await worker.get()
if data['type'] == 'log':
self.log_msg[data['level']](data['msg'])
return data

async def _post_command(self, command_id: str) -> int:
if command_id not in self._commands:
# TODO добавить какую-то обработку ошибки.
pass
return

# Обработчики сообщений воркеров.

# Вспомогательные методы.
def _add_routes(self, method: Callable, routes: dict) -> NoReturn:
'''Метод для добавления методов.'''
for path, handler in routes.items():
router = method(path)
router(handler)

def _make_worker(self, wid: Optional[int] = None):
'''Метод для создания воркера для команды.'''
if wid is not None:
self._workers[wid] = Worker(wid, self._event_loop)
return wid
elif not self._workers:
self._workers[0] = Worker(0, self._event_loop)
return 0
else:
wid = max(self._workers.keys()) + 1
self._workers[wid] = Worker(wid, self._event_loop)
return wid

def _make_command(self, command_id: str) -> int:
'''Метод для создания команды по ее описанию.'''
command_description = self._commands[command]

@property
def app(self):
return self._app

def run(self):
'''Метод для запуска сервера.'''
# Выгружаем список команд.
uvicorn.run(self._app,
uds=self._socket_path)


if __name__ == '__main__':
server = Server()
server.run()
data = ServerData()
app = FastAPI()


@app.on_event("startup")
async def startup():
await database.connect()


@app.on_event("shutdown")
async def shutdown():
await database.disconnect()


@app.get("/", tags=["Root"],
dependencies=[Depends(right_checkers["read"])])
async def get_root() -> dict:
'''Обработчик корневых запросов.'''
return {'msg': 'root msg'}


@app.get("/workers/{wid}", tags=["Workers management"],
dependencies=[Depends(right_checkers["write"])])
async def get_worker(wid: int):
'''Тестовый обработчик.'''
worker = data.get_worker_object(wid=wid)
worker.run()
print(f"worker: {type(worker)} object {worker}")

await worker.send({"text": "INFO"})
worker_data = await worker.get()

if worker_data['type'] == 'log':
data.log_message[data['level']](data['msg'])
return worker_data

# Authentification and users management.
app.include_router(users_router)

# Commands creation and management.
app.include_router(commands_router)

+ 104
- 0
calculate/server/server_data.py View File

@@ -0,0 +1,104 @@
import os
import asyncio
import importlib
from typing import Dict, Optional

from logging.config import dictConfig
from logging import getLogger, Logger

from ..variables.loader import Datavars
from ..commands.commands import Command, CommandRunner
from .utils.workers import Worker

from .schemas.config import ConfigSchema


# Получаем конфигурацию сервера.
TESTING = bool(os.environ.get("TESTING", False))
if not TESTING:
from .config import config
server_config = ConfigSchema(**config)
else:
from tests.server.config import config
server_config = ConfigSchema(**config)


class ServerData:
def __init__(self, config: ConfigSchema = server_config):
self.event_loop = asyncio.get_event_loop()

self._variables_path = config.variables_path

# Конфигурируем логгирование.
dictConfig(config.logger_config)
self.logger: Logger = getLogger("main")
self.log_message = {'DEBUG': self.logger.debug,
'INFO': self.logger.info,
'WARNING': self.logger.warning,
'ERROR': self.logger.error,
'CRITICAL': self.logger.critical}

self._datavars: Optional[Datavars] = None

# Словарь описаний команд.
self.commands: Dict[str, Command] = self._get_commands_descriptions(
config.commands_path)

# Словарь CID и экземпляров команд, передаваемых воркерам.
self.commands_runners: Dict[str, CommandRunner] = {}

# Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам.
self.workers: Dict[int, Worker] = {}

@property
def datavars(self):
if self._datavars is not None:
return self._datavars
else:
self._datavars = self._load_datavars(self._variables_path,
self.logger)
return self._datavars

def _load_datavars(vars_path: str, logger: Logger) -> Datavars:
return Datavars(variables_path=vars_path,
logger=logger)

def _get_commands_descriptions(self, commands_path: str
) -> Dict[str, Command]:
'''Метод для получения совокупности описаний команд.'''
output = {}
package = ".".join(commands_path.split("/"))

for entry in os.scandir(commands_path):
if (not entry.name.endswith('.py')
or entry.name in {"commands.py", "__init__.py"}):
continue
module_name = entry.name[:-3]
try:
module = importlib.import_module("{}.{}".format(package,
module_name))
for obj_name in dir(module):
obj = module.__getattribute__(obj_name)
if isinstance(obj, Command):
output[obj.id] = obj
except Exception:
continue
return output

def make_command(self, command_id: str, ) -> int:
'''Метод для создания команды по ее описанию.'''
command_description = self.commands[command_id]

def _get_worker_object(self, wid: Optional[int] = None) -> Worker:
'''Метод для получения воркера для команды.'''
if wid is not None:
worker = Worker(wid, self._event_loop, self._datavars)
self._workers[wid] = worker
elif not self._workers:
worker = Worker(0, self._event_loop, self._datavars)
self._workers[0] = worker
else:
wid = max(self._workers.keys()) + 1
worker = Worker(wid, self._event_loop, self._datavars)
self._workers[wid] = worker
return worker

+ 77
- 0
calculate/server/utils/auth.py View File

@@ -0,0 +1,77 @@
from hashlib import pbkdf2_hmac
from random import choices
from string import ascii_letters
from datetime import datetime, timedelta

from typing import Optional, Union
from jose import jwt

from ..schemas.tokens import TokenData
from ..schemas.users import UserData

from .users import get_user_by_username

from calculate.utils.files import Process


def make_secret_key():
openssl_process = Process("/usr/bin/openssl", "rand", "-hex", "32")
secret_key = openssl_process.read()
return secret_key.strip()


# SECRET_KEY =\
# "efe90242c1c221b20fc718edc3aa5da4f78147eb2f4e81e809e945bbcdf0710e"
SECRET_KEY = make_secret_key()
ALGORITHM = "HS256"


async def auth_user(username: str, password: str) -> Union[UserData, bool]:
user_row = await get_user_by_username(username)
if not user_row:
return False
user = UserData(**user_row)
if not validate_password(password, user.password):
return False
return user


def decode_jwt(token: str) -> Union[TokenData, None]:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
expire = payload.get("exp")
if username is None:
return None
return TokenData(username=username,
expire=expire)


def create_access_token(data: dict,
expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def get_salt(length=12):
"""Метод для получения случайной строки символов используемой как соль
при кэшировании."""
return "".join(choices(ascii_letters, k=length))


def validate_password(password: str, hashed_password: str) -> bool:
salt, db_hash = hashed_password.split("$")
hashed = hash_password(password, salt)
return hashed == db_hash


def hash_password(password: str, salt: str) -> str:
if salt is None:
salt = get_salt()
enc = pbkdf2_hmac("sha256", password.encode(), salt.encode(), 100_000)
return enc.hex()

+ 50
- 0
calculate/server/utils/dependencies.py View File

@@ -0,0 +1,50 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer

from jose import JWTError

from pydantic import ValidationError

from .auth import decode_jwt
from .users import get_user_by_username

from ..schemas.users import UserRead, UserWrite, UserAdmin


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth")


async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"}
)
try:
token_data = decode_jwt(token)
if token_data is None:
raise credentials_exception

except (JWTError, ValidationError):
raise credentials_exception

user = await get_user_by_username(token_data.username)
if user is None:
raise credentials_exception

return user


def make_right_checkers():
rights_schemas = {"read": UserRead, "write": UserWrite, "admin": UserAdmin}
dependencies = {}
for right, schema in rights_schemas.items():
async def depend_function(token: str = Depends(oauth2_scheme)):
user = await get_current_user(token=token)
return schema(**user)

dependencies[right] = depend_function
return dependencies


right_checkers = make_right_checkers()

+ 38
- 0
calculate/server/utils/users.py View File

@@ -0,0 +1,38 @@
from sqlalchemy import func, and_
from typing import List

from ..models.database import database
from ..models.users import users_table, users_rights, rights_table

from ..schemas.users import UserCreate


async def get_user_by_username(username: str):
'''Метод для получения строки с данными пользователя из базы данных по
username.'''
query = users_table.select().where(users_table.c.login == username)
user_data = await database.fetch_one(query)
query = (users_table.
join(users_rights).
join(rights_table).
select().
where(and_(users_table.c.id == users_rights.c.user_id,
users_table.c.login == username,
rights_table.c.id == users_rights.c.right_id)).
with_only_columns([users_table.c.id,
users_table.c.login,
users_table.c.password,
func.group_concat(rights_table.c.name,
' ').label("rights")]).
group_by(users_table.c.id))
response = await database.fetch_one(query)

user_data = dict(response)
user_data['rights'] = user_data['rights'].split()
return user_data


async def create_user(username: str, hashed_password: str, rights: List[str]):
user = UserCreate(login=username,
password=hashed_password,
rights=rights)

calculate/server/worker.py → calculate/server/utils/workers.py View File

@@ -3,9 +3,9 @@ import json
import socket
import logging
from typing import Union
from ..variables.loader import Datavars
from calculate.variables.loader import Datavars
from multiprocessing import Queue, Process
from ..commands.commands import CommandRunner, Command
from calculate.commands.commands import CommandRunner, Command
# from time import sleep



+ 7
- 2
calculate/templates/format/sqlite_format.py View File

@@ -9,6 +9,7 @@ import os
class SqliteFormat(Format):
FORMAT = 'sqlite'
EXECUTABLE = True
FORMAT_PARAMETERS = {'execsql'}

def __init__(self, template_text: str,
template_path: str,
@@ -33,11 +34,15 @@ class SqliteFormat(Format):
'''Метод для запуска работы формата.'''
if os.path.exists(target_path) and os.path.isdir(target_path):
raise FormatError(f"directory on target path: {target_path}")
status = "N"
if os.path.exists(target_path):
status = "M"

connection = connect(target_path)
self._executor(connection, self._template_text)
connection.close()

self.changed_files[target_path] = status
return self.changed_files

def _execute_continue(self, connection: Connection, template_text: str):
@@ -63,10 +68,10 @@ class SqliteFormat(Format):
try:
for command in commands:
cursor.execute(command)
cursor.execute("COMMIT;")
connection.commit()
except (OperationalError, IntegrityError) as error:
self._warnings.append(str(error))
cursor.execute("ROLLBACK;")
connection.rollback()

def _execute_stop(self, connection: Connection, template_text: str):
"""Метод для выполнения sql-команд при значения параметра


+ 0
- 0
calculate/utils/ldap.py View File


+ 2
- 0
client.py View File

@@ -1,3 +1,4 @@
import os
import asyncio
import aiohttp

@@ -26,6 +27,7 @@ async def main():


if __name__ == '__main__':
print(f"TESTING = {os.environ.get('TESTING', None)}")
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())


+ 3
- 0
conftest.py View File

@@ -1,5 +1,8 @@
import os
import pytest

from collections import OrderedDict
os.environ["TESTING"] = "True"


@pytest.fixture(scope='function')


+ 8
- 0
requirements.txt View File

@@ -1,5 +1,13 @@
python-ldap
requests
uvicorn
fastapi
pytest
jinja2
xattr
lxml
mock
python-jose
python-multipart
sqlalchemy
databases[sqlite]

+ 11
- 0
tests/server/config.py View File

@@ -0,0 +1,11 @@
import os
from calculate.logging import dictLogConfig


testfiles_path = os.path.join(os.getcwd(), 'tests/server/testfiles')


config = {"socket_path": "./input.sock",
"variables_path": 'tests/server/testfiles/variables',
"commands_path": 'tests/server/testfiles/commands',
"logger_config": dictLogConfig}

+ 76
- 53
tests/server/test_server.py View File

@@ -2,62 +2,85 @@ import os
import pytest
import shutil
from fastapi.testclient import TestClient
from calculate.server.server import Server
from calculate.server.server import app


TESTFILES_PATH = os.path.join(os.getcwd(), 'tests/server/testfiles')
VARS_PATH = os.path.join(TESTFILES_PATH, 'variables')
COMMANDS_PATH = 'tests/server/testfiles/commands'
server = Server(datavars_path=VARS_PATH, commands_path=COMMANDS_PATH)
test_client = TestClient(server.app)

test_client = TestClient(app)


def authenticate(username: str, password: str):
request_headers = {"accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded"}
request_data = {"username": username, "password": password,
"grant_type": "password", "scope": None, "client_id": None,
"client_id": None}
token = test_client.post("/auth", data=request_data, json=request_headers)
token_header = token.json()
access_token = token_header["access_token"]
return {"accept": "application/json",
"Authorization": f"Bearer {access_token}"}


@pytest.mark.server
def test_to_make_testfiles():
shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'),
os.path.join(TESTFILES_PATH, 'var'),
symlinks=True)
shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'),
os.path.join(TESTFILES_PATH, 'etc'),
symlinks=True)


@pytest.mark.server
def test_get_root_message():
authorization_headers = authenticate("denis", "secret")
response = test_client.get("/", headers=authorization_headers)
assert response.status_code == 200
assert response.json() == {"msg": "root msg"}


@pytest.mark.server
def test_get_commands_list():
authorization_headers = authenticate("denis", "secret")
response = test_client.get("/commands", headers=authorization_headers)
assert response.status_code == 200
assert response.json() == {"test_1":
{"title": "Test 1",
"category": "Test Category",
"icon": "/path/to/icon_1.png",
"command": "test_1"},
"test_2":
{"title": "Test 2",
"category": "Test Category",
"icon": "/path/to/icon_2.png",
"command": "cl_test_2"}}


# @pytest.mark.server
# def test_post_command():
# response = test_client.get("/commands/")
# assert response.status_code == 200


# @pytest.mark.server
# def test_get_command_by_cid():
# response = test_client.get("/commands/0")
# assert response.status_code == 200
# assert response.json() == {"id": 0, "name": "command_0"}


# @pytest.mark.server
# def test_get_worker_message_by_wid():
# response = test_client.get("/workers/0")
# assert response.status_code == 200
# data = response.json()
# assert data == {'type': 'log', 'level': 'INFO',
# 'msg': 'recieved message INFO'}


@pytest.mark.server
class TestServer:
pass
# def test_to_make_testfiles(self):
# shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'),
# os.path.join(TESTFILES_PATH, 'var'),
# symlinks=True)
# shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'),
# os.path.join(TESTFILES_PATH, 'etc'),
# symlinks=True)

# def test_get_root_message(self):
# response = test_client.get("/")
# assert response.status_code == 200
# assert response.json() == {"msg": "root msg"}

# def test_get_commands_list(self):
# response = test_client.get("/commands")
# assert response.status_code == 200
# assert response.json() == {"test_1":
# {"title": "Test 1",
# "category": "Test Category",
# "icon": "/path/to/icon_1.png",
# "command": "test_1"},
# "test_2":
# {"title": "Test 2",
# "category": "Test Category",
# "icon": "/path/to/icon_2.png",
# "command": "cl_test_2"}}

# def test_post_command(self):
# response = test_client.get("/commands/")
# assert response.status_code == 200

# def test_get_command_by_cid(self):
# response = test_client.get("/commands/0")
# assert response.status_code == 200
# assert response.json() == {"id": 0, "name": "command_0"}

# def test_get_worker_message_by_wid(self):
# response = test_client.get("/workers/0")
# assert response.status_code == 200
# data = response.json()
# assert data == {'type': 'log', 'level': 'INFO',
# 'msg': 'recieved message INFO'}

# def test_for_removing_testfiles(self):
# shutil.rmtree(os.path.join(TESTFILES_PATH, 'var'))
# shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc'))
def test_for_removing_testfiles():
shutil.rmtree(os.path.join(TESTFILES_PATH, 'var'))
shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc'))

+ 0
- 0
tests/server/testfiles/etc.backup/temp View File


Loading…
Cancel
Save