You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
4.0 KiB

# Copyright 2017 Dan Krause
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socketserver
import socket
import struct
import json
class IPCError(Exception):
pass
class UnknownMessageClass(IPCError):
pass
class InvalidSerialization(IPCError):
pass
class ConnectionClosed(IPCError):
pass
def _read_objects(sock):
header = sock.recv(4)
if len(header) == 0:
raise ConnectionClosed()
size = struct.unpack('!i', header)[0]
data = sock.recv(size - 4)
if len(data) == 0:
raise ConnectionClosed()
return Message.deserialize(json.loads(data))
def _write_objects(sock, objects):
data = json.dumps([o.serialize() for o in objects])
sock.sendall(struct.pack('!i', len(data) + 4))
sock.sendall(data)
def _recursive_subclasses(cls):
classmap = {}
for subcls in cls.__subclasses__():
classmap[subcls.__name__] = subcls
classmap.update(_recursive_subclasses(subcls))
return classmap
class Message(object):
@classmethod
def deserialize(cls, objects):
classmap = _recursive_subclasses(cls)
serialized = []
for obj in objects:
if isinstance(obj, Message):
serialized.append(obj)
else:
try:
serialized.append(classmap[obj['class']](*obj['args'],
**obj['kwargs']))
except KeyError as e:
raise UnknownMessageClass(e)
except TypeError as e:
raise InvalidSerialization(e)
return serialized
def serialize(self):
args, kwargs = self._get_args()
return {'class': type(self).__name__, 'args': args, 'kwargs': kwargs}
def _get_args(self):
return [], {}
def __repr__(self):
r = self.serialize()
args = ', '.join([repr(arg) for arg in r['args']])
kwargs = ''.join([', {}={}'.format(k, repr(v))
for k, v in r['kwargs'].items()])
name = r['class']
return '{}({}{})'.format(name, args, kwargs)
class Client(object):
def __init__(self, server_address):
self.addr = server_address
if isinstance(self.addr, str):
address_family = socket.AF_UNIX
else:
address_family = socket.AF_INET
self.sock = socket.socket(address_family, socket.SOCK_STREAM)
def connect(self):
self.sock.connect(self.addr)
def close(self):
self.sock.close()
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def send(self, objects):
_write_objects(self.sock, objects)
return _read_objects(self.sock)
class Server(socketserver.ThreadingUnixStreamServer):
def __init__(self, server_address, callback, bind_and_activate=True):
if not callable(callback):
callback = lambda x: []
class IPCHandler(socketserver.BaseRequestHandler):
def handle(self):
while True:
try:
results = _read_objects(self.request)
except ConnectionClosed:
return
_write_objects(self.request, callback(results))
if isinstance(server_address, str):
self.address_family = socket.AF_UNIX
else:
self.address_family = socket.AF_INET
socketserver.TCPServer.__init__(self, server_address, IPCHandler,
bind_and_activate)