【Python】基於asyncio的聊天伺服器
阿新 • • 發佈:2019-01-22
機器環境
- win10
- python3.6
關鍵點
- 自定義訊息協議解決粘包問題:訊息頭+訊息體(4I+ns),訊息頭定長4個位元組
- 定義session類,管理連線進來的socket連線
- 定義handler類,管理訊息觸發對應的處理方法
Server端程式碼
- session.py
__all__ = ['Session']
class Session(object):
def __init__(self):
self.clients = {}
def get(self, client):
# Get transport associated by client if exists.
if client not in self.clients:
return None
return self.clients[client]
def __contains__(self, client):
# Decide if client is online
return client in self.clients
def __repr__(self):
return "{}".format(self.clients)
__str__ = __repr__
def register(self, client, transport):
"""Register client on session"""
self.clients[client] = transport
def unregister(self, client):
"""Unregister client on session"""
if client in self.clients:
del self.clients[client]
if __name__ == '__main__':
Session()
- handlers.py
__all__ = ['MessageHandler']
import asyncio
import json
from struct import pack
from session import Session
class MetaHandler(type):
"""Metaclass for MessageHandler"""
def __init__(cls, name, bases, _dict):
try:
cls._msg_handlers[cls.__msgtype__] = cls
except AttributeError:
cls._msg_handlers = {}
class MessageHandler(metaclass=MetaHandler):
_session = Session()
def handle(self, msg, transport):
try:
_handler = self._msg_handlers[msg['type']]
except KeyError:
return ErrorHandler().handler(msg)
# Handling messages in a asyncio-Task
# Don’t directly create Task instances: use the async() function
# or the BaseEventLoop.create_task() method.
#return _handler().handle(msg, transport)
return asyncio.async(_handler().handle(msg, transport))
class ErrorHandler(MessageHandler):
"""
Unknown message type
"""
__msgtype__ = 'unknown'
def handle(self, msg):
print("Unknown message type: {}".format(msg))
class Register(MessageHandler):
"""
Registry handler for handling clients registry.
Message body should like this:
{'type': 'register', 'uid': 'unique-user-id'}
"""
__msgtype__ = 'register'
def __init__(self):
self.current_uid = None
self.transport = None
@asyncio.coroutine
def handle(self, msg, transport):
self.current_uid = msg['uid']
self.transport = transport
print("registe uid: {}".format(self.current_uid))
# Register user in global session
self._session.register(self.current_uid, self.transport)
class SendTextMsg(MessageHandler):
"""
Send message to others.
Message body should like this:
{'type': 'text', 'sender': 'Jack', 'receiver': 'Rose', 'content': 'I love you forever'}
"""
__msgtype__ = 'text' # Text message
@asyncio.coroutine
def handle(self, msg, _):
"""
Send message to receiver if receiver is online, and
save message to mongodb. Otherwise save
message to mongodb as offline message.
:param msg:
:return: None
"""
print("send data...{}".format(msg))
transport = self._session.get(msg['receiver'])
msg_pack = json.dumps(msg)
msg_len = len(msg_pack)
if transport:
# Pack message as length-prifixed and send to receiver.
transport.write(pack("!I%ds" % msg_len, msg_len, bytes(msg_pack, encoding='utf-8')))
class Unregister(MessageHandler):
"""
Unregister user from global session
Message body should like this:
{'type': 'unregister', 'uid': 'unique-user-id'}
"""
__msgtype__ = 'unregister'
@asyncio.coroutine
def handle(self, msg, _):
"""Unregister user record from global session"""
self._session.unregister(msg['uid'])
- server.py
import asyncio
import json
from handlers import MessageHandler
# 訊息頭長度 int or uint
_MESSAGE_PREFIX_LENGTH = 4
# 位元組序
_BYTE_ORDER = 'big'
class myImProtocol(asyncio.Protocol):
_buffer = b'' # 資料緩衝Buffer
_msg_len = None # 訊息長度
def data_received(self, data):
while data:
data = self.process_data(data)
def process_data(self, data):
"""
Called when some data is received.
This method must be implemented by subclasses
The argument is a bytes object.
"""
self._buffer += data
# For store the rest data out-of a full message
_buffer = None
if self._msg_len is None:
# If buffer length < _MESSAGE_PREFIX_LENGTH return for more data
if len(self._buffer) < _MESSAGE_PREFIX_LENGTH:
return
# If buffer length >= _MESSAGE_PREFIX_LENGTH
self._msg_len = int.from_bytes(self._buffer[:_MESSAGE_PREFIX_LENGTH], byteorder=_BYTE_ORDER)
# The left bytes will be the message body
self._buffer = self._buffer[_MESSAGE_PREFIX_LENGTH:]
# Received full message
if len(self._buffer) >= self._msg_len:
# Call message_received to handler message
self.message_received(self._buffer[:self._msg_len])
# Left the rest of the buffer for next message
_buffer = self._buffer[self._msg_len:]
# Clean data buffer for next message
self._buffer = b''
# Set message length to None for next message
self._msg_len = None
return _buffer
def message_received(self, msg):
"""
Must override in subclass
:param msg: the full message
:return: None
"""
raise NotImplementedError()
class myIm(myImProtocol):
def __init__(self):
self.handler = MessageHandler()
self.transport = None
def connection_made(self, transport):
self.transport = transport
def message_received(self, msg):
"""
The real message handler
:param msg: a full message without prefix length
:return: None
"""
# Convert bytes msg to python dictionary
msg = json.loads(msg.decode("utf-8"))
print("receive msg...{}".format(msg))
# Handler msg
return self.handler.handle(msg, self.transport)
class myImServer(object):
def __init__(self, protocol_factory, host, port):
self.host = host
self.port = port
self.protocol_factory = protocol_factory
def start(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(loop.create_server(self.protocol_factory, self.host, self.port))
loop.run_forever()
if __name__ == '__main__':
server = myImServer(myIm, 'localhost', 2222)
server.start()
客戶端程式碼
- client.py
import asyncio
import socket
import json
import random
from struct import pack
_MESSAGE_TEXT = {
'type': 'text',
'sender': 'niuminguo',
'receiver': 'niuminguo',
'content': 'love',
}
_MESSAGE_REG = {
'uid': 'niuminguo',
'type': 'register',
}
class EchoClientProtocol(asyncio.Protocol):
def __init__(self, loop):
self.loop = loop
def connection_made(self, transport):
print("send data......")
msg = json.dumps(_MESSAGE_REG)
msg_len = len(msg)
packed_msg = pack("!i%ds" % msg_len, msg_len, bytes(msg, encoding='utf-8'))
transport.write(packed_msg)
msg2 = json.dumps(_MESSAGE_TEXT)
msg_len2 = len(msg2)
packed_msg2 = pack("!i%ds" % msg_len2, msg_len2, bytes(msg2, encoding='utf-8'))
transport.write(packed_msg2)
def data_received(self, data):
print('Data received: {!r}'.format(data.decode()))
def connection_lost(self, exc):
print('The server closed the connection')
print('Stop the event loop')
self.loop.stop()
loop = asyncio.get_event_loop()
coro = loop.create_connection(lambda: EchoClientProtocol(loop),
'127.0.0.1', 2222)
loop.run_until_complete(coro)
loop.run_forever()
loop.close()
執行測試
- python3 server.py
- python3 client.py