tornado.websocket 源代码

#!/usr/bin/env python
# coding: utf-8
"""WebSocket 协议的实现

`WebSockets <http://dev.w3.org/html5/websockets/>`_ 允许浏览器和服务器之间进行
双向通信

所有主流浏览器的现代版本都支持WebSockets(支持情况详见:http://caniuse.com/websockets)

该模块依照最新 WebSocket 协议 `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_ 实现.

.. versionchanged:: 4.0
   删除了对76 草案协议版本的支持.
"""

from __future__ import absolute_import, division, print_function, with_statement
# Author: Jacob Kristhammar, 2010

import base64
import collections
import hashlib
import os
import struct
import tornado.escape
import tornado.web
import zlib

from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
from tornado import httpclient, httputil
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado import simple_httpclient
from tornado.tcpclient import TCPClient
from tornado.util import _websocket_mask

try:
    from urllib.parse import urlparse  # py2
except ImportError:
    from urlparse import urlparse  # py3

try:
    xrange  # py2
except NameError:
    xrange = range  # py3


class WebSocketError(Exception):
    pass


[文档]class WebSocketClosedError(WebSocketError): """出现关闭连接错误触发. .. versionadded:: 3.2 """ pass
[文档]class WebSocketHandler(tornado.web.RequestHandler): """通过继承该类来创建一个基本的 WebSocket handler. 重写 `on_message` 来处理收到的消息, 使用 `write_message` 来发送消息到客户端. 你也可以重写 `open` 和 `on_close` 来处理连接打开和关闭这两个动作. 有关JavaScript 接口的详细信息: http://dev.w3.org/html5/websockets/ 具体的协议: http://tools.ietf.org/html/rfc6455 一个简单的 WebSocket handler 的实例: 服务端直接返回所有收到的消息给客户端 .. testcode:: class EchoWebSocket(tornado.websocket.WebSocketHandler): def open(self): print("WebSocket opened") def on_message(self, message): self.write_message(u"You said: " + message) def on_close(self): print("WebSocket closed") .. testoutput:: :hide: WebSockets 并不是标准的 HTTP 连接. “握手”动作符合 HTTP 标准,但是在”握手”动作之后, 协议是基于消息的. 因此,Tornado 里大多数的 HTTP 工具对于这类 handler 都是不可用的. 用来通讯的方法只有 `write_message()` , `ping()` , 和 `close()` . 同样的,你的 request handler 类里应该使用 `open()` 而不是 ``get()`` 或者 ``post()`` 如果你在应用中将这个 handler 分配到 ``/websocket``, 你可以通过如下代码实现:: var ws = new WebSocket("ws://localhost:8888/websocket"); ws.onopen = function() { ws.send("Hello, world"); }; ws.onmessage = function (evt) { alert(evt.data); }; 这个脚本将会弹出一个提示框 :"You said: Hello, world" 浏览器并没有遵循同源策略(same-origin policy),相应的允许了任意站点使用 javascript 发起任意 WebSocket 连接来支配其他网络.这令人惊讶,并且是一个潜在的安全漏洞,所以 从 Tornado 4.0 开始 `WebSocketHandler` 需要对希望接受跨域请求的应用通过重写. `~WebSocketHandler.check_origin` (详细信息请查看文档中有关该方法的部分)来进行设置. 没有正确配置这个属性,在建立 WebSocket 连接时候很可能会导致 403 错误. 当使用安全的 websocket 连接(``wss://``) 时, 来自浏览器的连接可能会失败,因为 websocket 没有地方输出 "认证成功" 的对话. 你在 websocket 连接建立成功之前,必须 使用相同的证书访问一个常规的 HTML 页面. """ def __init__(self, application, request, **kwargs): tornado.web.RequestHandler.__init__(self, application, request, **kwargs) self.ws_connection = None self.close_code = None self.close_reason = None self.stream = None self._on_close_called = False @tornado.web.asynchronous def get(self, *args, **kwargs): self.open_args = args self.open_kwargs = kwargs # Upgrade header should be present and should be equal to WebSocket if self.request.headers.get("Upgrade", "").lower() != 'websocket': self.set_status(400) log_msg = "Can \"Upgrade\" only to \"WebSocket\"." self.finish(log_msg) gen_log.debug(log_msg) return # Connection header should be upgrade. # Some proxy servers/load balancers # might mess with it. headers = self.request.headers connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(",")) if 'upgrade' not in connection: self.set_status(400) log_msg = "\"Connection\" must be \"Upgrade\"." self.finish(log_msg) gen_log.debug(log_msg) return # Handle WebSocket Origin naming convention differences # The difference between version 8 and 13 is that in 8 the # client sends a "Sec-Websocket-Origin" header and in 13 it's # simply "Origin". if "Origin" in self.request.headers: origin = self.request.headers.get("Origin") else: origin = self.request.headers.get("Sec-Websocket-Origin", None) # If there was an origin header, check to make sure it matches # according to check_origin. When the origin is None, we assume it # did not come from a browser and that it can be passed on. if origin is not None and not self.check_origin(origin): self.set_status(403) log_msg = "Cross origin websockets not allowed" self.finish(log_msg) gen_log.debug(log_msg) return self.stream = self.request.connection.detach() self.stream.set_close_callback(self.on_connection_close) self.ws_connection = self.get_websocket_protocol() if self.ws_connection: self.ws_connection.accept_connection() else: if not self.stream.closed(): self.stream.write(tornado.escape.utf8( "HTTP/1.1 426 Upgrade Required\r\n" "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n")) self.stream.close()
[文档] def write_message(self, message, binary=False): """将给出的 message 发送到客户端 message 可以是 string 或者 dict(将会被编码成 json ) 如果 ``binary`` 为 false, message 将会以 utf8 的编码发送; 在 binary 模式下 message 可以是 任何 byte string. 如果连接已经关闭, 则会触发 `WebSocketClosedError` .. versionchanged:: 3.2 添加了 `WebSocketClosedError` (在之前版本会触发 `AttributeError`) .. versionchanged:: 4.3 返回能够被用于 flow control 的 `.Future`. """ if self.ws_connection is None: raise WebSocketClosedError() if isinstance(message, dict): message = tornado.escape.json_encode(message) return self.ws_connection.write_message(message, binary=binary)
[文档] def select_subprotocol(self, subprotocols): """当一个新的 WebSocket 请求特定子协议(subprotocols)时调用 ``subprotocols`` 是一个由一系列能够被客户端正确识别出相应的子协议 (subprotocols)的字符串构成的 list . 这个方法可能会被重载,用来返回 list 中某 个匹配字符串, 没有匹配到则返回 ``None``. 如果没有找到相应的子协议,虽然服务端并 不会自动关闭 WebSocket 连接,但是客户端可以选择关闭连接. """ return None
[文档] def get_compression_options(self): """重写该方法返回当前连接的 compression 选项 如果这个方法返回 None (默认), compression 将会被禁用. 如果它返回 dict (即使 是空的),compression 都会被开启. dict 的内容将会被用来控制 compression 所 使用的内存和CPU.但是这类的设置现在还没有被实现. .. versionadded:: 4.1 """ return None
[文档] def open(self, *args, **kwargs): """当打开一个新的 WebSocket 时调用 `open` 的参数是从 `tornado.web.URLSpec` 通过正则表达式获取的, 就像获取 `tornado.web.RequestHandler.get` 的参数一样 """ pass
[文档] def on_message(self, message): """处理在 WebSocket 中收到的消息 这个方法必须被重写 """ raise NotImplementedError
[文档] def ping(self, data): """发送 ping 包到远端.""" if self.ws_connection is None: raise WebSocketClosedError() self.ws_connection.write_ping(data)
[文档] def on_pong(self, data): """当收到ping 包的响应时执行.""" pass
[文档] def on_close(self): """当关闭该 WebSocket 时调用 当连接被彻底关闭并且支持 status code 或 reason phtase 的时候, 可以通过 ``self.close_code`` 和 ``self.close_reason`` 这两个属性来获取它们 .. versionchanged:: 4.0 Added ``close_code`` and ``close_reason`` attributes. 添加 ``close_code`` 和 ``close_reason`` 这两个属性 """ pass
[文档] def close(self, code=None, reason=None): """关闭当前 WebSocket 一旦挥手动作成功,socket将会被关闭. ``code`` 可能是一个数字构成的状态码, 采用 `RFC 6455 section 7.4.1 <https://tools.ietf.org/html/rfc6455#section-7.4.1>`_. 定义的值. ``reason`` 可能是描述连接关闭的文本消息. 这个值被提给客户端,但是不会被 WebSocket 协议单独解释. .. versionchanged:: 4.0 Added the ``code`` and ``reason`` arguments. """ if self.ws_connection: self.ws_connection.close(code, reason) self.ws_connection = None
[文档] def check_origin(self, origin): """通过重写这个方法来实现域的切换 参数 ``origin`` 的值来自 HTTP header 中的 ``Origin`` ,url 负责初始化这个请求. 这个方法并不是要求客户端不发送这样的 heder;这样的请求一直被允许(因为所有的浏览器 实现的 websockets 都支持这个 header ,并且非浏览器客户端没有同样的跨域安全问题. 返回 True 代表接受,相应的返回 False 代表拒绝.默认拒绝除 host 外其他域的请求. 这个是一个浏览器防止 XSS 攻击的安全策略,因为 WebSocket 允许绕过通常的同源策略 以及不使用 CORS 头. 要允许所有跨域通信的话(这在 Tornado 4.0 之前是默认的),只要简单的重写这个方法 让它一直返回 true 就可以了:: def check_origin(self, origin): return True 要允许所有所有子域下的连接,可以这样实现:: def check_origin(self, origin): parsed_origin = urllib.parse.urlparse(origin) return parsed_origin.netloc.endswith(".mydomain.com") .. versionadded:: 4.0 """ parsed_origin = urlparse(origin) origin = parsed_origin.netloc origin = origin.lower() host = self.request.headers.get("Host") # Check to see that origin matches host directly, including ports return origin == host
[文档] def set_nodelay(self, value): """为当前 stream 设置 no-delay 在默认情况下, 小块数据会被延迟和/或合并以减少发送包的数量. 这在有些时候会因为 Nagle's 算法和 TCP ACKs 相互作用会造成 200-500ms 的延迟.在 WebSocket 连接 已经建立的情况下,可以通过设置 ``self.set_nodelay(True)`` 来降低延迟(这可能 会占用更多带宽) 更多详细信息: `.BaseIOStream.set_nodelay`. 在 `.BaseIOStream.set_nodelay` 查看详细信息. .. versionadded:: 3.1 """ self.stream.set_nodelay(value)
def on_connection_close(self): if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None if not self._on_close_called: self._on_close_called = True self.on_close() def send_error(self, *args, **kwargs): if self.stream is None: super(WebSocketHandler, self).send_error(*args, **kwargs) else: # If we get an uncaught exception during the handshake, # we have no choice but to abruptly close the connection. # TODO: for uncaught exceptions after the handshake, # we can close the connection more gracefully. self.stream.close() def get_websocket_protocol(self): websocket_version = self.request.headers.get("Sec-WebSocket-Version") if websocket_version in ("7", "8", "13"): return WebSocketProtocol13( self, compression_options=self.get_compression_options())
def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): if self.stream is None: method(self, *args, **kwargs) else: raise RuntimeError("Method not supported for Web Sockets") return _disallow_for_websocket for method in ["write", "redirect", "set_header", "set_cookie", "set_status", "flush", "finish"]: setattr(WebSocketHandler, method, _wrap_method(getattr(WebSocketHandler, method))) class WebSocketProtocol(object): """Base class for WebSocket protocol versions. """ def __init__(self, handler): self.handler = handler self.request = handler.request self.stream = handler.stream self.client_terminated = False self.server_terminated = False def _run_callback(self, callback, *args, **kwargs): """Runs the given callback with exception handling. On error, aborts the websocket connection and returns False. """ try: callback(*args, **kwargs) except Exception: app_log.error("Uncaught exception in %s", self.request.path, exc_info=True) self._abort() def on_connection_close(self): self._abort() def _abort(self): """Instantly aborts the WebSocket connection by closing the socket""" self.client_terminated = True self.server_terminated = True self.stream.close() # forcibly tear down the connection self.close() # let the subclass cleanup class _PerMessageDeflateCompressor(object): def __init__(self, persistent, max_wbits): if max_wbits is None: max_wbits = zlib.MAX_WBITS # There is no symbolic constant for the minimum wbits value. if not (8 <= max_wbits <= zlib.MAX_WBITS): raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", max_wbits, zlib.MAX_WBITS) self._max_wbits = max_wbits if persistent: self._compressor = self._create_compressor() else: self._compressor = None def _create_compressor(self): return zlib.compressobj(tornado.web.GZipContentEncoding.GZIP_LEVEL, zlib.DEFLATED, -self._max_wbits) def compress(self, data): compressor = self._compressor or self._create_compressor() data = (compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH)) assert data.endswith(b'\x00\x00\xff\xff') return data[:-4] class _PerMessageDeflateDecompressor(object): def __init__(self, persistent, max_wbits): if max_wbits is None: max_wbits = zlib.MAX_WBITS if not (8 <= max_wbits <= zlib.MAX_WBITS): raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", max_wbits, zlib.MAX_WBITS) self._max_wbits = max_wbits if persistent: self._decompressor = self._create_decompressor() else: self._decompressor = None def _create_decompressor(self): return zlib.decompressobj(-self._max_wbits) def decompress(self, data): decompressor = self._decompressor or self._create_decompressor() return decompressor.decompress(data + b'\x00\x00\xff\xff') class WebSocketProtocol13(WebSocketProtocol): """Implementation of the WebSocket protocol from RFC 6455. This class supports versions 7 and 8 of the protocol in addition to the final version 13. """ # Bit masks for the first byte of a frame. FIN = 0x80 RSV1 = 0x40 RSV2 = 0x20 RSV3 = 0x10 RSV_MASK = RSV1 | RSV2 | RSV3 OPCODE_MASK = 0x0f def __init__(self, handler, mask_outgoing=False, compression_options=None): WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing self._final_frame = False self._frame_opcode = None self._masked_frame = None self._frame_mask = None self._frame_length = None self._fragmented_message_buffer = None self._fragmented_message_opcode = None self._waiting = None self._compression_options = compression_options self._decompressor = None self._compressor = None self._frame_compressed = None # The total uncompressed size of all messages received or sent. # Unicode messages are encoded to utf8. # Only for testing; subject to change. self._message_bytes_in = 0 self._message_bytes_out = 0 # The total size of all packets received or sent. Includes # the effect of compression, frame overhead, and control frames. self._wire_bytes_in = 0 self._wire_bytes_out = 0 def accept_connection(self): try: self._handle_websocket_headers() self._accept_connection() except ValueError: gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() return def _handle_websocket_headers(self): """Verifies all invariant- and required headers If a header is missing or have an incorrect value ValueError will be raised """ fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") if not all(map(lambda f: self.request.headers.get(f), fields)): raise ValueError("Missing/Invalid WebSocket headers") @staticmethod def compute_accept_value(key): """Computes the value for the Sec-WebSocket-Accept header, given the value for Sec-WebSocket-Key. """ sha1 = hashlib.sha1() sha1.update(utf8(key)) sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value return native_str(base64.b64encode(sha1.digest())) def _challenge_response(self): return WebSocketProtocol13.compute_accept_value( self.request.headers.get("Sec-Websocket-Key")) def _accept_connection(self): subprotocol_header = '' subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '') subprotocols = [s.strip() for s in subprotocols.split(',')] if subprotocols: selected = self.handler.select_subprotocol(subprotocols) if selected: assert selected in subprotocols subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n" % selected) extension_header = '' extensions = self._parse_extensions_header(self.request.headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): # TODO: negotiate parameters if compression_options # specifies limits. self._create_compressors('server', ext[1]) if ('client_max_window_bits' in ext[1] and ext[1]['client_max_window_bits'] is None): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] extension_header = ('Sec-WebSocket-Extensions: %s\r\n' % httputil._encode_header( 'permessage-deflate', ext[1])) break if self.stream.closed(): self._abort() return self.stream.write(tornado.escape.utf8( "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" "%s%s" "\r\n" % (self._challenge_response(), subprotocol_header, extension_header))) self._run_callback(self.handler.open, *self.handler.open_args, **self.handler.open_kwargs) self._receive_frame() def _parse_extensions_header(self, headers): extensions = headers.get("Sec-WebSocket-Extensions", '') if extensions: return [httputil._parse_header(e.strip()) for e in extensions.split(',')] return [] def _process_server_headers(self, key, headers): """Process the headers sent by the server to this client connection. 'key' is the websocket handshake challenge/response key. """ assert headers['Upgrade'].lower() == 'websocket' assert headers['Connection'].lower() == 'upgrade' accept = self.compute_accept_value(key) assert headers['Sec-Websocket-Accept'] == accept extensions = self._parse_extensions_header(headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): self._create_compressors('client', ext[1]) else: raise ValueError("unsupported extension %r", ext) def _get_compressor_options(self, side, agreed_parameters): """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. """ options = dict( persistent=(side + '_no_context_takeover') not in agreed_parameters) wbits_header = agreed_parameters.get(side + '_max_window_bits', None) if wbits_header is None: options['max_wbits'] = zlib.MAX_WBITS else: options['max_wbits'] = int(wbits_header) return options def _create_compressors(self, side, agreed_parameters): # TODO: handle invalid parameters gracefully allowed_keys = set(['server_no_context_takeover', 'client_no_context_takeover', 'server_max_window_bits', 'client_max_window_bits']) for key in agreed_parameters: if key not in allowed_keys: raise ValueError("unsupported compression parameter %r" % key) other_side = 'client' if (side == 'server') else 'server' self._compressor = _PerMessageDeflateCompressor( **self._get_compressor_options(side, agreed_parameters)) self._decompressor = _PerMessageDeflateDecompressor( **self._get_compressor_options(other_side, agreed_parameters)) def _write_frame(self, fin, opcode, data, flags=0): if fin: finbit = self.FIN else: finbit = 0 frame = struct.pack("B", finbit | opcode | flags) l = len(data) if self.mask_outgoing: mask_bit = 0x80 else: mask_bit = 0 if l < 126: frame += struct.pack("B", l | mask_bit) elif l <= 0xFFFF: frame += struct.pack("!BH", 126 | mask_bit, l) else: frame += struct.pack("!BQ", 127 | mask_bit, l) if self.mask_outgoing: mask = os.urandom(4) data = mask + _websocket_mask(mask, data) frame += data self._wire_bytes_out += len(frame) try: return self.stream.write(frame) except StreamClosedError: self._abort() def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket.""" if binary: opcode = 0x2 else: opcode = 0x1 message = tornado.escape.utf8(message) assert isinstance(message, bytes) self._message_bytes_out += len(message) flags = 0 if self._compressor: message = self._compressor.compress(message) flags |= self.RSV1 return self._write_frame(True, opcode, message, flags=flags) def write_ping(self, data): """Send ping frame.""" assert isinstance(data, bytes) self._write_frame(True, 0x9, data) def _receive_frame(self): try: self.stream.read_bytes(2, self._on_frame_start) except StreamClosedError: self._abort() def _on_frame_start(self, data): self._wire_bytes_in += len(data) header, payloadlen = struct.unpack("BB", data) self._final_frame = header & self.FIN reserved_bits = header & self.RSV_MASK self._frame_opcode = header & self.OPCODE_MASK self._frame_opcode_is_control = self._frame_opcode & 0x8 if self._decompressor is not None and self._frame_opcode != 0: self._frame_compressed = bool(reserved_bits & self.RSV1) reserved_bits &= ~self.RSV1 if reserved_bits: # client is using as-yet-undefined extensions; abort self._abort() return self._masked_frame = bool(payloadlen & 0x80) payloadlen = payloadlen & 0x7f if self._frame_opcode_is_control and payloadlen >= 126: # control frames must have payload < 126 self._abort() return try: if payloadlen < 126: self._frame_length = payloadlen if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self.stream.read_bytes(self._frame_length, self._on_frame_data) elif payloadlen == 126: self.stream.read_bytes(2, self._on_frame_length_16) elif payloadlen == 127: self.stream.read_bytes(8, self._on_frame_length_64) except StreamClosedError: self._abort() def _on_frame_length_16(self, data): self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!H", data)[0] try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self.stream.read_bytes(self._frame_length, self._on_frame_data) except StreamClosedError: self._abort() def _on_frame_length_64(self, data): self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!Q", data)[0] try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self.stream.read_bytes(self._frame_length, self._on_frame_data) except StreamClosedError: self._abort() def _on_masking_key(self, data): self._wire_bytes_in += len(data) self._frame_mask = data try: self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) except StreamClosedError: self._abort() def _on_masked_frame_data(self, data): # Don't touch _wire_bytes_in; we'll do it in _on_frame_data. self._on_frame_data(_websocket_mask(self._frame_mask, data)) def _on_frame_data(self, data): self._wire_bytes_in += len(data) if self._frame_opcode_is_control: # control frames may be interleaved with a series of fragmented # data frames, so control frames must not interact with # self._fragmented_* if not self._final_frame: # control frames must not be fragmented self._abort() return opcode = self._frame_opcode elif self._frame_opcode == 0: # continuation frame if self._fragmented_message_buffer is None: # nothing to continue self._abort() return self._fragmented_message_buffer += data if self._final_frame: opcode = self._fragmented_message_opcode data = self._fragmented_message_buffer self._fragmented_message_buffer = None else: # start of new data message if self._fragmented_message_buffer is not None: # can't start new message until the old one is finished self._abort() return if self._final_frame: opcode = self._frame_opcode else: self._fragmented_message_opcode = self._frame_opcode self._fragmented_message_buffer = data if self._final_frame: self._handle_message(opcode, data) if not self.client_terminated: self._receive_frame() def _handle_message(self, opcode, data): if self.client_terminated: return if self._frame_compressed: data = self._decompressor.decompress(data) if opcode == 0x1: # UTF-8 data self._message_bytes_in += len(data) try: decoded = data.decode("utf-8") except UnicodeDecodeError: self._abort() return self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data self._message_bytes_in += len(data) self._run_callback(self.handler.on_message, data) elif opcode == 0x8: # Close self.client_terminated = True if len(data) >= 2: self.handler.close_code = struct.unpack('>H', data[:2])[0] if len(data) > 2: self.handler.close_reason = to_unicode(data[2:]) # Echo the received close code, if any (RFC 6455 section 5.5.1). self.close(self.handler.close_code) elif opcode == 0x9: # Ping self._write_frame(True, 0xA, data) elif opcode == 0xA: # Pong self._run_callback(self.handler.on_pong, data) else: self._abort() def close(self, code=None, reason=None): """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): if code is None and reason is not None: code = 1000 # "normal closure" status code if code is None: close_data = b'' else: close_data = struct.pack('>H', code) if reason is not None: close_data += utf8(reason) self._write_frame(True, 0x8, close_data) self.server_terminated = True if self.client_terminated: if self._waiting is not None: self.stream.io_loop.remove_timeout(self._waiting) self._waiting = None self.stream.close() elif self._waiting is None: # Give the client a few seconds to complete a clean shutdown, # otherwise just close the connection. self._waiting = self.stream.io_loop.add_timeout( self.stream.io_loop.time() + 5, self._abort)
[文档]class WebSocketClientConnection(simple_httpclient._HTTPConnection): """WebSocket 客户端连接 这个类不应当直接被实例化, 请使用 `websocket_connect` """ def __init__(self, io_loop, request, on_message_callback=None, compression_options=None): self.compression_options = compression_options self.connect_future = TracebackFuture() self.protocol = None self.read_future = None self.read_queue = collections.deque() self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback self.close_code = self.close_reason = None scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] request.url = scheme + sep + rest request.headers.update({ 'Upgrade': 'websocket', 'Connection': 'Upgrade', 'Sec-WebSocket-Key': self.key, 'Sec-WebSocket-Version': '13', }) if self.compression_options is not None: # Always offer to let the server set our max_wbits (and even though # we don't offer it, we will accept a client_no_context_takeover # from the server). # TODO: set server parameters for deflate extension # if requested in self.compression_options. request.headers['Sec-WebSocket-Extensions'] = ( 'permessage-deflate; client_max_window_bits') self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( io_loop, None, request, lambda: None, self._on_http_response, 104857600, self.tcp_client, 65536, 104857600)
[文档] def close(self, code=None, reason=None): """关闭 websocket 连接 ``code`` 和 ``reason`` 的文档在 `WebSocketHandler.close` 下已给出. .. versionadded:: 3.2 .. versionchanged:: 4.0 添加 ``code`` 和 ``reason`` 这两个参数 """ if self.protocol is not None: self.protocol.close(code, reason) self.protocol = None
def on_connection_close(self): if not self.connect_future.done(): self.connect_future.set_exception(StreamClosedError()) self.on_message(None) self.tcp_client.close() super(WebSocketClientConnection, self).on_connection_close() def _on_http_response(self, response): if not self.connect_future.done(): if response.error: self.connect_future.set_exception(response.error) else: self.connect_future.set_exception(WebSocketError( "Non-websocket response")) def headers_received(self, start_line, headers): if start_line.code != 101: return super(WebSocketClientConnection, self).headers_received( start_line, headers) self.headers = headers self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) self.protocol._receive_frame() if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None self.stream = self.connection.detach() self.stream.set_close_callback(self.on_connection_close) # Once we've taken over the connection, clear the final callback # we set on the http request. This deactivates the error handling # in simple_httpclient that would otherwise interfere with our # ability to see exceptions. self.final_callback = None self.connect_future.set_result(self)
[文档] def write_message(self, message, binary=False): """发送消息到 websocket 服务器.""" return self.protocol.write_message(message, binary)
[文档] def read_message(self, callback=None): """读取来自 WebSocket 服务器的消息. 如果在 WebSocket 初始化时指定了 on_message_callback ,那么这个方法永远不会返回消息 如果连接已经关闭,返回结果会是一个结果是 message 的 future 对象或者是 None. 如果 future 给出了回调参数, 这个参数将会在 future 完成时调用. """ assert self.read_future is None future = TracebackFuture() if self.read_queue: future.set_result(self.read_queue.popleft()) else: self.read_future = future if callback is not None: self.io_loop.add_future(future, callback) return future
def on_message(self, message): if self._on_message_callback: self._on_message_callback(message) elif self.read_future is not None: self.read_future.set_result(message) self.read_future = None else: self.read_queue.append(message) def on_pong(self, data): pass def get_websocket_protocol(self): return WebSocketProtocol13(self, mask_outgoing=True, compression_options=self.compression_options)
[文档]def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, on_message_callback=None, compression_options=None): """客户端 WebSocket 支持 需要指定 url, 返回一个结果为 `WebSocketClientConnection` 的 Future 对象 ``compression_options`` 作为 `.WebSocketHandler.get_compression_options` 的 返回值, 将会以同样的方式执行. 这个连接支持两种类型的操作.在协程风格下,应用程序通常在一个循环里调用 `~.WebSocketClientConnection.read_message`:: conn = yield websocket_connect(url) while True: msg = yield conn.read_message() if msg is None: break # Do something with msg 在回调风格下,需要传递 ``on_message_callback`` 到 ``websocket_connect`` 里. 在这两种风格里,一个内容是 ``None`` 的 message 都标志着 WebSocket 连接已经. .. versionchanged:: 3.2 允许使用 ``HTTPRequest`` 对象来代替 urls. .. versionchanged:: 4.1 添加 ``compression_options`` 和 ``on_message_callback`` . 不赞成使用 ``compression_options`` . """ if io_loop is None: io_loop = IOLoop.current() if isinstance(url, httpclient.HTTPRequest): assert connect_timeout is None request = url # Copy and convert the headers dict/object (see comments in # AsyncHTTPClient.fetch) request.headers = httputil.HTTPHeaders(request.headers) else: request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient._RequestProxy( request, httpclient.HTTPRequest._DEFAULTS) conn = WebSocketClientConnection(io_loop, request, on_message_callback=on_message_callback, compression_options=compression_options) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future