Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions mq_http_sdk/mq_async_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# coding=utf-8
import asyncio
import ssl
from logging import Logger
from typing import Optional, Union

import certifi

import aiohttp
from aiohttp.http_exceptions import BadStatusLine

from .mq_exception import *

DEFAULT_CONNECTION_TIMEOUT = 60
DEFAULT_READ_TIMEOUT = 100


class RequestInternal:
def __init__(self, method="", uri="", header=None, data=""):
if header == None:
header = {}
self.method = method
self.uri = uri
self.header = header
self.data = data

def __str__(self):
return "Method: %s\nUri: %s\nHeader: %s\nData: %s\n" % \
(self.method, self.uri, "\n".join(["%s: %s" % (k, v) for k, v in list(self.header.items())]), self.data)


class ResponseInternal:
def __init__(self, status=0, header=None, data=""):
if header == None:
header = {}
self.status = status
self.header = header
self.data = data

def __str__(self):
return "Status: %s\nHeader: %s\nData: %s\n" % \
(self.status, "\n".join(["%s: %s" % (k, v) for k, v in list(self.header.items())]), self.data)


class MQHTTPAsyncConnection:
def __init__(self, host: str):
self.host = host
self.session = None

async def async_request(self, req_inter, timeout) -> ResponseInternal:
if self.session and not self.session.closed:
return await self._async_do_action(self.session, req_inter, timeout)
async with aiohttp.ClientSession(base_url=f"http://{self.host}", raise_for_status=False) as session:
return await self._async_do_action(session, req_inter, timeout)

@staticmethod
async def _async_do_action(session, req_inter, timeout) -> ResponseInternal:
async with session.request(method=req_inter.method,
url=req_inter.uri,
data=req_inter.data,
headers=req_inter.header, ssl=False,
timeout=timeout) as http_resp:
return ResponseInternal(status=http_resp.status, header=http_resp.headers, data=await http_resp.text())

async def renew_session(self):
if self.session and not self.session.closed:
await self.session.close()
self.session = aiohttp.ClientSession(base_url=f"http://{self.host}", raise_for_status=False)

async def close(self):
if self.session and not self.session.closed:
await self.session.close()


class MQHTTPSAsyncConnection:
def __init__(self, host: str, ca_cert: str):
self.host = host
self.ca_cert = ca_cert
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_verify_locations(ca_cert)
self.connector = aiohttp.TCPConnector(ssl=ssl_context)
self.session = None

async def async_request(self, req_inter, timeout) -> ResponseInternal:
if self.session and not self.session.closed:
return await self._async_do_action(self.session, req_inter, timeout)
async with aiohttp.ClientSession(base_url=f"https://{self.host}", raise_for_status=False,
connector=self.connector) as session:
return await self._async_do_action(session, req_inter, timeout)

@staticmethod
async def _async_do_action(session, req_inter, timeout) -> ResponseInternal:
async with session.request(method=req_inter.method,
url=req_inter.uri,
data=req_inter.data,
headers=req_inter.header, ssl=True,
timeout=timeout) as http_resp:
return ResponseInternal(status=http_resp.status, header=http_resp.headers, data=await http_resp.text())

async def renew_session(self):
if self.session and not self.session.closed:
await self.session.close()
self.session = aiohttp.ClientSession(base_url=f"https://{self.host}", raise_for_status=False, connector=self.connector)

async def close(self):
if self.session and not self.session.closed:
await self.session.close()


class MQAsyncHttp:
def __init__(self, host: str, connection_timeout: Optional[int] = DEFAULT_CONNECTION_TIMEOUT, logger: Optional[Logger] = None, is_https: Optional[bool] = False, read_timeout: Optional[int] = DEFAULT_READ_TIMEOUT):
ca_cert = certifi.where()
self.connector = None
self.is_https = ca_cert and is_https
self.timeout = aiohttp.ClientTimeout(sock_read=read_timeout, sock_connect=connection_timeout)
self.host = host
self.is_https = is_https
self.connection_timeout = connection_timeout
self.read_timeout = read_timeout
self.logger = logger
if self.is_https:
self.conn = MQHTTPSAsyncConnection(host, ca_cert=ca_cert)
else:
self.conn = MQHTTPAsyncConnection(host)
if self.logger:
self.logger.info("InitOnsAHttp ConnectionTime:%s" % self.connection_timeout)

