You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
calculate-utils-3-lib/pym/calculate/contrib/spyne/server/twisted/msgpack.py

478 lines
16 KiB

# -*- coding: utf-8 -*-
#
# spyne - Copyright (C) Spyne contributors.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
#
import logging
logger = logging.getLogger(__name__)
import io
import msgpack
from time import time
from hashlib import md5
from collections import deque, OrderedDict
from itertools import chain
from twisted.internet import reactor
from twisted.internet.task import deferLater
from twisted.internet.defer import Deferred, CancelledError
from twisted.internet.protocol import Protocol, Factory, connectionDone, \
ClientFactory
from twisted.python.failure import Failure
from spyne import EventManager, Address, ServerBase, Fault
from spyne.auxproc import process_contexts
from spyne.error import InternalError
from spyne.server.twisted import log_and_let_go
class TwistedMessagePackProtocolFactory(Factory):
IDLE_TIMEOUT_SEC = None
def __init__(self, tpt):
assert isinstance(tpt, ServerBase)
self.tpt = tpt
self.event_manager = EventManager(self)
def buildProtocol(self, address):
retval = TwistedMessagePackProtocol(self.tpt, factory=self)
if self.IDLE_TIMEOUT_SEC is not None:
retval.IDLE_TIMEOUT_SEC = self.IDLE_TIMEOUT_SEC
return retval
TwistedMessagePackProtocolServerFactory = TwistedMessagePackProtocolFactory
class TwistedMessagePackProtocolClientFactory(ClientFactory):
def __init__(self, tpt, max_buffer_size=2 * 1024 * 1024):
assert isinstance(tpt, ServerBase), \
"%r is not a ServerBase instance" % tpt
self.tpt = tpt
self.max_buffer_size = max_buffer_size
self.event_manager = EventManager(self)
def buildProtocol(self, address):
return TwistedMessagePackProtocol(self.tpt,
max_buffer_size=self.max_buffer_size, factory=self)
def _cha(*args):
return args
class TwistedMessagePackProtocol(Protocol):
IDLE_TIMEOUT_SEC = 0
IDLE_TIMEOUT_MSG = 'idle timeout'
MAX_INACTIVE_CONTEXTS = float('inf')
def __init__(self, tpt, max_buffer_size=2 * 1024 * 1024, out_chunk_size=0,
out_chunk_delay_sec=1, max_in_queue_size=0, factory=None):
"""Twisted protocol implementation for Spyne's MessagePack transport.
:param tpt: Spyne transport. It's an app-wide instance.
:param max_buffer_size: Max. encoded message size.
:param out_chunk_size: Split
:param factory: Twisted protocol factory
Supported events:
* ``outresp_flushed(ctx, ctxid, data)``
Called right after response data is flushed to the socket.
* ctx: Always None
* ctxid: Integer equal to ``id(ctx)``
* data: Flushed bytes object
"""
from spyne.server.msgpack import MessagePackTransportBase
assert isinstance(tpt, MessagePackTransportBase), \
"Expected {!r} got {!r}".format(MessagePackTransportBase, type(tpt))
self.spyne_tpt = tpt
self._buffer = msgpack.Unpacker(raw=True,
max_buffer_size=max_buffer_size)
self.out_chunk_size = out_chunk_size
self.out_chunk_delay_sec = out_chunk_delay_sec
self.max_in_queue_size = max_in_queue_size
self.factory = factory
self.sessid = ''
self._delaying = None
self.sent_bytes = 0
self.recv_bytes = 0
self.idle_timer = None
self.out_chunks = deque()
self.inreq_queue = OrderedDict()
self.inactive_queue = list()
self.disconnecting = False # FIXME: should we use this to raise an
# invalid connection state exception ?
@staticmethod
def gen_chunks(l, n):
"""Yield successive n-sized chunks from l."""
if isinstance(l, io.BufferedIOBase):
while True:
data = l.read(n)
if not data:
break
yield data
l.close()
else:
for i in range(0, len(l), n):
yield l[i:i+n]
def gen_sessid(self, *args):
"""It's up to you to use this in a subclass."""
retval = _cha(
Address.from_twisted_address(self.transport.getPeer()),
time(),
*args
)
return md5(repr(retval).encode('utf8')).hexdigest()
def connectionMade(self):
logger.debug("%08x connection made", id(self))
self.sessid = ''
self._delaying = None
self.sent_bytes = 0
self.recv_bytes = 0
self.idle_timer = None
self.out_chunks = deque()
self.inreq_queue = OrderedDict()
self.inactive_queue = list()
self.active_queue = dict()
self.disconnecting = False # FIXME: should we use this to raise an
# invalid connection state exception ?
self._reset_idle_timer()
if self.factory is not None:
self.factory.event_manager.fire_event("connection_made", self)
def connectionLost(self, reason=connectionDone):
if reason is connectionDone:
logger.debug("%08x connection done", id(self))
else:
logger.debug("%08x connection lost: %s", id(self), reason)
self.disconnecting = False
if self.factory is not None:
self.factory.event_manager.fire_event("connection_lost", self)
self._cancel_idle_timer()
def _cancel_idle_timer(self):
if self.idle_timer is not None:
if not self.idle_timer.called:
# FIXME: Workaround for a bug in Twisted 18.9.0 when
# DelayedCall.debug == True
try:
self.idle_timer.cancel()
except AttributeError:
del self.idle_timer.func
del self.idle_timer.args
del self.idle_timer.kw
self.idle_timer = None
def dataReceived(self, data):
self._buffer.feed(data)
self.recv_bytes += len(data)
self._reset_idle_timer()
for msg in self._buffer:
self.process_incoming_message(msg)
if self.disconnecting:
return
def _reset_idle_timer(self):
if self.idle_timer is not None:
t = self.idle_timer
self.idle_timer = None
if not t.called:
t.cancel()
if self.IDLE_TIMEOUT_SEC is not None and self.IDLE_TIMEOUT_SEC > 0:
self.idle_timer = deferLater(reactor, self.IDLE_TIMEOUT_SEC,
self.loseConnection, self.IDLE_TIMEOUT_MSG) \
.addErrback(self._err_idle_cancelled) \
.addErrback(self._err_idle_cancelled_unknown_error)
def _err_idle_cancelled(self, err):
err.trap(CancelledError)
# do nothing.
def _err_idle_cancelled_unknown_error(self, err):
logger.error("Sessid %s error cancelling idle timer: %s",
self.sessid, err.getTraceback())
self.idle_timer = None
def loseConnection(self, reason=None):
self.disconnecting = True
self.idle_timer = None
logger.debug("Aborting connection because %s", reason)
self.transport.abortConnection()
def process_incoming_message(self, msg, oob=None):
p_ctx, others = self.spyne_tpt.produce_contexts(msg)
p_ctx.oob_ctx = oob
p_ctx.transport.remote_addr = Address.from_twisted_address(
self.transport.getPeer())
p_ctx.transport.protocol = self
p_ctx.transport.sessid = self.sessid
self.inactive_queue.append((p_ctx, others))
self.process_inactive()
@property
def num_active_contexts(self):
return len(self.inreq_queue)
@property
def num_inactive_contexts(self):
return len(self.inactive_queue)
def process_inactive(self):
peer = self.transport.getPeer()
addr_str = Address.from_twisted_address(peer)
if self.max_in_queue_size == 0:
while self.num_inactive_contexts > 0:
p_ctx, others = self.inactive_queue.pop()
self.active_queue[id(p_ctx)] = p_ctx
self.inreq_queue[id(p_ctx)] = None
self.process_contexts(p_ctx, others)
else:
while self.num_active_contexts < self.max_in_queue_size and \
self.num_inactive_contexts > 0:
p_ctx, others = self.inactive_queue.pop()
self.active_queue[id(p_ctx)] = p_ctx
self.inreq_queue[id(p_ctx)] = None
self.process_contexts(p_ctx, others)
if self.num_active_contexts > self.MAX_INACTIVE_CONTEXTS:
logger.error("%s Too many inactive contexts. "
"Closing connection.", addr_str)
self.loseConnection("Too many inactive contexts")
logger.debug("%s active %d inactive %d", addr_str,
self.num_active_contexts, self.num_inactive_contexts)
def enqueue_outresp_data(self, ctxid, data):
assert self.inreq_queue[ctxid] is None
self.inreq_queue[ctxid] = data
for k, v in list(self.inreq_queue.items()):
if v is None:
break
self.out_write(v)
self.spyne_tpt.event_manager.fire_event('outresp_flushed',
None, k, v)
del self.inreq_queue[k]
self.active_queue[k].close()
del self.active_queue[k]
self.process_inactive()
def out_write(self, reqdata):
if self.out_chunk_size == 0:
if isinstance(reqdata, io.BufferedIOBase):
nbytes = reqdata.tell()
reqdata.seek(0)
self.transport.write(reqdata.read())
else:
nbytes = len(reqdata)
self.transport.write(reqdata)
self.sent_bytes += nbytes
else:
if isinstance(reqdata, io.BufferedIOBase):
reqdata.seek(0)
chunks = self.gen_chunks(reqdata, self.out_chunk_size)
self.out_chunks.append(chunks)
deferLater(reactor, 0, self._write_single_chunk)
def _wait_for_next_chunk(self):
return deferLater(reactor, self.out_chunk_delay_sec,
self._write_single_chunk)
def _write_single_chunk(self):
try:
chunk = next(chain.from_iterable(self.out_chunks))
except StopIteration:
chunk = None
self.out_chunks.clear()
if chunk is None:
self._delaying = None
logger.debug("%s no more chunks...", self.sessid)
else:
self.transport.write(chunk)
self.sent_bytes += len(chunk)
if self.connected and not self.disconnecting:
self._delaying = self._wait_for_next_chunk()
logger.debug("%s One chunk of %d bytes written. Delaying "
"before next chunk write...", self.sessid, len(chunk))
else:
logger.debug("%s Disconnection detected, discarding "
"remaining chunks", self.sessid)
self.out_chunks.clear()
def handle_error(self, p_ctx, others, exc):
self.spyne_tpt.get_out_string(p_ctx)
if isinstance(exc, InternalError):
error = self.spyne_tpt.OUT_RESPONSE_SERVER_ERROR
else:
error = self.spyne_tpt.OUT_RESPONSE_CLIENT_ERROR
data = p_ctx.out_document[0]
if isinstance(data, dict):
data = list(data.values())
out_object = (error, msgpack.packb(data),)
if p_ctx.oob_ctx is not None:
p_ctx.oob_ctx.d.callback(out_object)
return
out_string = msgpack.packb(out_object)
p_ctx.transport.resp_length = len(out_string)
self.enqueue_outresp_data(id(p_ctx), out_string)
try:
process_contexts(self, others, p_ctx, error=error)
except Exception as e:
# Report but ignore any exceptions from auxiliary methods.
logger.error("Exception ignored from auxiliary method: %r", e)
logger.exception(e)
def _register_callbacks(self, d, p_ctx, others):
return d \
.addCallback(self._cb_deferred, p_ctx, others) \
.addErrback(self._eb_deferred, p_ctx, others) \
.addErrback(log_and_let_go, logger)
def process_contexts(self, p_ctx, others):
if p_ctx.in_error:
self.handle_error(p_ctx, others, p_ctx.in_error)
return
self.spyne_tpt.get_in_object(p_ctx)
if p_ctx.in_error:
logger.error(p_ctx.in_error)
self.handle_error(p_ctx, others, p_ctx.in_error)
return
self.spyne_tpt.get_out_object(p_ctx)
if p_ctx.out_error:
self.handle_error(p_ctx, others, p_ctx.out_error)
return
ret = p_ctx.out_object
if isinstance(ret, Deferred):
self._register_callbacks(ret, p_ctx, others)
else:
ret = p_ctx.out_object[0]
if isinstance(ret, Deferred):
self._register_callbacks(ret, p_ctx, others)
else:
self._cb_deferred(p_ctx.out_object, p_ctx, others, nowrap=True)
def _eb_deferred(self, fail, p_ctx, others):
assert isinstance(fail, Failure)
if isinstance(fail.value, Fault):
p_ctx.out_error = fail.value
else:
p_ctx.out_error = InternalError(fail.value)
if not getattr(fail, 'logged', False):
logger.error(fail.getTraceback())
try:
self.handle_error(p_ctx, others, p_ctx.out_error)
except Exception as e:
logger.exception(e)
raise
def _cb_deferred(self, ret, p_ctx, others, nowrap=False):
# this means callback is not invoked directly instead of as part of a
# deferred chain
if not nowrap:
# if there is one return value or the output is bare (which means
# there can't be anything other than 1 return value case) use the
# enclosing list. otherwise, the return value is a tuple anyway, so
# leave it be.
if p_ctx.descriptor.is_out_bare():
p_ctx.out_object = [ret]
else:
if len(p_ctx.descriptor.out_message._type_info) > 1:
p_ctx.out_object = ret
else:
p_ctx.out_object = [ret]
if p_ctx.oob_ctx is not None:
assert isinstance(p_ctx.oob_ctx.d, Deferred)
p_ctx.oob_ctx.d.callback(p_ctx.out_object)
return
try:
self.spyne_tpt.get_out_string(p_ctx)
self.spyne_tpt.pack(p_ctx)
out_string = b''.join(p_ctx.out_string)
p_ctx.transport.resp_length = len(out_string)
self.enqueue_outresp_data(id(p_ctx), out_string)
except Exception as e:
logger.exception(e)
logger.error("%r", p_ctx)
self.handle_error(p_ctx, others, InternalError(e))
finally:
p_ctx.close()
process_contexts(self.spyne_tpt, others, p_ctx)