You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
415 lines
16 KiB
415 lines
16 KiB
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2012-2016 Mir Calculate. http://www.calculate-linux.org
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
from __future__ import absolute_import
|
|
from spyne.server.wsgi import WsgiApplication
|
|
|
|
import re
|
|
import logging
|
|
import os
|
|
#cStringIO was moved to io in python3
|
|
import cStringIO as io
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
import datetime
|
|
import pickle
|
|
from .loaded_methods import LoadedMethods
|
|
|
|
# for OpenSSLAdapter
|
|
import calculate.contrib
|
|
from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
|
|
|
|
HTTP_500 = '500 Internal server error'
|
|
HTTP_200 = '200 OK'
|
|
HTTP_405 = '405 Method Not Allowed'
|
|
HTTP_403 = '403 Forbidden'
|
|
not_log_list = ['post_server_request', 'post_client_request', 'del_sid',
|
|
'get_server_cert', 'get_client_cert', 'get_entire_frame',
|
|
'get_crl', 'get_server_host_name', 'get_ca', 'get_table',
|
|
'post_cert', 'post_sid', 'active_client', 'list_pid',
|
|
'get_methods', 'get_frame', 'get_progress', 'pid_info']
|
|
|
|
|
|
class ClApplication(WsgiApplication):
|
|
def __init__(self, app, log=None):
|
|
super(ClApplication, self).__init__(app)
|
|
# add object logging
|
|
self.log = logger
|
|
|
|
#verification of compliance certificate and session (sid)
|
|
def check_cert_sid(self, sid, server):
|
|
import threading
|
|
|
|
curthread = threading.currentThread()
|
|
cert = curthread.client_cert
|
|
from .cert_cmd import find_cert_id
|
|
|
|
cert_id = find_cert_id(cert, server.data_path, server.certbase)
|
|
cert_id = int(cert_id)
|
|
if cert_id == 0:
|
|
return 0
|
|
|
|
# session file
|
|
if not os.path.exists(server.sids):
|
|
os.system('mkdir %s' % server.sids)
|
|
|
|
if not os.path.isfile(server.sids_file):
|
|
open(server.sids_file, 'w').close()
|
|
with open(server.sids_file, 'r') as fd:
|
|
while 1:
|
|
try:
|
|
# read all on one record
|
|
list_sid = pickle.load(fd)
|
|
except (IOError, EOFError, KeyError):
|
|
break
|
|
# find session id in sids file
|
|
if cert_id == int(list_sid[1]):
|
|
if int(sid) == int(list_sid[0]):
|
|
return 1
|
|
return 0
|
|
|
|
# input parameters - certificate and name method
|
|
def check_rights(self, method_name, req_env, sid):
|
|
""" check right client certificate for the method """
|
|
import OpenSSL
|
|
|
|
# rmethod = re.compile('[{\w]+[}]')
|
|
# method_rep = rmethod.findall(method_name)
|
|
# method_name = method_name.replace(method_rep[0], '')
|
|
import threading
|
|
|
|
curthread = threading.currentThread()
|
|
|
|
cert = curthread.client_cert
|
|
server_cert = curthread.server.ssl_certificate
|
|
server_key = curthread.server.ssl_private_key
|
|
certbase = curthread.server.certbase
|
|
rights = curthread.server.rights
|
|
group_rights = curthread.server.group_rights
|
|
data_path = curthread.server.data_path
|
|
permitted_methods = ['post_server_request', 'post_client_request',
|
|
'get_server_cert', 'get_client_cert',
|
|
'get_crl', 'get_server_host_name', 'get_ca']
|
|
|
|
if method_name in permitted_methods:
|
|
return 1
|
|
if cert is None:
|
|
if method_name not in permitted_methods:
|
|
return 0
|
|
return 1
|
|
|
|
if (sid and
|
|
(method_name in LoadedMethods.rightsMethods or
|
|
method_name.endswith('_view') and
|
|
method_name[:-5] in LoadedMethods.rightsMethods)):
|
|
if not self.check_cert_sid(sid, curthread.server):
|
|
return 0
|
|
|
|
with open(server_cert, 'r') as f:
|
|
data_server_cert = f.read()
|
|
certobj = OpenSSL.crypto.load_certificate(
|
|
OpenSSL.SSL.FILETYPE_PEM, data_server_cert)
|
|
|
|
with open(server_key, 'r') as f:
|
|
data_server_key = f.read()
|
|
Pkey = OpenSSL.crypto.load_privatekey(OpenSSL.SSL.FILETYPE_PEM,
|
|
data_server_key, 'qqqq')
|
|
signature = OpenSSL.crypto.sign(Pkey, cert, 'SHA1')
|
|
try:
|
|
OpenSSL.crypto.verify(certobj, signature, cert, 'SHA1')
|
|
except Exception as e:
|
|
print(e)
|
|
return 0
|
|
if method_name == 'cert_add':
|
|
return 0
|
|
certobj_cl = OpenSSL.crypto.load_certificate(
|
|
OpenSSL.SSL.FILETYPE_PEM, cert)
|
|
try:
|
|
com = certobj_cl.get_extension(certobj_cl.get_extension_count() - 1)
|
|
groups = com.get_data().split(':')[1]
|
|
except IndexError:
|
|
groups = ""
|
|
except Exception:
|
|
return 0
|
|
groups_list = groups.split(',')
|
|
# open certificates database
|
|
if not os.path.exists(certbase):
|
|
open(certbase, "w").close()
|
|
from .cert_cmd import find_cert_id
|
|
|
|
checked_id = find_cert_id(cert, data_path, certbase)
|
|
cert_id = int(checked_id)
|
|
count = 0
|
|
find_flag = False
|
|
# if certificate found
|
|
if cert_id > 0:
|
|
if method_name not in LoadedMethods.rightsMethods:
|
|
return 1
|
|
|
|
# if group = all and not redefined group all
|
|
if 'all' in groups_list:
|
|
find_flag = False
|
|
with open(group_rights, 'r') as fd:
|
|
t = fd.read()
|
|
# find all in group_rights file
|
|
for line in t.splitlines():
|
|
if not line:
|
|
continue
|
|
if line.split()[0] == 'all':
|
|
find_flag = True
|
|
break
|
|
# if not find_flag:
|
|
# return 1
|
|
|
|
for right_param in LoadedMethods.rightsMethods[method_name]:
|
|
flag = 0
|
|
try:
|
|
# check rights
|
|
if not os.path.exists(rights):
|
|
open(rights, 'w').close()
|
|
with open(rights) as fr:
|
|
t = fr.read()
|
|
for line in t.splitlines():
|
|
words = line.split()
|
|
# first word in line equal name input method
|
|
if words[0] == right_param:
|
|
for word in words:
|
|
try:
|
|
word = int(word)
|
|
except ValueError:
|
|
continue
|
|
# compare with certificat number
|
|
if cert_id == word:
|
|
# if has right
|
|
count += 1
|
|
flag = 1
|
|
break
|
|
if cert_id == -word:
|
|
return 0
|
|
if flag:
|
|
break
|
|
|
|
if flag:
|
|
break
|
|
# open file with groups rights
|
|
if not os.path.exists(group_rights):
|
|
open(group_rights, 'w').close()
|
|
with open(group_rights) as fd:
|
|
t = fd.read()
|
|
for line in t.splitlines():
|
|
if not line:
|
|
continue
|
|
words = line.split(' ', 1)
|
|
# first word in line equal name input method
|
|
if words[0] in groups_list:
|
|
methods = words[1].split(',')
|
|
for word in methods:
|
|
# compare with certificat number
|
|
if right_param == word.strip():
|
|
# if has right
|
|
count += 1
|
|
flag = 1
|
|
break
|
|
if flag:
|
|
break
|
|
except Exception:
|
|
return 0
|
|
if count == len(LoadedMethods.rightsMethods[method_name]):
|
|
return 1
|
|
if not find_flag and 'all' in groups_list:
|
|
return 1
|
|
elif method_name in ['post_cert', 'init_session']:
|
|
return 1
|
|
return 0
|
|
|
|
def create_path(self):
|
|
""" create paths for server files """
|
|
import threading
|
|
|
|
curthread = threading.currentThread()
|
|
data_path = curthread.server.data_path
|
|
sids = curthread.server.sids
|
|
pids = curthread.server.pids
|
|
cert_path = curthread.server.cert_path
|
|
if not os.path.exists(sids):
|
|
if not os.path.exists(data_path):
|
|
os.makedirs(data_path)
|
|
os.makedirs(sids)
|
|
if not os.path.exists(pids):
|
|
if not os.path.exists(data_path):
|
|
os.makedirs(data_path)
|
|
os.makedirs(pids)
|
|
if not os.path.exists(data_path + '/conf'):
|
|
if not os.path.exists(data_path):
|
|
os.makedirs(data_path)
|
|
os.makedirs(data_path + '/conf')
|
|
|
|
if not os.path.exists(data_path + '/conf/right.conf'):
|
|
open(data_path + '/conf/right.conf', 'w').close()
|
|
|
|
if not os.path.exists(data_path + '/conf/group_right.conf'):
|
|
open(data_path + '/conf/group_right.conf', 'w').close()
|
|
|
|
if not os.path.exists(data_path + '/client_certs'):
|
|
os.makedirs(data_path + '/client_certs')
|
|
|
|
if not os.path.exists(data_path + '/server_certs'):
|
|
os.makedirs(data_path + '/server_certs')
|
|
|
|
if not os.path.exists(cert_path):
|
|
os.makedirs(cert_path)
|
|
|
|
|
|
def get_method_name_from_http(self, http_req_env):
|
|
retval = None
|
|
# check HTTP_SOAPACTION
|
|
retval = http_req_env.get("HTTP_SOAPACTION")
|
|
|
|
if retval is not None:
|
|
if retval.startswith('"') and retval.endswith('"'):
|
|
retval = retval[1:-1]
|
|
|
|
if retval.find('/') >0:
|
|
retvals = retval.split('/')
|
|
retval = '{%s}%s' % (retvals[0], retvals[1])
|
|
|
|
logger.debug("\033[92m"
|
|
"Method name from HTTP_SOAPACTION: %r"
|
|
"\033[0m" % retval)
|
|
if(not retval):
|
|
logger.critical("Couldn't get method name from HTTP_SOAPACTION")
|
|
return retval
|
|
|
|
def get_sid_from_soap(self, http_req_env):
|
|
"""
|
|
rips sid param from soap request (if there is one)
|
|
"""
|
|
if (not "wsgi.input" in http_req_env):
|
|
return None
|
|
length = http_req_env.get("CONTENT_LENGTH")
|
|
input = http_req_env["wsgi.input"]
|
|
body = input.read(int(length))
|
|
res = re.search("<ns.:sid>(.*?)<\/ns.:sid>", body)
|
|
#horrbile hack:
|
|
#cherrypy provides rfile in req_env which is consumed upon .read() without
|
|
# workarounds, and both we and spyne need the data on it
|
|
#so we pass a dummy with the data and read() method on to spyne
|
|
http_req_env["wsgi.input"] = io.StringIO(body)
|
|
|
|
if(res):
|
|
return int(res.group(1))
|
|
else:
|
|
return None
|
|
|
|
|
|
def handle_rpc(self, req_env, start_response):
|
|
"""
|
|
Overriding spyne.wsgiApplication method
|
|
"""
|
|
import OpenSSL
|
|
import threading
|
|
http_resp_headers = {
|
|
'Content-Type': 'text/xml',
|
|
'Content-Length': '0',
|
|
}
|
|
curthread = threading.currentThread()
|
|
curthread.REMOTE_ADDR = req_env.get('REMOTE_ADDR')
|
|
curthread.REMOTE_PORT = req_env.get('REMOTE_PORT')
|
|
ip = req_env.get('REMOTE_ADDR')
|
|
self.create_path()
|
|
sid = self.get_sid_from_soap(req_env)
|
|
method_name = self.get_method_name_from_http(req_env)
|
|
if method_name is None:
|
|
resp = "Could not extract method name from the request!"
|
|
http_resp_headers['Content-Length'] = str(len(resp))
|
|
start_response(HTTP_500, http_resp_headers.items())
|
|
return [resp]
|
|
|
|
service = self.app.services[0]
|
|
import threading
|
|
|
|
curthread = threading.currentThread()
|
|
# check if client certificate exists
|
|
if not hasattr(curthread, 'client_cert'):
|
|
curthread.client_cert = None
|
|
# check rights client certificate for the method
|
|
check = self.check_rights(method_name, req_env, sid)
|
|
if not check:
|
|
if curthread.client_cert:
|
|
certobj = OpenSSL.crypto.load_certificate(
|
|
OpenSSL.SSL.FILETYPE_PEM, curthread.client_cert)
|
|
finger = certobj.digest('SHA1')
|
|
if self.log:
|
|
self.log.debug('%s %s %s forbidden %s'
|
|
% (datetime.datetime.now().__str__(),
|
|
finger, ip,
|
|
method_name[5:]))
|
|
resp = "Permission denied: " + method_name
|
|
http_resp_headers['Content-Length'] = str(len(resp))
|
|
start_response(HTTP_403, http_resp_headers.items())
|
|
return [resp]
|
|
|
|
|
|
if sid:
|
|
curthread.lang = service.get_lang(service, sid, method_name)
|
|
if curthread.client_cert:
|
|
certobj = OpenSSL.crypto.load_certificate(
|
|
OpenSSL.SSL.FILETYPE_PEM, curthread.client_cert)
|
|
finger = certobj.digest('SHA1')
|
|
if (not method_name[5:] in not_log_list and
|
|
not method_name[5:].endswith('_view')):
|
|
if self.log:
|
|
self.log.debug('%s %s %s allowed %s'
|
|
% (datetime.datetime.now().__str__(),
|
|
finger, ip,
|
|
method_name[5:]))
|
|
|
|
return super(ClApplication, self).handle_rpc(req_env, start_response)
|
|
|
|
class OpenSSLAdapter(pyOpenSSLAdapter):
|
|
def verify_func(self, connection, x509, errnum, errdepth, ok):
|
|
# get client certificate
|
|
import OpenSSL
|
|
import threading
|
|
|
|
curthread = threading.currentThread()
|
|
if errdepth == 0:
|
|
curthread.client_cert = OpenSSL.crypto.dump_certificate(
|
|
OpenSSL.crypto.FILETYPE_PEM, x509)
|
|
else:
|
|
curthread.client_cert = None
|
|
return ok
|
|
|
|
def get_context(self):
|
|
"""Return an SSL.Context from self attributes."""
|
|
# See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473
|
|
import OpenSSL
|
|
|
|
c = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
|
|
c.set_options(OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_SSLv3)
|
|
|
|
# c.set_passwd_cb(lambda *unused: 'qqqq')
|
|
c.use_privatekey_file(self.private_key)
|
|
c.set_verify(OpenSSL.SSL.VERIFY_PEER, self.verify_func)
|
|
|
|
if self.certificate_chain:
|
|
c.load_verify_locations(self.certificate_chain)
|
|
|
|
c.use_certificate_file(self.certificate)
|
|
return c
|