#-*- coding: utf-8 -*- # Copyright 2012-2016 Mir Calculate. http://www.calculate-linux.org # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import PROTOCOL_SSLv23, PROTOCOL_TLSv1, PROTOCOL_TLSv1_2 # from socket import socket from socket import SocketIO import io # the OpenSSL stuff import OpenSSL _ssl_to_openssl_cert_op_remap = { CERT_NONE: OpenSSL.SSL.VERIFY_NONE, CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER|OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT } _ssl_to_openssl_version_remap = { PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD, PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, PROTOCOL_TLSv1_2 : OpenSSL.SSL.TLSv1_2_METHOD } class PyOpenSSLSocket(OpenSSL.SSL.Connection): def __init__(self, sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_TLSv1_2, ca_certs=None, do_handshake_on_connect=True, keyobj=None, certobj=None): context = PyOpenSSLSocket.make_context( keyfile = keyfile, certfile = certfile, cert_reqs = cert_reqs, ssl_version = ssl_version, ca_certs = ca_certs, keyobj = keyobj, certobj = certobj) super().__init__(context, sock) self.setblocking(True) self.set_connect_state() if do_handshake_on_connect: timeout = self.gettimeout() try: self.settimeout(None) self.do_handshake() finally: self.settimeout(timeout) self._io_refs = 0 self.do_handshake_on_connect = do_handshake_on_connect def connect(self, addr): print("PYOPENSSL CONNECT") super().connect(addr) if self.do_handshake_on_connect: self.do_handshake() def close (self): if self._io_refs < 1: self._socket.close() else: self._io_refs -= 1 def makefile(self, mode="r", buffering=None, *, encoding=None, errors=None, newline=None): """makefile(...) -> an I/O stream connected to the socket The arguments are as for io.open() after the filename, except the only supported mode values are 'r' (default), 'w' and 'b'. """ if not set(mode) <= {"r", "w", "b"}: raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) writing = "w" in mode reading = "r" in mode or not writing assert reading or writing binary = "b" in mode rawmode = "" if reading: rawmode += "r" if writing: rawmode += "w" raw = SocketIO(self, rawmode) self._io_refs += 1 if buffering is None: buffering = -1 if buffering < 0: buffering = io.DEFAULT_BUFFER_SIZE if buffering == 0: if not binary: raise ValueError("unbuffered streams must be binary") return raw if reading and writing: buffer = io.BufferedRWPair(raw, raw, buffering) elif reading: buffer = io.BufferedReader(raw, buffering) else: assert writing buffer = io.BufferedWriter(raw, buffering) if binary: return buffer text = io.TextIOWrapper(buffer, encoding, errors, newline) text.mode = mode return text @staticmethod def make_context(keyfile=None, certfile=None, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_TLSv1_2, ca_certs=None, keyobj=None, certobj=None): ctx = OpenSSL.SSL.Context(_ssl_to_openssl_version_remap[ssl_version]) if ca_certs: ctx.load_verify_locations(ca_certs) ctx.set_verify(_ssl_to_openssl_cert_op_remap[cert_reqs], verify_connection) if keyobj: ctx.use_privatekey(keyobj) elif keyfile: ctx.use_privatekey_file(keyfile) if certobj: ctx.use_certificate(certobj) elif certfile: ctx.use_certificate_file(certfile) ctx.set_options(0x4000) # THIS IS THE KEY TO SUCCESS OF DS return ctx def verify_connection(conn, x509, error_code, depth, ret_code): # no extra validation - just return whatever OpenSSL already # decided during its check return bool(ret_code)