def set_log_level(self, log_level: Union[str, int]):
if self.logger:
self.logger.setLevel(log_level)

def close_log(self):
self.logger = None

def set_connection_timeout(self, connection_timeout: int):
self.connection_timeout = connection_timeout
self.timeout = aiohttp.ClientTimeout(sock_read=self.read_timeout, sock_connect=connection_timeout)

async def send_request(self, req_inter: RequestInternal) -> ResponseInternal:
try:
if self.logger:
self.logger.debug("SendRequest %s" % req_inter)
try:
resp_inter = await self.conn.async_request(req_inter=req_inter, timeout=self.timeout)
except BadStatusLine:
await self.conn.renew_session()
resp_inter = await self.conn.async_request(req_inter=req_inter, timeout=self.timeout)
if self.logger:
self.logger.debug("GetResponse %s" % resp_inter)
return resp_inter
except Exception as e:
raise MQClientNetworkException("NetWorkException", str(e)) # raise netException
104 changes: 104 additions & 0 deletions mq_http_sdk/mq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hmac
import platform
from . import pkg_info
from .mq_async_http import MQAsyncHttp
from .mq_xml_handler import *
from .mq_tool import *
from .mq_http import *
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self, host, access_id, access_key, security_token="", debug=False,
self.logger = logger
self.debug = debug
self.http = MQHttp(self.host, logger=logger, is_https=self.is_https)
self.async_http = MQAsyncHttp(self.host, logger=logger, is_https=self.is_https)
if self.logger:
self.logger.info("InitClient Host:%s Version:%s" % (host, self.version))

Expand Down Expand Up @@ -99,20 +101,32 @@ def set_log_level(self, log_level):
MQLogger.validate_loglevel(log_level)
self.logger.setLevel(log_level)
self.http.set_log_level(log_level)
self.async_http.set_log_level(log_level)

def close_log(self):
self.logger = None
self.http.close_log()
self.async_http.close_log()

def set_connection_timeout(self, connection_timeout):
self.http.set_connection_timeout(connection_timeout)
self.async_http.set_connection_timeout(connection_timeout)

def set_keep_alive(self, keep_alive):
self.http.set_keep_alive(keep_alive)

def close_connection(self):
self.http.conn.close()

async def async_close_connection(self):
await self.async_http.conn.close()

async def start_async_session(self):
await self.async_http.conn.renew_session()

async def close_async_session(self):
await self.async_http.conn.close()

def consume_message(self, req, resp):
# check parameter
ConsumeMessageValidator.validate(req)
Expand Down Expand Up @@ -149,6 +163,42 @@ def consume_message(self, req, resp):
(msg.message_id, msg.message_body_md5, msg.next_consume_time, msg.receipt_handle,
msg.publish_time, msg.consumed_times) for msg in resp.message_list])))

async def async_consume_message(self, req, resp):
# check parameter
ConsumeMessageValidator.validate(req)

# make request internal
req_url = "/%s/%s/%s?consumer=%s&numOfMessages=%s" % (URI_SEC_TOPIC, req.topic_name, URI_SEC_MESSAGE, req.consumer, req.batch_size)
if req.instance_id != "":
req_url += "&ns=%s" % req.instance_id
if req.wait_seconds != -1:
req_url += "&waitseconds=%s" % req.wait_seconds
if req.message_tag != "":
req_url += "&tag=%s" % req.message_tag
if req.trans != "":
req_url += "&trans=%s" % req.trans

req_inter = RequestInternal(req.method, req_url)
self.build_header(req, req_inter)

# send request
resp_inter = await self.async_http.send_request(req_inter)

