Added authentification using access tokens.
This commit is contained in:
parent
a2e6088216
commit
91690f73e2
27 changed files with 651 additions and 206 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,6 +2,7 @@
|
|||
*.pyo
|
||||
*.swp
|
||||
*.sock
|
||||
*.db
|
||||
build/
|
||||
dist/
|
||||
calculate_lib.egg-info/
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import Tuple, Union, Any
|
|||
|
||||
|
||||
class ParameterError(Exception):
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(ParameterError):
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class CyclicValidationError(ValidationError):
|
||||
|
|
7
calculate/server/config.py
Normal file
7
calculate/server/config.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from calculate.logging import dictLogConfig
|
||||
|
||||
|
||||
config = {"socket_path": "./input.sock",
|
||||
"variables_path": "calculate/variables",
|
||||
"commands_path": "calculate/commands",
|
||||
"logger_config": dictLogConfig}
|
14
calculate/server/models/database.py
Normal file
14
calculate/server/models/database.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from os import environ
|
||||
from databases import Database
|
||||
|
||||
|
||||
TESTING = bool(environ.get("TESTING", False))
|
||||
|
||||
|
||||
if TESTING:
|
||||
DATABASE_URL = "sqlite:///tests/server/testfiles/test.db"
|
||||
else:
|
||||
# Временно.
|
||||
DATABASE_URL = "sqlite:///calculate/server/tmp.db"
|
||||
|
||||
database = Database(DATABASE_URL)
|
34
calculate/server/models/users.py
Normal file
34
calculate/server/models/users.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey
|
||||
|
||||
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
users_table: Table = Table("users",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("login",
|
||||
String(20),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True),
|
||||
Column("password",
|
||||
String(77),
|
||||
nullable=False)
|
||||
)
|
||||
|
||||
rights_table: Table = Table("rights",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("name",
|
||||
String(20),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True),
|
||||
Column("description", String(40)))
|
||||
|
||||
users_rights: Table = Table("users_rights",
|
||||
metadata,
|
||||
Column("user_id", ForeignKey("users.id")),
|
||||
Column("right_id", ForeignKey("rights.id"))
|
||||
)
|
0
calculate/server/models/workers.py
Normal file
0
calculate/server/models/workers.py
Normal file
56
calculate/server/routers/commands.py
Normal file
56
calculate/server/routers/commands.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
|
||||
from ..utils.dependencies import right_checkers
|
||||
|
||||
from ..server_data import ServerData
|
||||
|
||||
|
||||
data = ServerData()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/commands", tags=["Commands management"],
|
||||
dependencies=[Depends(right_checkers["read"])])
|
||||
async def get_commands() -> dict:
|
||||
'''Обработчик, отвечающий на запросы списка команд.'''
|
||||
response = {}
|
||||
for command_id, command_object in data.commands.items():
|
||||
response.update({command_id: {"title": command_object.title,
|
||||
"category": command_object.category,
|
||||
"icon": command_object.icon,
|
||||
"command": command_object.command}})
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/commands/{cid}", tags=["Commands management"],
|
||||
dependencies=[Depends(right_checkers["read"])])
|
||||
async def get_command(cid: int) -> dict:
|
||||
'''Обработчик запросов списка команд.'''
|
||||
if cid not in data.commands_instances:
|
||||
# TODO добавить какую-то обработку ошибки.
|
||||
pass
|
||||
return {'id': cid,
|
||||
'name': f'command_{cid}'}
|
||||
|
||||
|
||||
@router.get("/commands/{cid}/groups", tags=["Commands management"],
|
||||
dependencies=[Depends(right_checkers["read"])])
|
||||
async def get_command_parameters_groups(cid: int) -> dict:
|
||||
'''Обработчик запросов на получение групп параметров указанной команды.'''
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/commands/{cid}/parameters", tags=["Commands management"],
|
||||
dependencies=[Depends(right_checkers["read"])])
|
||||
async def get_command_parameters(cid: int) -> dict:
|
||||
'''Обработчик запросов на получение параметров указанной команды.'''
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/commands/{command_id}", tags=["Commands management"],
|
||||
dependencies=[Depends(right_checkers["write"])])
|
||||
async def post_command(command_id: str) -> int:
|
||||
if command_id not in data.commands:
|
||||
# TODO добавить какую-то обработку ошибки.
|
||||
pass
|
||||
return
|
46
calculate/server/routers/users.py
Normal file
46
calculate/server/routers/users.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from ..utils.auth import auth_user, create_access_token
|
||||
from ..utils.dependencies import right_checkers
|
||||
|
||||
from ..schemas.tokens import Token
|
||||
from ..schemas.users import User, UserCreate
|
||||
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/auth", tags=["Authentication"], response_model=Token)
|
||||
async def authenticate(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
'''Метод обрабатывающий запросы на аутентификацию пользователей по данным
|
||||
указанным в форме, возвращает access_token.'''
|
||||
user = await auth_user(form_data.username,
|
||||
form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"User '{form_data.username}' not found"
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(data={"sub": user.login},
|
||||
expires_delta=access_token_expires)
|
||||
# TODO добавить refresh_token
|
||||
return {"access_token": access_token,
|
||||
"token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/refresh", tags=["Authentication"], response_model=Token)
|
||||
async def refresh():
|
||||
'''Метод для обработки запросов с refresh-токенами.'''
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/users/create", tags=["Users management"], response_model=User)
|
||||
async def create_user(user_info: UserCreate,
|
||||
current_user=Depends(right_checkers["admin"])):
|
||||
print("CREATE USER")
|
0
calculate/server/routers/workers.py
Normal file
0
calculate/server/routers/workers.py
Normal file
13
calculate/server/schemas/config.py
Normal file
13
calculate/server/schemas/config.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
socket_path: str
|
||||
variables_path: str
|
||||
commands_path: str
|
||||
|
||||
logger_config: dict
|
||||
|
||||
class Config:
|
||||
min_any_str_length = 1
|
||||
anystr_strip_whitespace = True
|
11
calculate/server/schemas/tokens.py
Normal file
11
calculate/server/schemas/tokens.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str
|
||||
expire: int
|
49
calculate/server/schemas/users.py
Normal file
49
calculate/server/schemas/users.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
login: str
|
||||
password: str
|
||||
rights: Optional[List[str]] = []
|
||||
|
||||
|
||||
class UserData(UserCreate):
|
||||
id: int
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: int
|
||||
login: str
|
||||
rights: List[str]
|
||||
|
||||
|
||||
class UserRead(User):
|
||||
@validator("rights")
|
||||
def check_permissions(cls, value):
|
||||
check_rights("read", value)
|
||||
return value
|
||||
|
||||
|
||||
class UserWrite(User):
|
||||
@validator("rights")
|
||||
def check_permissions(cls, value):
|
||||
check_rights("write", value)
|
||||
return value
|
||||
|
||||
|
||||
class UserAdmin(User):
|
||||
@validator("rights")
|
||||
def check_permissions(cls, value):
|
||||
check_rights("admin", value)
|
||||
return value
|
||||
|
||||
|
||||
def check_rights(right: str, rights_list: List[str]):
|
||||
if right not in rights_list:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not enough permissions",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
0
calculate/server/schemas/workers.py
Normal file
0
calculate/server/schemas/workers.py
Normal file
|
@ -1,19 +1,10 @@
|
|||
from ..variables.loader import Datavars
|
||||
from ..commands.commands import Command
|
||||
from ..logging import dictLogConfig
|
||||
from logging.config import dictConfig
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
NoReturn
|
||||
)
|
||||
from logging import getLogger
|
||||
from fastapi import FastAPI
|
||||
from .worker import Worker
|
||||
import importlib
|
||||
import uvicorn
|
||||
import asyncio
|
||||
import os
|
||||
from fastapi import FastAPI, Depends
|
||||
|
||||
from .server_data import ServerData
|
||||
from .utils.dependencies import right_checkers
|
||||
from .models.database import database
|
||||
from .routers.users import router as users_router
|
||||
from .routers.commands import router as commands_router
|
||||
|
||||
|
||||
# TODO
|
||||
|
@ -21,142 +12,44 @@ import os
|
|||
# 2. Разобраться с объектами воркеров. И способом их функционирования.
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, socket_path: str = './input.sock',
|
||||
datavars_path: str = 'calculate/vars/',
|
||||
commands_path: str = 'calculate/commands'):
|
||||
self._app = FastAPI()
|
||||
self._socket_path = socket_path
|
||||
self._event_loop = asyncio.get_event_loop()
|
||||
|
||||
# Конфигурируем логгирование.
|
||||
dictConfig(dictLogConfig)
|
||||
self._logger = getLogger("main")
|
||||
self.log_msg = {'DEBUG': self._logger.debug,
|
||||
'INFO': self._logger.info,
|
||||
'WARNING': self._logger.warning,
|
||||
'ERROR': self._logger.error,
|
||||
'CRITICAL': self._logger.critical}
|
||||
|
||||
self._datavars = Datavars(variables_path=datavars_path,
|
||||
logger=self._logger)
|
||||
|
||||
# Словарь описаний команд.
|
||||
self._commands = self._get_commands_list(commands_path)
|
||||
|
||||
# Словарь CID и экземпляров команд, передаваемых воркерам.
|
||||
self._commands_instances = {}
|
||||
|
||||
# Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам.
|
||||
self._workers = {}
|
||||
|
||||
# Соответствие путей обработчикам запросов для HTTP-метода GET.
|
||||
self._add_routes(self._app.get,
|
||||
{"/": self._get_root,
|
||||
"/commands": self._get_commands,
|
||||
"/commands/{cid}": self._get_command,
|
||||
"/workers/{wid}": self._get_worker})
|
||||
self._add_routes(self._app.post,
|
||||
{"/commands/{command_id}": self._post_command})
|
||||
|
||||
def _get_commands_list(self, commands_path: str) -> list:
|
||||
'''Метод для получения совокупности описаний команд.'''
|
||||
output = {}
|
||||
package = ".".join(commands_path.split("/"))
|
||||
|
||||
for entry in os.scandir(commands_path):
|
||||
if (not entry.name.endswith('.py')
|
||||
or entry.name in {"commands.py", "__init__.py"}):
|
||||
continue
|
||||
module_name = entry.name[:-3]
|
||||
try:
|
||||
module = importlib.import_module("{}.{}".format(package,
|
||||
module_name))
|
||||
for obj in dir(module):
|
||||
if type(module.__getattribute__(obj)) == Command:
|
||||
command_object = module.__getattribute__(obj)
|
||||
output[command_object.id] = command_object
|
||||
except Exception:
|
||||
continue
|
||||
return output
|
||||
|
||||
# Обработчики запросов серверу.
|
||||
async def _get_root(self) -> dict:
|
||||
'''Обработчик корневых запросов.'''
|
||||
return {'msg': 'root msg'}
|
||||
|
||||
async def _get_commands(self) -> dict:
|
||||
'''Обработчик, отвечающий на запросы списка команд.'''
|
||||
response = {}
|
||||
for command_id, command_object in self._commands.items():
|
||||
response.update({command_id: {"title": command_object.title,
|
||||
"category": command_object.category,
|
||||
"icon": command_object.icon,
|
||||
"command": command_object.command}})
|
||||
return response
|
||||
|
||||
async def _get_command(self, cid: int) -> dict:
|
||||
'''Обработчик запросов списка команд.'''
|
||||
if cid not in self._commands_instances:
|
||||
# TODO добавить какую-то обработку ошибки.
|
||||
pass
|
||||
return {'id': cid,
|
||||
'name': f'command_{cid}'}
|
||||
|
||||
async def _get_worker(self, wid: int):
|
||||
'''Тестовый обработчик.'''
|
||||
self._make_worker(wid=wid)
|
||||
worker = self._workers[wid]
|
||||
worker.run(None)
|
||||
await worker.send({"text": "INFO"})
|
||||
data = await worker.get()
|
||||
if data['type'] == 'log':
|
||||
self.log_msg[data['level']](data['msg'])
|
||||
return data
|
||||
|
||||
async def _post_command(self, command_id: str) -> int:
|
||||
if command_id not in self._commands:
|
||||
# TODO добавить какую-то обработку ошибки.
|
||||
pass
|
||||
return
|
||||
|
||||
# Обработчики сообщений воркеров.
|
||||
|
||||
# Вспомогательные методы.
|
||||
def _add_routes(self, method: Callable, routes: dict) -> NoReturn:
|
||||
'''Метод для добавления методов.'''
|
||||
for path, handler in routes.items():
|
||||
router = method(path)
|
||||
router(handler)
|
||||
|
||||
def _make_worker(self, wid: Optional[int] = None):
|
||||
'''Метод для создания воркера для команды.'''
|
||||
if wid is not None:
|
||||
self._workers[wid] = Worker(wid, self._event_loop)
|
||||
return wid
|
||||
elif not self._workers:
|
||||
self._workers[0] = Worker(0, self._event_loop)
|
||||
return 0
|
||||
else:
|
||||
wid = max(self._workers.keys()) + 1
|
||||
self._workers[wid] = Worker(wid, self._event_loop)
|
||||
return wid
|
||||
|
||||
def _make_command(self, command_id: str) -> int:
|
||||
'''Метод для создания команды по ее описанию.'''
|
||||
command_description = self._commands[command]
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
return self._app
|
||||
|
||||
def run(self):
|
||||
'''Метод для запуска сервера.'''
|
||||
# Выгружаем список команд.
|
||||
uvicorn.run(self._app,
|
||||
uds=self._socket_path)
|
||||
data = ServerData()
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = Server()
|
||||
server.run()
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
await database.connect()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
await database.disconnect()
|
||||
|
||||
|
||||
@app.get("/", tags=["Root"],
|
||||
dependencies=[Depends(right_checkers["read"])])
|
||||
async def get_root() -> dict:
|
||||
'''Обработчик корневых запросов.'''
|
||||
return {'msg': 'root msg'}
|
||||
|
||||
|
||||
@app.get("/workers/{wid}", tags=["Workers management"],
|
||||
dependencies=[Depends(right_checkers["write"])])
|
||||
async def get_worker(wid: int):
|
||||
'''Тестовый обработчик.'''
|
||||
worker = data.get_worker_object(wid=wid)
|
||||
worker.run()
|
||||
print(f"worker: {type(worker)} object {worker}")
|
||||
|
||||
await worker.send({"text": "INFO"})
|
||||
worker_data = await worker.get()
|
||||
|
||||
if worker_data['type'] == 'log':
|
||||
data.log_message[data['level']](data['msg'])
|
||||
return worker_data
|
||||
|
||||
# Authentification and users management.
|
||||
app.include_router(users_router)
|
||||
|
||||
# Commands creation and management.
|
||||
app.include_router(commands_router)
|
||||
|
|
104
calculate/server/server_data.py
Normal file
104
calculate/server/server_data.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import Dict, Optional
|
||||
|
||||
from logging.config import dictConfig
|
||||
from logging import getLogger, Logger
|
||||
|
||||
from ..variables.loader import Datavars
|
||||
from ..commands.commands import Command, CommandRunner
|
||||
from .utils.workers import Worker
|
||||
|
||||
from .schemas.config import ConfigSchema
|
||||
|
||||
|
||||
# Получаем конфигурацию сервера.
|
||||
TESTING = bool(os.environ.get("TESTING", False))
|
||||
if not TESTING:
|
||||
from .config import config
|
||||
server_config = ConfigSchema(**config)
|
||||
else:
|
||||
from tests.server.config import config
|
||||
server_config = ConfigSchema(**config)
|
||||
|
||||
|
||||
class ServerData:
|
||||
def __init__(self, config: ConfigSchema = server_config):
|
||||
self.event_loop = asyncio.get_event_loop()
|
||||
|
||||
self._variables_path = config.variables_path
|
||||
|
||||
# Конфигурируем логгирование.
|
||||
dictConfig(config.logger_config)
|
||||
self.logger: Logger = getLogger("main")
|
||||
self.log_message = {'DEBUG': self.logger.debug,
|
||||
'INFO': self.logger.info,
|
||||
'WARNING': self.logger.warning,
|
||||
'ERROR': self.logger.error,
|
||||
'CRITICAL': self.logger.critical}
|
||||
|
||||
self._datavars: Optional[Datavars] = None
|
||||
|
||||
# Словарь описаний команд.
|
||||
self.commands: Dict[str, Command] = self._get_commands_descriptions(
|
||||
config.commands_path)
|
||||
|
||||
# Словарь CID и экземпляров команд, передаваемых воркерам.
|
||||
self.commands_runners: Dict[str, CommandRunner] = {}
|
||||
|
||||
# Словарь WID и экземпляров процессов-воркеров, передаваемых воркерам.
|
||||
self.workers: Dict[int, Worker] = {}
|
||||
|
||||
@property
|
||||
def datavars(self):
|
||||
if self._datavars is not None:
|
||||
return self._datavars
|
||||
else:
|
||||
self._datavars = self._load_datavars(self._variables_path,
|
||||
self.logger)
|
||||
return self._datavars
|
||||
|
||||
def _load_datavars(vars_path: str, logger: Logger) -> Datavars:
|
||||
return Datavars(variables_path=vars_path,
|
||||
logger=logger)
|
||||
|
||||
def _get_commands_descriptions(self, commands_path: str
|
||||
) -> Dict[str, Command]:
|
||||
'''Метод для получения совокупности описаний команд.'''
|
||||
output = {}
|
||||
package = ".".join(commands_path.split("/"))
|
||||
|
||||
for entry in os.scandir(commands_path):
|
||||
if (not entry.name.endswith('.py')
|
||||
or entry.name in {"commands.py", "__init__.py"}):
|
||||
continue
|
||||
module_name = entry.name[:-3]
|
||||
try:
|
||||
module = importlib.import_module("{}.{}".format(package,
|
||||
module_name))
|
||||
for obj_name in dir(module):
|
||||
obj = module.__getattribute__(obj_name)
|
||||
if isinstance(obj, Command):
|
||||
output[obj.id] = obj
|
||||
except Exception:
|
||||
continue
|
||||
return output
|
||||
|
||||
def make_command(self, command_id: str, ) -> int:
|
||||
'''Метод для создания команды по ее описанию.'''
|
||||
command_description = self.commands[command_id]
|
||||
|
||||
def _get_worker_object(self, wid: Optional[int] = None) -> Worker:
|
||||
'''Метод для получения воркера для команды.'''
|
||||
if wid is not None:
|
||||
worker = Worker(wid, self._event_loop, self._datavars)
|
||||
self._workers[wid] = worker
|
||||
elif not self._workers:
|
||||
worker = Worker(0, self._event_loop, self._datavars)
|
||||
self._workers[0] = worker
|
||||
else:
|
||||
wid = max(self._workers.keys()) + 1
|
||||
worker = Worker(wid, self._event_loop, self._datavars)
|
||||
self._workers[wid] = worker
|
||||
return worker
|
77
calculate/server/utils/auth.py
Normal file
77
calculate/server/utils/auth.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from hashlib import pbkdf2_hmac
|
||||
from random import choices
|
||||
from string import ascii_letters
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from typing import Optional, Union
|
||||
from jose import jwt
|
||||
|
||||
from ..schemas.tokens import TokenData
|
||||
from ..schemas.users import UserData
|
||||
|
||||
from .users import get_user_by_username
|
||||
|
||||
from calculate.utils.files import Process
|
||||
|
||||
|
||||
def make_secret_key():
|
||||
openssl_process = Process("/usr/bin/openssl", "rand", "-hex", "32")
|
||||
secret_key = openssl_process.read()
|
||||
return secret_key.strip()
|
||||
|
||||
|
||||
# SECRET_KEY =\
|
||||
# "efe90242c1c221b20fc718edc3aa5da4f78147eb2f4e81e809e945bbcdf0710e"
|
||||
SECRET_KEY = make_secret_key()
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
async def auth_user(username: str, password: str) -> Union[UserData, bool]:
|
||||
user_row = await get_user_by_username(username)
|
||||
if not user_row:
|
||||
return False
|
||||
user = UserData(**user_row)
|
||||
if not validate_password(password, user.password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
def decode_jwt(token: str) -> Union[TokenData, None]:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
expire = payload.get("exp")
|
||||
if username is None:
|
||||
return None
|
||||
return TokenData(username=username,
|
||||
expire=expire)
|
||||
|
||||
|
||||
def create_access_token(data: dict,
|
||||
expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def get_salt(length=12):
|
||||
"""Метод для получения случайной строки символов используемой как соль
|
||||
при кэшировании."""
|
||||
return "".join(choices(ascii_letters, k=length))
|
||||
|
||||
|
||||
def validate_password(password: str, hashed_password: str) -> bool:
|
||||
salt, db_hash = hashed_password.split("$")
|
||||
hashed = hash_password(password, salt)
|
||||
return hashed == db_hash
|
||||
|
||||
|
||||
def hash_password(password: str, salt: str) -> str:
|
||||
if salt is None:
|
||||
salt = get_salt()
|
||||
enc = pbkdf2_hmac("sha256", password.encode(), salt.encode(), 100_000)
|
||||
return enc.hex()
|
50
calculate/server/utils/dependencies.py
Normal file
50
calculate/server/utils/dependencies.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from jose import JWTError
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .auth import decode_jwt
|
||||
from .users import get_user_by_username
|
||||
|
||||
from ..schemas.users import UserRead, UserWrite, UserAdmin
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth")
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
try:
|
||||
token_data = decode_jwt(token)
|
||||
if token_data is None:
|
||||
raise credentials_exception
|
||||
|
||||
except (JWTError, ValidationError):
|
||||
raise credentials_exception
|
||||
|
||||
user = await get_user_by_username(token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def make_right_checkers():
|
||||
rights_schemas = {"read": UserRead, "write": UserWrite, "admin": UserAdmin}
|
||||
dependencies = {}
|
||||
for right, schema in rights_schemas.items():
|
||||
async def depend_function(token: str = Depends(oauth2_scheme)):
|
||||
user = await get_current_user(token=token)
|
||||
return schema(**user)
|
||||
|
||||
dependencies[right] = depend_function
|
||||
return dependencies
|
||||
|
||||
|
||||
right_checkers = make_right_checkers()
|
38
calculate/server/utils/users.py
Normal file
38
calculate/server/utils/users.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from sqlalchemy import func, and_
|
||||
from typing import List
|
||||
|
||||
from ..models.database import database
|
||||
from ..models.users import users_table, users_rights, rights_table
|
||||
|
||||
from ..schemas.users import UserCreate
|
||||
|
||||
|
||||
async def get_user_by_username(username: str):
|
||||
'''Метод для получения строки с данными пользователя из базы данных по
|
||||
username.'''
|
||||
query = users_table.select().where(users_table.c.login == username)
|
||||
user_data = await database.fetch_one(query)
|
||||
query = (users_table.
|
||||
join(users_rights).
|
||||
join(rights_table).
|
||||
select().
|
||||
where(and_(users_table.c.id == users_rights.c.user_id,
|
||||
users_table.c.login == username,
|
||||
rights_table.c.id == users_rights.c.right_id)).
|
||||
with_only_columns([users_table.c.id,
|
||||
users_table.c.login,
|
||||
users_table.c.password,
|
||||
func.group_concat(rights_table.c.name,
|
||||
' ').label("rights")]).
|
||||
group_by(users_table.c.id))
|
||||
response = await database.fetch_one(query)
|
||||
|
||||
user_data = dict(response)
|
||||
user_data['rights'] = user_data['rights'].split()
|
||||
return user_data
|
||||
|
||||
|
||||
async def create_user(username: str, hashed_password: str, rights: List[str]):
|
||||
user = UserCreate(login=username,
|
||||
password=hashed_password,
|
||||
rights=rights)
|
|
@ -3,9 +3,9 @@ import json
|
|||
import socket
|
||||
import logging
|
||||
from typing import Union
|
||||
from ..variables.loader import Datavars
|
||||
from calculate.variables.loader import Datavars
|
||||
from multiprocessing import Queue, Process
|
||||
from ..commands.commands import CommandRunner, Command
|
||||
from calculate.commands.commands import CommandRunner, Command
|
||||
# from time import sleep
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ import os
|
|||
class SqliteFormat(Format):
|
||||
FORMAT = 'sqlite'
|
||||
EXECUTABLE = True
|
||||
FORMAT_PARAMETERS = {'execsql'}
|
||||
|
||||
def __init__(self, template_text: str,
|
||||
template_path: str,
|
||||
|
@ -33,11 +34,15 @@ class SqliteFormat(Format):
|
|||
'''Метод для запуска работы формата.'''
|
||||
if os.path.exists(target_path) and os.path.isdir(target_path):
|
||||
raise FormatError(f"directory on target path: {target_path}")
|
||||
status = "N"
|
||||
if os.path.exists(target_path):
|
||||
status = "M"
|
||||
|
||||
connection = connect(target_path)
|
||||
self._executor(connection, self._template_text)
|
||||
connection.close()
|
||||
|
||||
self.changed_files[target_path] = status
|
||||
return self.changed_files
|
||||
|
||||
def _execute_continue(self, connection: Connection, template_text: str):
|
||||
|
@ -63,10 +68,10 @@ class SqliteFormat(Format):
|
|||
try:
|
||||
for command in commands:
|
||||
cursor.execute(command)
|
||||
cursor.execute("COMMIT;")
|
||||
connection.commit()
|
||||
except (OperationalError, IntegrityError) as error:
|
||||
self._warnings.append(str(error))
|
||||
cursor.execute("ROLLBACK;")
|
||||
connection.rollback()
|
||||
|
||||
def _execute_stop(self, connection: Connection, template_text: str):
|
||||
"""Метод для выполнения sql-команд при значения параметра
|
||||
|
|
0
calculate/utils/ldap.py
Normal file
0
calculate/utils/ldap.py
Normal file
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
|
@ -26,6 +27,7 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f"TESTING = {os.environ.get('TESTING', None)}")
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
from collections import OrderedDict
|
||||
os.environ["TESTING"] = "True"
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
python-ldap
|
||||
requests
|
||||
uvicorn
|
||||
fastapi
|
||||
pytest
|
||||
jinja2
|
||||
xattr
|
||||
lxml
|
||||
mock
|
||||
python-jose
|
||||
python-multipart
|
||||
sqlalchemy
|
||||
databases[sqlite]
|
||||
|
|
11
tests/server/config.py
Normal file
11
tests/server/config.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
import os
|
||||
from calculate.logging import dictLogConfig
|
||||
|
||||
|
||||
testfiles_path = os.path.join(os.getcwd(), 'tests/server/testfiles')
|
||||
|
||||
|
||||
config = {"socket_path": "./input.sock",
|
||||
"variables_path": 'tests/server/testfiles/variables',
|
||||
"commands_path": 'tests/server/testfiles/commands',
|
||||
"logger_config": dictLogConfig}
|
|
@ -2,62 +2,85 @@ import os
|
|||
import pytest
|
||||
import shutil
|
||||
from fastapi.testclient import TestClient
|
||||
from calculate.server.server import Server
|
||||
from calculate.server.server import app
|
||||
|
||||
|
||||
TESTFILES_PATH = os.path.join(os.getcwd(), 'tests/server/testfiles')
|
||||
VARS_PATH = os.path.join(TESTFILES_PATH, 'variables')
|
||||
COMMANDS_PATH = 'tests/server/testfiles/commands'
|
||||
server = Server(datavars_path=VARS_PATH, commands_path=COMMANDS_PATH)
|
||||
test_client = TestClient(server.app)
|
||||
|
||||
test_client = TestClient(app)
|
||||
|
||||
|
||||
def authenticate(username: str, password: str):
|
||||
request_headers = {"accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded"}
|
||||
request_data = {"username": username, "password": password,
|
||||
"grant_type": "password", "scope": None, "client_id": None,
|
||||
"client_id": None}
|
||||
token = test_client.post("/auth", data=request_data, json=request_headers)
|
||||
token_header = token.json()
|
||||
access_token = token_header["access_token"]
|
||||
return {"accept": "application/json",
|
||||
"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
@pytest.mark.server
|
||||
class TestServer:
|
||||
pass
|
||||
# def test_to_make_testfiles(self):
|
||||
# shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'),
|
||||
# os.path.join(TESTFILES_PATH, 'var'),
|
||||
# symlinks=True)
|
||||
# shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'),
|
||||
# os.path.join(TESTFILES_PATH, 'etc'),
|
||||
# symlinks=True)
|
||||
def test_to_make_testfiles():
|
||||
shutil.copytree(os.path.join(TESTFILES_PATH, 'var.backup'),
|
||||
os.path.join(TESTFILES_PATH, 'var'),
|
||||
symlinks=True)
|
||||
shutil.copytree(os.path.join(TESTFILES_PATH, 'etc.backup'),
|
||||
os.path.join(TESTFILES_PATH, 'etc'),
|
||||
symlinks=True)
|
||||
|
||||
# def test_get_root_message(self):
|
||||
# response = test_client.get("/")
|
||||
# assert response.status_code == 200
|
||||
# assert response.json() == {"msg": "root msg"}
|
||||
|
||||
# def test_get_commands_list(self):
|
||||
# response = test_client.get("/commands")
|
||||
# assert response.status_code == 200
|
||||
# assert response.json() == {"test_1":
|
||||
# {"title": "Test 1",
|
||||
# "category": "Test Category",
|
||||
# "icon": "/path/to/icon_1.png",
|
||||
# "command": "test_1"},
|
||||
# "test_2":
|
||||
# {"title": "Test 2",
|
||||
# "category": "Test Category",
|
||||
# "icon": "/path/to/icon_2.png",
|
||||
# "command": "cl_test_2"}}
|
||||
@pytest.mark.server
|
||||
def test_get_root_message():
|
||||
authorization_headers = authenticate("denis", "secret")
|
||||
response = test_client.get("/", headers=authorization_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"msg": "root msg"}
|
||||
|
||||
# def test_post_command(self):
|
||||
# response = test_client.get("/commands/")
|
||||
# assert response.status_code == 200
|
||||
|
||||
# def test_get_command_by_cid(self):
|
||||
# response = test_client.get("/commands/0")
|
||||
# assert response.status_code == 200
|
||||
# assert response.json() == {"id": 0, "name": "command_0"}
|
||||
@pytest.mark.server
|
||||
def test_get_commands_list():
|
||||
authorization_headers = authenticate("denis", "secret")
|
||||
response = test_client.get("/commands", headers=authorization_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"test_1":
|
||||
{"title": "Test 1",
|
||||
"category": "Test Category",
|
||||
"icon": "/path/to/icon_1.png",
|
||||
"command": "test_1"},
|
||||
"test_2":
|
||||
{"title": "Test 2",
|
||||
"category": "Test Category",
|
||||
"icon": "/path/to/icon_2.png",
|
||||
"command": "cl_test_2"}}
|
||||
|
||||
# def test_get_worker_message_by_wid(self):
|
||||
# response = test_client.get("/workers/0")
|
||||
# assert response.status_code == 200
|
||||
# data = response.json()
|
||||
# assert data == {'type': 'log', 'level': 'INFO',
|
||||
# 'msg': 'recieved message INFO'}
|
||||
|
||||
# def test_for_removing_testfiles(self):
|
||||
# shutil.rmtree(os.path.join(TESTFILES_PATH, 'var'))
|
||||
# shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc'))
|
||||
# @pytest.mark.server
|
||||
# def test_post_command():
|
||||
# response = test_client.get("/commands/")
|
||||
# assert response.status_code == 200
|
||||
|
||||
|
||||
# @pytest.mark.server
|
||||
# def test_get_command_by_cid():
|
||||
# response = test_client.get("/commands/0")
|
||||
# assert response.status_code == 200
|
||||
# assert response.json() == {"id": 0, "name": "command_0"}
|
||||
|
||||
|
||||
# @pytest.mark.server
|
||||
# def test_get_worker_message_by_wid():
|
||||
# response = test_client.get("/workers/0")
|
||||
# assert response.status_code == 200
|
||||
# data = response.json()
|
||||
# assert data == {'type': 'log', 'level': 'INFO',
|
||||
# 'msg': 'recieved message INFO'}
|
||||
|
||||
|
||||
@pytest.mark.server
|
||||
def test_for_removing_testfiles():
|
||||
shutil.rmtree(os.path.join(TESTFILES_PATH, 'var'))
|
||||
shutil.rmtree(os.path.join(TESTFILES_PATH, 'etc'))
|
||||
|
|
0
tests/server/testfiles/etc.backup/temp
Normal file
0
tests/server/testfiles/etc.backup/temp
Normal file
Loading…
Add table
Reference in a new issue