# handle result, make response
resp.status = resp_inter.status
resp.header = resp_inter.header
self.check_status(resp_inter, resp)
if resp.error_data == "":
resp.message_list = ConsumeMessageDecoder.decode(resp_inter.data, resp.get_req_id())
if self.logger:
self.logger.info("ConsumeMessage RequestId:%s TopicName:%s WaitSeconds:%s BatchSize:%s Tag:%s MessageCount:%s \
MessagesInfo\n%s" % (
resp.get_req_id(), req.topic_name, req.wait_seconds, req.batch_size, req.message_tag, len(resp.message_list), \
"\n".join([
"MessageId:%s MessageBodyMD5:%s NextConsumeTime:%s ReceiptHandle:%s PublishTime:%s ConsumedTimes:%s" % \
(msg.message_id, msg.message_body_md5, msg.next_consume_time, msg.receipt_handle,
msg.publish_time, msg.consumed_times) for msg in resp.message_list])))

def ack_message(self, req, resp):
# check parameter
AckMessageValidator.validate(req)
Expand All @@ -175,6 +225,32 @@ def ack_message(self, req, resp):
self.logger.info("AckMessage RequestId:%s TopicName:%s ReceiptHandles\n%s" % \
(resp.get_req_id(), req.topic_name, "\n".join(req.receipt_handle_list)))

async def async_ack_message(self, req, resp):
# check parameter
AckMessageValidator.validate(req)

# make request internal
req_url = "/%s/%s/%s?consumer=%s" % (URI_SEC_TOPIC, req.topic_name, URI_SEC_MESSAGE, req.consumer)
if req.instance_id != "":
req_url += "&ns=%s" % req.instance_id
if req.trans != "":
req_url += "&trans=%s" % req.trans

req_inter = RequestInternal(req.method, req_url)
req_inter.data = ReceiptHandlesEncoder.encode(req.receipt_handle_list)
self.build_header(req, req_inter)

# send request
resp_inter = await self.async_http.send_request(req_inter)

# handle result, make response
resp.status = resp_inter.status
resp.header = resp_inter.header
self.check_status(resp_inter, resp, AckMessageDecoder)
if self.logger:
self.logger.info("AckMessage RequestId:%s TopicName:%s ReceiptHandles\n%s" % \
(resp.get_req_id(), req.topic_name, "\n".join(req.receipt_handle_list)))

def publish_message(self, req, resp):
# check parameter
PublishMessageValidator.validate(req)
Expand Down Expand Up @@ -202,11 +278,39 @@ def publish_message(self, req, resp):
self.logger.info("PublishMessage RequestId:%s TopicName:%s MessageId:%s MessageBodyMD5:%s" % \
(resp.get_req_id(), req.topic_name, resp.message_id, resp.message_body_md5))

async def async_publish_message(self, req, resp):
# check parameter
PublishMessageValidator.validate(req)

# make request internal
req_url = "/%s/%s/%s" % (URI_SEC_TOPIC, req.topic_name, URI_SEC_MESSAGE)
if req.instance_id != "":
req_url += "?ns=%s" % req.instance_id

req_inter = RequestInternal(req.method, req_url)
req_inter.data = TopicMessageEncoder.encode(req)
self.build_header(req, req_inter)

# send request
resp_inter = await self.async_http.send_request(req_inter)

# handle result, make response
resp.status = resp_inter.status
resp.header = resp_inter.header
self.check_status(resp_inter, resp)
if resp.error_data == "":
resp.message_id, resp.message_body_md5, resp.receipt_handle = PublishMessageDecoder.decode(resp_inter.data,
resp.get_req_id())
if self.logger:
self.logger.info("PublishMessage RequestId:%s TopicName:%s MessageId:%s MessageBodyMD5:%s" % \
(resp.get_req_id(), req.topic_name, resp.message_id, resp.message_body_md5))

###################################################################################################
# ----------------------internal-------------------------------------------------------------------#
def build_header(self, req, req_inter):
if self.http.is_keep_alive():
req_inter.header["Connection"] = "Keep-Alive"
req_inter.header["content-type"] = ""
if req_inter.data != "":
req_inter.header["content-type"] = "text/xml;charset=UTF-8"
req_inter.header["x-mq-version"] = self.version
Expand Down
Loading