From 60115cf878148ac6671163789d9968dc271de45e Mon Sep 17 00:00:00 2001 From: Yan Chao Mei <1653720237@qq.com> Date: Tue, 25 Jul 2023 10:02:28 +0800 Subject: [PATCH] General message engineering in Producer (#559) * General message engineering in Producer * remove user * finish retry & isolation * add comments&exception handler * add license * use snake case naming * fix name & finish telemetry rebuild * fix style issues * finish retry&isolation test * init delay&fifo message * finish delay & fifo & transaction message & its tests --- python/rocketmq/client.py | 306 ++++++++++++-- python/rocketmq/client_config.py | 30 +- python/rocketmq/client_id_encoder.py | 19 +- python/rocketmq/definition.py | 104 +++++ .../exponential_backoff_retry_policy.py | 100 +++++ python/rocketmq/producer.py | 387 +++++++++++++++++- python/rocketmq/publish_settings.py | 7 +- python/rocketmq/publishing_message.py | 86 ++++ python/rocketmq/rpc_client.py | 2 +- python/rocketmq/send_receipt.py | 27 +- python/rocketmq/session.py | 40 +- python/rocketmq/status_checker.py | 212 ++++++++++ python/rocketmq/utils.py | 5 + python/tests/test_foo.py | 3 +- 14 files changed, 1270 insertions(+), 58 deletions(-) create mode 100644 python/rocketmq/exponential_backoff_retry_policy.py create mode 100644 python/rocketmq/publishing_message.py create mode 100644 python/rocketmq/status_checker.py diff --git a/python/rocketmq/client.py b/python/rocketmq/client.py index 0ef32e60d..509d991d6 100644 --- a/python/rocketmq/client.py +++ b/python/rocketmq/client.py @@ -13,84 +13,248 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import threading from typing import Set -from protocol import service_pb2 -from protocol.service_pb2 import QueryRouteRequest +from protocol import definition_pb2, service_pb2 +from protocol.definition_pb2 import Code as ProtoCode +from protocol.service_pb2 import HeartbeatRequest, QueryRouteRequest from rocketmq.client_config import ClientConfig from rocketmq.client_id_encoder import ClientIdEncoder -from rocketmq.definition import TopicRouteData +from rocketmq.definition import Resource, TopicRouteData +from rocketmq.log import logger from rocketmq.rpc_client import Endpoints, RpcClient from rocketmq.session import Session from rocketmq.signature import Signature +class ScheduleWithFixedDelay: + def __init__(self, action, delay, period): + self.action = action + self.delay = delay + self.period = period + self.task = None + + async def start(self): + await asyncio.sleep(self.delay) + while True: + try: + await self.action() + except Exception as e: + logger.error(e, "Failed to execute scheduled task") + finally: + await asyncio.sleep(self.period) + + def schedule(self): + loop1 = asyncio.new_event_loop() + asyncio.set_event_loop(loop1) + self.task = asyncio.create_task(self.start()) + + def cancel(self): + if self.task: + self.task.cancel() + + class Client: + """ + Main client class which handles interaction with the server. + """ def __init__(self, client_config: ClientConfig, topics: Set[str]): + """ + Initialization method for the Client class. + + :param client_config: Client configuration. + :param topics: Set of topics that the client is subscribed to. + """ self.client_config = client_config self.client_id = ClientIdEncoder.generate() self.endpoints = client_config.endpoints self.topics = topics + #: A cache to store topic routes. self.topic_route_cache = {} + #: A table to store session information. self.sessions_table = {} self.sessionsLock = threading.Lock() self.client_manager = ClientManager(self) + #: A dictionary to store isolated items. + self.isolated = dict() + async def start(self): + """ + Start method which initiates fetching of topic routes and schedules heartbeats. + """ # get topic route + logger.debug(f"Begin to start the rocketmq client, client_id={self.client_id}") for topic in self.topics: self.topic_route_cache[topic] = await self.fetch_topic_route(topic) + scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12) + scheduler_sync_settings = ScheduleWithFixedDelay(self.sync_settings, 3, 12) + scheduler.schedule() + scheduler_sync_settings.schedule() + logger.debug(f"Start the rocketmq client successfully, client_id={self.client_id}") + + async def shutdown(self): + logger.debug(f"Begin to shutdown rocketmq client, client_id={self.client_id}") - def GetTotalRouteEndpoints(self): + logger.debug(f"Shutdown the rocketmq client successfully, client_id={self.client_id}") + + async def heartbeat(self): + """ + Asynchronous method that sends a heartbeat to the server. + """ + try: + endpoints = self.get_total_route_endpoints() + request = HeartbeatRequest() + request.client_type = definition_pb2.PRODUCER + topic = Resource() + topic.name = "normal_topic" + # Collect task into a map. + for item in endpoints: + try: + + task = await self.client_manager.heartbeat(item, request, self.client_config.request_timeout) + code = task.status.code + if code == ProtoCode.OK: + logger.info(f"Send heartbeat successfully, endpoints={item}, client_id={self.client_id}") + + if item in self.isolated: + self.isolated.pop(item) + logger.info(f"Rejoin endpoints which was isolated before, endpoints={item}, " + + f"client_id={self.client_id}") + return + status_message = task.status.message + logger.info(f"Failed to send heartbeat, endpoints={item}, code={code}, " + + f"status_message={status_message}, client_id={self.client_id}") + except Exception: + logger.error(f"Failed to send heartbeat, endpoints={item}") + except Exception as e: + logger.error(f"[Bug] unexpected exception raised during heartbeat, client_id={self.client_id}, Exception: {str(e)}") + + def get_total_route_endpoints(self): + """ + Method that returns all route endpoints. + """ endpoints = set() for item in self.topic_route_cache.items(): for endpoint in [mq.broker.endpoints for mq in item[1].message_queues]: endpoints.add(endpoint) return endpoints + async def get_route_data(self, topic): + """ + Asynchronous method that fetches route data for a given topic. + + :param topic: The topic to fetch route data for. + """ + if topic in self.topic_route_cache: + return self.topic_route_cache[topic] + topic_route_data = await self.fetch_topic_route(topic=topic) + return topic_route_data + def get_client_config(self): + """ + Method to return client configuration. + """ return self.client_config - async def OnTopicRouteDataFetched(self, topic, topicRouteData): + async def sync_settings(self): + total_route_endpoints = self.get_total_route_endpoints() + + for endpoints in total_route_endpoints: + created, session = await self.get_session(endpoints) + await session.sync_settings(True) + logger.info(f"Sync settings to remote, endpoints={endpoints}") + + def stats(self): + # TODO: stats implement + pass + + async def notify_client_termination(self): + pass + + async def on_recover_orphaned_transaction_command(self, endpoints, command): + pass + + async def on_verify_message_command(self, endpoints, command): + logger.warn(f"Ignore verify message command from remote, which is not expected, clientId={self.client_id}, " + + f"endpoints={endpoints}, command={command}") + pass + + async def on_print_thread_stack_trace_command(self, endpoints, command): + pass + + async def on_settings_command(self, endpoints, settings): + pass + + async def on_topic_route_data_fetched(self, topic, topic_route_data): + """ + Asynchronous method that handles the process once the topic route data is fetched. + + :param topic: The topic for which the route data is fetched. + :param topic_route_data: The fetched topic route data. + """ route_endpoints = set() - for mq in topicRouteData.message_queues: + for mq in topic_route_data.message_queues: route_endpoints.add(mq.broker.endpoints) - existed_route_endpoints = self.GetTotalRouteEndpoints() + existed_route_endpoints = self.get_total_route_endpoints() new_endpoints = route_endpoints.difference(existed_route_endpoints) for endpoints in new_endpoints: - created, session = await self.GetSession(endpoints) + created, session = await self.get_session(endpoints) if not created: continue - + logger.info(f"Begin to establish session for endpoints={endpoints}, client_id={self.client_id}") await session.sync_settings(True) + logger.info(f"Establish session for endpoints={endpoints} successfully, client_id={self.client_id}") - self.topic_route_cache[topic] = topicRouteData - # self.OnTopicRouteDataUpdated0(topic, topicRouteData) + self.topic_route_cache[topic] = topic_route_data async def fetch_topic_route0(self, topic): - req = QueryRouteRequest() - req.topic.name = topic - address = req.endpoints.addresses.add() - address.host = self.endpoints.Addresses[0].host - address.port = self.endpoints.Addresses[0].port - req.endpoints.scheme = self.endpoints.scheme.to_protobuf(self.endpoints.scheme) - response = await self.client_manager.query_route(self.endpoints, req, 10) - - message_queues = response.message_queues - return TopicRouteData(message_queues) - - # return topic data + """ + Asynchronous method that fetches the topic route. + + :param topic: The topic to fetch the route for. + """ + try: + req = QueryRouteRequest() + req.topic.name = topic + address = req.endpoints.addresses.add() + address.host = self.endpoints.Addresses[0].host + address.port = self.endpoints.Addresses[0].port + req.endpoints.scheme = self.endpoints.scheme.to_protobuf(self.endpoints.scheme) + response = await self.client_manager.query_route(self.endpoints, req, 10) + code = response.status.code + if code != ProtoCode.OK: + logger.error(f"Failed to fetch topic route, client_id={self.client_id}, topic={topic}, code={code}, " + + f"statusMessage={response.status.message}") + message_queues = response.message_queues + return TopicRouteData(message_queues) + except Exception as e: + logger.error(e, f"Failed to fetch topic route, client_id={self.client_id}, topic={topic}") + raise + async def fetch_topic_route(self, topic): + """ + Asynchronous method that fetches the topic route and updates the data. + + :param topic: The topic to fetch the route for. + """ topic_route_data = await self.fetch_topic_route0(topic) - await self.OnTopicRouteDataFetched(topic, topic_route_data) + await self.on_topic_route_data_fetched(topic, topic_route_data) + logger.info(f"Fetch topic route successfully, client_id={self.client_id}, topic={topic}, topicRouteData={topic_route_data}") return topic_route_data - async def GetSession(self, endpoints): + async def get_session(self, endpoints): + """ + Asynchronous method that gets the session for a given endpoint. + + :param endpoints: The endpoints to get the session for. + """ self.sessionsLock.acquire() try: # Session exists, return in advance. @@ -105,21 +269,41 @@ async def GetSession(self, endpoints): if endpoints in self.sessions_table: return (False, self.sessions_table[endpoints]) - stream = self.client_manager.telemetry(endpoints, 10) + stream = self.client_manager.telemetry(endpoints, 10000000) created = Session(endpoints, stream, self) self.sessions_table[endpoints] = created return (True, created) finally: self.sessionsLock.release() + def get_client_id(self): + return self.client_id + class ClientManager: + """Manager class for RPC Clients in a thread-safe manner. + Each instance is created by a specific client and can manage + multiple RPC clients. + """ + def __init__(self, client: Client): + #: The client that instantiated this manager. self.__client = client + + #: A dictionary that maps endpoints to the corresponding RPC clients. self.__rpc_clients = {} + + #: A lock used to ensure thread safety when accessing __rpc_clients. self.__rpc_clients_lock = threading.Lock() def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool): + """Retrieve the RPC client corresponding to the given endpoints. + If not present, a new RPC client is created and stored in __rpc_clients. + + :param endpoints: The endpoints associated with the RPC client. + :param ssl_enabled: A flag indicating whether SSL is enabled. + :return: The RPC client associated with the given endpoints. + """ with self.__rpc_clients_lock: rpc_client = self.__rpc_clients.get(endpoints) if rpc_client: @@ -134,10 +318,16 @@ async def query_route( request: service_pb2.QueryRouteRequest, timeout_seconds: int, ): + """Query the routing information. + + :param endpoints: The endpoints to query. + :param request: The request containing the details of the query. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the query. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) - metadata = Signature.sign(self.__client.client_config, self.__client.client_id) return await rpc_client.query_route(request, metadata, timeout_seconds) @@ -147,6 +337,13 @@ async def heartbeat( request: service_pb2.HeartbeatRequest, timeout_seconds: int, ): + """Send a heartbeat to the server to indicate that the client is still alive. + + :param endpoints: The endpoints to send the heartbeat to. + :param request: The request containing the details of the heartbeat. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the heartbeat. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -159,6 +356,13 @@ async def send_message( request: service_pb2.SendMessageRequest, timeout_seconds: int, ): + """Send a message to the server. + + :param endpoints: The endpoints to send the message to. + :param request: The request containing the details of the message. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the message sending operation. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -171,6 +375,13 @@ async def query_assignment( request: service_pb2.QueryAssignmentRequest, timeout_seconds: int, ): + """Query the assignment information. + + :param endpoints: The endpoints to query. + :param request: The request containing the details of the query. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the query. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -183,6 +394,13 @@ async def ack_message( request: service_pb2.AckMessageRequest, timeout_seconds: int, ): + """Send an acknowledgment for a message to the server. + + :param endpoints: The endpoints to send the acknowledgment to. + :param request: The request containing the details of the acknowledgment. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the acknowledgment. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -195,6 +413,13 @@ async def forward_message_to_dead_letter_queue( request: service_pb2.ForwardMessageToDeadLetterQueueRequest, timeout_seconds: int, ): + """Forward a message to the dead letter queue. + + :param endpoints: The endpoints to send the request to. + :param request: The request containing the details of the message to forward. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the forward operation. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -209,6 +434,13 @@ async def end_transaction( request: service_pb2.EndTransactionRequest, timeout_seconds: int, ): + """Ends a transaction. + + :param endpoints: The endpoints to send the request to. + :param request: The request to end the transaction. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the end transaction operation. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -221,6 +453,13 @@ async def notify_client_termination( request: service_pb2.NotifyClientTerminationRequest, timeout_seconds: int, ): + """Notify server about client termination. + + :param endpoints: The endpoints to send the notification to. + :param request: The request containing the details of the termination. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the notification operation. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -235,6 +474,13 @@ async def change_invisible_duration( request: service_pb2.ChangeInvisibleDurationRequest, timeout_seconds: int, ): + """Change the invisible duration of a message. + + :param endpoints: The endpoints to send the request to. + :param request: The request containing the new invisible duration. + :param timeout_seconds: The maximum time to wait for a response. + :return: The result of the change operation. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) @@ -248,6 +494,12 @@ def telemetry( endpoints: Endpoints, timeout_seconds: int, ): + """Fetch telemetry information. + + :param endpoints: The endpoints to send the request to. + :param timeout_seconds: The maximum time to wait for a response. + :return: The telemetry information. + """ rpc_client = self.__get_rpc_client( endpoints, self.__client.client_config.ssl_enabled ) diff --git a/python/rocketmq/client_config.py b/python/rocketmq/client_config.py index 41e691c4b..1ccd5e0b7 100644 --- a/python/rocketmq/client_config.py +++ b/python/rocketmq/client_config.py @@ -18,25 +18,49 @@ class ClientConfig: + """Client configuration class which holds the settings for a client. + The settings include endpoint configurations, session credential provider and SSL settings. + An instance of this class is used to setup the client with necessary configurations. + """ + def __init__( self, endpoints: Endpoints, session_credentials_provider: SessionCredentialsProvider, ssl_enabled: bool, ): + #: The endpoints for the client to connect to. self.__endpoints = endpoints + + #: The session credentials provider to authenticate the client. self.__session_credentials_provider = session_credentials_provider + + #: A flag indicating if SSL is enabled for the client. self.__ssl_enabled = ssl_enabled + + #: The request timeout for the client in seconds. self.request_timeout = 10 @property - def session_credentials_provider(self): + def session_credentials_provider(self) -> SessionCredentialsProvider: + """The session credentials provider for the client. + + :return: the session credentials provider + """ return self.__session_credentials_provider @property - def endpoints(self): + def endpoints(self) -> Endpoints: + """The endpoints for the client to connect to. + + :return: the endpoints + """ return self.__endpoints @property - def ssl_enabled(self): + def ssl_enabled(self) -> bool: + """A flag indicating if SSL is enabled for the client. + + :return: True if SSL is enabled, False otherwise + """ return self.__ssl_enabled diff --git a/python/rocketmq/client_id_encoder.py b/python/rocketmq/client_id_encoder.py index 138b05f02..b6f3ca3fa 100644 --- a/python/rocketmq/client_id_encoder.py +++ b/python/rocketmq/client_id_encoder.py @@ -22,12 +22,25 @@ class ClientIdEncoder: + """This class generates a unique client ID for each client based on + hostname, process id, index and the monotonic clock time. + """ + + #: The current index for client id generation. __INDEX = 0 + + #: The lock used for thread-safe incrementing of the index. __INDEX_LOCK = threading.Lock() + + #: The separator used in the client id string. __CLIENT_ID_SEPARATOR = "@" @staticmethod - def __get_and_increment_sequence(): + def __get_and_increment_sequence() -> int: + """Increment and return the current index in a thread-safe manner. + + :return: the current index after incrementing it. + """ with ClientIdEncoder.__INDEX_LOCK: temp = ClientIdEncoder.__INDEX ClientIdEncoder.__INDEX += 1 @@ -35,6 +48,10 @@ def __get_and_increment_sequence(): @staticmethod def generate() -> str: + """Generate a unique client ID. + + :return: the generated client id + """ index = ClientIdEncoder.__get_and_increment_sequence() return ( socket.gethostname() diff --git a/python/rocketmq/definition.py b/python/rocketmq/definition.py index b2115b609..3d63748c3 100644 --- a/python/rocketmq/definition.py +++ b/python/rocketmq/definition.py @@ -17,6 +17,7 @@ from typing import List from protocol.definition_pb2 import Broker as ProtoBroker +from protocol.definition_pb2 import Encoding as ProtoEncoding from protocol.definition_pb2 import MessageQueue as ProtoMessageQueue from protocol.definition_pb2 import MessageType as ProtoMessageType from protocol.definition_pb2 import Permission as ProtoPermission @@ -25,20 +26,55 @@ from rocketmq.rpc_client import Endpoints +class Encoding(Enum): + """Enumeration of supported encoding types.""" + IDENTITY = 0 + GZIP = 1 + + +class EncodingHelper: + """Helper class for converting encoding types to protobuf.""" + + @staticmethod + def to_protobuf(mq_encoding): + """Convert encoding type to protobuf. + + :param mq_encoding: The encoding to be converted. + :return: The corresponding protobuf encoding. + """ + if mq_encoding == Encoding.IDENTITY: + return ProtoEncoding.IDENTITY + elif mq_encoding == Encoding.GZIP: + return ProtoEncoding.GZIP + + class Broker: + """Represent a broker entity.""" + def __init__(self, broker): self.name = broker.name self.id = broker.id self.endpoints = Endpoints(broker.endpoints) def to_protobuf(self): + """Convert the broker to its protobuf representation. + + :return: The protobuf representation of the broker. + """ return ProtoBroker( Name=self.name, Id=self.id, Endpoints=self.endpoints.to_protobuf() ) class Resource: + """Represent a resource entity.""" + def __init__(self, name=None, resource=None): + """Initialize a resource. + + :param name: The name of the resource. + :param resource: The resource object. + """ if resource is not None: self.namespace = resource.ResourceNamespace self.name = resource.Name @@ -47,6 +83,10 @@ def __init__(self, name=None, resource=None): self.name = name def to_protobuf(self): + """Convert the resource to its protobuf representation. + + :return: The protobuf representation of the resource. + """ return ProtoResource(ResourceNamespace=self.namespace, Name=self.name) def __str__(self): @@ -54,6 +94,7 @@ def __str__(self): class Permission(Enum): + """Enumeration of supported permission types.""" NONE = 0 READ = 1 WRITE = 2 @@ -61,8 +102,15 @@ class Permission(Enum): class PermissionHelper: + """Helper class for converting permission types to protobuf and vice versa.""" + @staticmethod def from_protobuf(permission): + """Convert protobuf permission to Permission enum. + + :param permission: The protobuf permission to be converted. + :return: The corresponding Permission enum. + """ if permission == ProtoPermission.READ: return Permission.READ elif permission == ProtoPermission.WRITE: @@ -76,6 +124,11 @@ def from_protobuf(permission): @staticmethod def to_protobuf(permission): + """Convert Permission enum to protobuf permission. + + :param permission: The Permission enum to be converted. + :return: The corresponding protobuf permission. + """ if permission == Permission.READ: return ProtoPermission.READ elif permission == Permission.WRITE: @@ -87,6 +140,11 @@ def to_protobuf(permission): @staticmethod def is_writable(permission): + """Check if the permission is writable. + + :param permission: The Permission enum to be checked. + :return: True if the permission is writable, False otherwise. + """ if permission in [Permission.WRITE, Permission.READ_WRITE]: return True else: @@ -94,6 +152,11 @@ def is_writable(permission): @staticmethod def is_readable(permission): + """Check if the permission is readable. + + :param permission: The Permission enum to be checked. + :return: True if the permission is readable, False otherwise. + """ if permission in [Permission.READ, Permission.READ_WRITE]: return True else: @@ -101,6 +164,7 @@ def is_readable(permission): class MessageType(Enum): + """Enumeration of supported message types.""" NORMAL = 0 FIFO = 1 DELAY = 2 @@ -108,8 +172,15 @@ class MessageType(Enum): class MessageTypeHelper: + """Helper class for converting message types to protobuf and vice versa.""" + @staticmethod def from_protobuf(message_type): + """Convert protobuf message type to MessageType enum. + + :param message_type: The protobuf message type to be converted. + :return: The corresponding MessageType enum. + """ if message_type == ProtoMessageType.NORMAL: return MessageType.NORMAL elif message_type == ProtoMessageType.FIFO: @@ -123,6 +194,11 @@ def from_protobuf(message_type): @staticmethod def to_protobuf(message_type): + """Convert MessageType enum to protobuf message type. + + :param message_type: The MessageType enum to be converted. + :return: The corresponding protobuf message type. + """ if message_type == MessageType.NORMAL: return ProtoMessageType.NORMAL elif message_type == MessageType.FIFO: @@ -136,7 +212,13 @@ def to_protobuf(message_type): class MessageQueue: + """A class that encapsulates a message queue entity.""" + def __init__(self, message_queue): + """Initialize a MessageQueue instance. + + :param message_queue: The initial message queue to be encapsulated. + """ self._topic_resource = Resource(message_queue.topic) self.queue_id = message_queue.id self.permission = PermissionHelper.from_protobuf(message_queue.permission) @@ -148,12 +230,24 @@ def __init__(self, message_queue): @property def topic(self): + """The topic resource name. + + :return: The name of the topic resource. + """ return self._topic_resource.name def __str__(self): + """Get a string representation of the MessageQueue instance. + + :return: A string that represents the MessageQueue instance. + """ return f"{self.broker.name}.{self._topic_resource}.{self.queue_id}" def to_protobuf(self): + """Convert the MessageQueue instance to protobuf message queue. + + :return: A protobuf message queue that represents the MessageQueue instance. + """ message_types = [ MessageTypeHelper.to_protobuf(mt) for mt in self.accept_message_types ] @@ -167,7 +261,13 @@ def to_protobuf(self): class TopicRouteData: + """A class that encapsulates a list of message queues.""" + def __init__(self, message_queues: List[definition_pb2.MessageQueue]): + """Initialize a TopicRouteData instance. + + :param message_queues: The initial list of message queues to be encapsulated. + """ message_queue_list = [] for mq in message_queues: message_queue_list.append(MessageQueue(mq)) @@ -175,4 +275,8 @@ def __init__(self, message_queues: List[definition_pb2.MessageQueue]): @property def message_queues(self) -> List[MessageQueue]: + """The list of MessageQueue instances. + + :return: The list of MessageQueue instances that the TopicRouteData instance encapsulates. + """ return self.__message_queue_list diff --git a/python/rocketmq/exponential_backoff_retry_policy.py b/python/rocketmq/exponential_backoff_retry_policy.py new file mode 100644 index 000000000..4dc87cd8d --- /dev/null +++ b/python/rocketmq/exponential_backoff_retry_policy.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from datetime import timedelta + +from google.protobuf.duration_pb2 import Duration + + +class ExponentialBackoffRetryPolicy: + """A class implementing exponential backoff retry policy.""" + + def __init__(self, max_attempts, initial_backoff, max_backoff, backoff_multiplier): + """Initialize an ExponentialBackoffRetryPolicy instance. + + :param max_attempts: Maximum number of retry attempts. + :param initial_backoff: Initial delay duration before the first retry. + :param max_backoff: Maximum delay duration between retries. + :param backoff_multiplier: Multiplier that determines the delay factor between retries. + """ + self._max_attempts = max_attempts + self.initial_backoff = initial_backoff + self.max_backoff = max_backoff + self.backoff_multiplier = backoff_multiplier + + def get_max_attempts(self): + """Get maximum number of retry attempts. + + :return: Maximum number of retry attempts. + """ + return self._max_attempts + + def inherit_backoff(self, retry_policy): + """Inherit backoff parameters from another retry policy. + + :param retry_policy: The retry policy to inherit from. + :return: An instance of ExponentialBackoffRetryPolicy with inherited parameters. + :raise ValueError: If the strategy of the retry policy is not ExponentialBackoff. + """ + if retry_policy.strategy_case != "ExponentialBackoff": + raise ValueError("Strategy must be exponential backoff") + return self._inherit_backoff(retry_policy.exponential_backoff) + + def _inherit_backoff(self, retry_policy): + """Inherit backoff parameters from another retry policy. + + :param retry_policy: The retry policy to inherit from. + :return: An instance of ExponentialBackoffRetryPolicy with inherited parameters. + """ + return ExponentialBackoffRetryPolicy(self._max_attempts, + retry_policy.initial.ToTimedelta(), + retry_policy.max.ToTimedelta(), + retry_policy.multiplier) + + def get_next_attempt_delay(self, attempt): + """Calculate the delay before the next retry attempt. + + :param attempt: The number of the current attempt. + :return: The delay before the next attempt. + """ + delay_seconds = min( + self.initial_backoff.total_seconds() * math.pow(self.backoff_multiplier, 1.0 * (attempt - 1)), + self.max_backoff.total_seconds()) + return timedelta(seconds=delay_seconds) if delay_seconds >= 0 else timedelta(seconds=0) + + @staticmethod + def immediately_retry_policy(max_attempts): + """Create a retry policy that makes immediate retries. + + :param max_attempts: Maximum number of retry attempts. + :return: An instance of ExponentialBackoffRetryPolicy with no delay between retries. + """ + return ExponentialBackoffRetryPolicy(max_attempts, timedelta(seconds=0), timedelta(seconds=0), 1) + + def to_protobuf(self): + """Convert the ExponentialBackoffRetryPolicy instance to protobuf. + + :return: A protobuf message that represents the ExponentialBackoffRetryPolicy instance. + """ + exponential_backoff = { + 'Multiplier': self.backoff_multiplier, + 'Max': Duration.FromTimedelta(self.max_backoff), + 'Initial': Duration.FromTimedelta(self.initial_backoff) + } + return { + 'MaxAttempts': self._max_attempts, + 'ExponentialBackoff': exponential_backoff + } diff --git a/python/rocketmq/producer.py b/python/rocketmq/producer.py index 3d604a252..9e10a3db3 100644 --- a/python/rocketmq/producer.py +++ b/python/rocketmq/producer.py @@ -15,31 +15,103 @@ import asyncio import threading +import time +# from status_checker import StatusChecker +from datetime import datetime, timedelta +from threading import RLock from typing import Set +from unittest.mock import MagicMock, patch import rocketmq +from publishing_message import MessageType from rocketmq.client import Client from rocketmq.client_config import ClientConfig -from rocketmq.definition import TopicRouteData +from rocketmq.definition import PermissionHelper, TopicRouteData +from rocketmq.exponential_backoff_retry_policy import \ + ExponentialBackoffRetryPolicy from rocketmq.log import logger +from rocketmq.message import Message from rocketmq.message_id_codec import MessageIdCodec from rocketmq.protocol.definition_pb2 import Message as ProtoMessage -from rocketmq.protocol.definition_pb2 import Resource, SystemProperties -from rocketmq.protocol.service_pb2 import SendMessageRequest +from rocketmq.protocol.definition_pb2 import Resource +from rocketmq.protocol.definition_pb2 import Resource as ProtoResource +from rocketmq.protocol.definition_pb2 import SystemProperties +from rocketmq.protocol.definition_pb2 import \ + TransactionResolution as ProtoTransactionResolution +from rocketmq.protocol.service_pb2 import (EndTransactionRequest, + SendMessageRequest) from rocketmq.publish_settings import PublishingSettings +from rocketmq.publishing_message import PublishingMessage from rocketmq.rpc_client import Endpoints +from rocketmq.send_receipt import SendReceipt from rocketmq.session_credentials import (SessionCredentials, SessionCredentialsProvider) +from status_checker import TooManyRequestsException +from utils import get_positive_mod + + +class Transaction: + MAX_MESSAGE_NUM = 1 + + def __init__(self, producer): + self.producer = producer + self.messages = set() + self.messages_lock = RLock() + self.message_send_receipt_dict = {} + + def try_add_message(self, message): + with self.messages_lock: + if len(self.messages) > self.MAX_MESSAGE_NUM: + raise ValueError(f"Message in transaction has exceed the threshold: {self.MAX_MESSAGE_NUM}") + + publishing_message = PublishingMessage(message, self.producer.publish_settings, True) + self.messages.add(publishing_message) + return publishing_message + + def try_add_receipt(self, publishing_message, send_receipt): + with self.messages_lock: + if publishing_message not in self.messages: + raise ValueError("Message is not in the transaction") + + self.message_send_receipt_dict[publishing_message] = send_receipt + + async def commit(self): + # if self.producer.state != "Running": + # raise Exception("Producer is not running") + + if not self.message_send_receipt_dict: + raise ValueError("Transactional message has not been sent yet") + + for publishing_message, send_receipt in self.message_send_receipt_dict.items(): + await self.producer.end_transaction(send_receipt.endpoints, publishing_message.message.topic, send_receipt.message_id, send_receipt.transaction_id, "Commit") + + async def rollback(self): + # if self.producer.state != "Running": + # raise Exception("Producer is not running") + + if not self.message_send_receipt_dict: + raise ValueError("Transactional message has not been sent yet") + + for publishing_message, send_receipt in self.message_send_receipt_dict.items(): + await self.producer.end_transaction(send_receipt.endpoints, publishing_message.message.topic, send_receipt.message_id, send_receipt.transaction_id, "Rollback") class PublishingLoadBalancer: + """This class serves as a load balancer for message publishing. + It keeps track of a rotating index to help distribute the load evenly. + """ + def __init__(self, topic_route_data: TopicRouteData, index: int = 0): + #: current index for message queue selection self.__index = index + #: thread lock to ensure atomic update to the index self.__index_lock = threading.Lock() + + #: filter the message queues which are writable and from the master broker message_queues = [] for mq in topic_route_data.message_queues: if ( - not mq.permission.is_writable() + not PermissionHelper().is_writable(mq.permission) or mq.broker.id is not rocketmq.utils.master_broker_id ): continue @@ -48,15 +120,21 @@ def __init__(self, topic_route_data: TopicRouteData, index: int = 0): @property def index(self): + """Property to fetch the current index""" return self.__index def get_and_increment_index(self): + """Thread safe method to get the current index and increment it by one""" with self.__index_lock: temp = self.__index self.__index += 1 return temp def take_message_queues(self, excluded: Set[Endpoints], count: int): + """Fetch a specified number of message queues, excluding the ones provided. + It will first try to fetch from non-excluded brokers and if insufficient, + it will select from the excluded ones. + """ next_index = self.get_and_increment_index() candidates = [] candidate_broker_name = set() @@ -85,41 +163,309 @@ def take_message_queues(self, excluded: Set[Endpoints], count: int): return candidates return candidates + def take_message_queue_by_message_group(self, message_group): + index = get_positive_mod(hash(message_group), len(self.__message_queues)) + return self.__message_queues[index] + class Producer(Client): + """The Producer class extends the Client class and is used to publish + messages to specific topics in RocketMQ. + """ + def __init__(self, client_config: ClientConfig, topics: Set[str]): + """Create a new Producer. + + :param client_config: The configuration for the client. + :param topics: The set of topics to which the producer can send messages. + """ super().__init__(client_config, topics) + retry_policy = ExponentialBackoffRetryPolicy.immediately_retry_policy(10) + #: Set up the publishing settings with the given parameters. self.publish_settings = PublishingSettings( - self.client_id, self.endpoints, None, 10, topics + self.client_id, self.endpoints, retry_policy, 10, topics ) + #: Initialize the routedata cache. + self.publish_routedata_cache = {} async def __aenter__(self): + """Provide an asynchronous context manager for the producer.""" await self.start() async def __aexit__(self, exc_type, exc_val, exc_tb): + """Provide an asynchronous context manager for the producer.""" await self.shutdown() async def start(self): + """Start the RocketMQ producer and log the operation.""" logger.info(f"Begin to start the rocketmq producer, client_id={self.client_id}") await super().start() logger.info(f"The rocketmq producer starts successfully, client_id={self.client_id}") async def shutdown(self): + """Shutdown the RocketMQ producer and log the operation.""" logger.info(f"Begin to shutdown the rocketmq producer, client_id={self.client_id}") + await super().shutdown() logger.info(f"Shutdown the rocketmq producer successfully, client_id={self.client_id}") - async def send_message(self, message): + @staticmethod + def wrap_send_message_request(message, message_queue): + """Wrap the send message request for the RocketMQ producer. + + :param message: The message to be sent. + :param message_queue: The queue to which the message will be sent. + :return: The SendMessageRequest with the message and queue details. + """ req = SendMessageRequest() - req.messages.extend([message]) - topic_data = self.topic_route_cache["normal_topic"] - endpoints = topic_data.message_queues[2].broker.endpoints - return await self.client_manager.send_message(endpoints, req, 10) + req.messages.extend([message.to_protobuf(message_queue.queue_id)]) + return req + + async def send(self, message, transaction: Transaction = None): + tx_enabled = True + if transaction is None: + tx_enabled = False + if tx_enabled: + logger.debug("Transaction send") + publishing_message = transaction.try_add_message(message) + send_receipt = await self.send_message(message, tx_enabled) + transaction.try_add_receipt(publishing_message, send_receipt) + return send_receipt + else: + return await self.send_message(message) + + async def send_message(self, message, tx_enabled=False): + """Send a message using a load balancer, retrying as needed according to the retry policy. + + :param message: The message to be sent. + """ + publish_load_balancer = await self.get_publish_load_balancer(message.topic) + publishing_message = PublishingMessage(message, self.publish_settings, tx_enabled) + retry_policy = self.get_retry_policy() + max_attempts = retry_policy.get_max_attempts() + + exception = None + logger.debug(publishing_message.message.message_group) + candidates = ( + publish_load_balancer.take_message_queues(set(self.isolated.keys()), max_attempts) + if publishing_message.message.message_group is None else + [publish_load_balancer.take_message_queue_by_message_group(publishing_message.message.message_group)]) + for attempt in range(1, max_attempts + 1): + start_time = time.time() + candidate_index = (attempt - 1) % len(candidates) + mq = candidates[candidate_index] + logger.debug(mq.accept_message_types) + if self.publish_settings.is_validate_message_type() and publishing_message.message_type.value != mq.accept_message_types[0].value: + raise ValueError( + "Current message type does not match with the accept message types," + + f" topic={message.topic}, actualMessageType={publishing_message.message_type}" + + f" acceptMessageType={','}") + + send_message_request = self.wrap_send_message_request(publishing_message, mq) + # topic_data = self.topic_route_cache["normal_topic"] + endpoints = mq.broker.endpoints + + try: + invocation = await self.client_manager.send_message(endpoints, send_message_request, self.client_config.request_timeout) + logger.debug(invocation) + send_recepits = SendReceipt.process_send_message_response(mq, invocation) + send_recepit = send_recepits[0] + if attempt > 1: + logger.info( + f"Re-send message successfully, topic={message.topic}," + + f" max_attempts={max_attempts}, endpoints={str(endpoints)}, clientId={self.client_id}") + return send_recepit + except Exception as e: + exception = e + self.isolated[endpoints] = True + if attempt >= max_attempts: + logger.error("Failed to send message finally, run out of attempt times, " + + f"topic={message.topic}, maxAttempt={max_attempts}, attempt={attempt}, " + + f"endpoints={endpoints}, messageId={publishing_message.message_id}, clientId={self.client_id}") + raise + if publishing_message.message_type == MessageType.TRANSACTION: + logger.error("Failed to send transaction message, run out of attempt times, " + + f"topic={message.topic}, maxAttempt=1, attempt={attempt}, " + + f"endpoints={endpoints}, messageId={publishing_message.message_id}, clientId={self.client_id}") + raise + if not isinstance(exception, TooManyRequestsException): + logger.error(f"Failed to send message, topic={message.topic}, max_attempts={max_attempts}, " + + f"attempt={attempt}, endpoints={endpoints}, messageId={publishing_message.message_id}," + + f" clientId={self.client_id}") + continue + + nextAttempt = 1 + attempt + delay = retry_policy.get_next_attempt_delay(nextAttempt) + await asyncio.sleep(delay.total_seconds()) + logger.warning(f"Failed to send message due to too many requests, would attempt to resend after {delay},\ + topic={message.topic}, max_attempts={max_attempts}, attempt={attempt}, endpoints={endpoints},\ + message_id={publishing_message.message_id}, client_id={self.client_id}") + finally: + elapsed_time = time.time() - start_time + logger.info(f"send time: {elapsed_time}") + + def update_publish_load_balancer(self, topic, topic_route_data): + """Update the load balancer used for publishing messages to a topic. + + :param topic: The topic for which to update the load balancer. + :param topic_route_data: The new route data for the topic. + :return: The updated load balancer. + """ + publishing_load_balancer = None + if topic in self.publish_routedata_cache: + publishing_load_balancer = self.publish_routedata_cache[topic] + else: + publishing_load_balancer = PublishingLoadBalancer(topic_route_data) + self.publish_routedata_cache[topic] = publishing_load_balancer + return publishing_load_balancer + + async def get_publish_load_balancer(self, topic): + """Get the load balancer used for publishing messages to a topic. + + :param topic: The topic for which to get the load balancer. + :return: The load balancer for the topic. + """ + if topic in self.publish_routedata_cache: + return self.publish_routedata_cache[topic] + topic_route_data = await self.get_route_data(topic) + return self.update_publish_load_balancer(topic, topic_route_data) def get_settings(self): + """Get the publishing settings for this producer. + + :return: The publishing settings for this producer. + """ return self.publish_settings + def get_retry_policy(self): + """Get the retry policy for this producer. + + :return: The retry policy for this producer. + """ + return self.publish_settings.GetRetryPolicy() + + def begin_transaction(self): + """Start a new transaction.""" + return Transaction(self) + + async def end_transaction(self, endpoints, topic, message_id, transaction_id, resolution): + """End a transaction based on its resolution (commit or rollback).""" + topic_resource = ProtoResource(name=topic) + request = EndTransactionRequest( + transaction_id=transaction_id, + message_id=message_id, + topic=topic_resource, + resolution=ProtoTransactionResolution.COMMIT if resolution == "Commit" else ProtoTransactionResolution.ROLLBACK + ) + await self.client_manager.end_transaction(endpoints, request, self.client_config.request_timeout) + # StatusChecker.check(invocation.response.status, request, invocation.request_id) + async def test(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "normal_topic" + msg = ProtoMessage() + msg.topic.CopyFrom(topic) + msg.body = b"My Normal Message Body" + sysperf = SystemProperties() + sysperf.message_id = MessageIdCodec.next_message_id() + msg.system_properties.CopyFrom(sysperf) + producer = Producer(client_config, topics={"normal_topic"}) + message = Message(topic.name, msg.body) + await producer.start() + await asyncio.sleep(10) + send_receipt = await producer.send(message) + logger.info(f"Send message successfully, {send_receipt}") + + +async def test_delay_message(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "delay_topic" + msg = ProtoMessage() + msg.topic.CopyFrom(topic) + msg.body = b"My Delay Message Body" + sysperf = SystemProperties() + sysperf.message_id = MessageIdCodec.next_message_id() + msg.system_properties.CopyFrom(sysperf) + logger.debug(f"{msg}") + producer = Producer(client_config, topics={"delay_topic"}) + current_time_millis = int(round(time.time() * 1000)) + message_delay_time = timedelta(seconds=10) + result_time_millis = current_time_millis + int(message_delay_time.total_seconds() * 1000) + result_time_datetime = datetime.fromtimestamp(result_time_millis / 1000.0) + message = Message(topic.name, msg.body, delivery_timestamp=result_time_datetime) + await producer.start() + await asyncio.sleep(10) + send_receipt = await producer.send(message) + logger.info(f"Send message successfully, {send_receipt}") + + +async def test_fifo_message(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "fifo_topic" + msg = ProtoMessage() + msg.topic.CopyFrom(topic) + msg.body = b"My FIFO Message Body" + sysperf = SystemProperties() + sysperf.message_id = MessageIdCodec.next_message_id() + msg.system_properties.CopyFrom(sysperf) + logger.debug(f"{msg}") + producer = Producer(client_config, topics={"fifo_topic"}) + message = Message(topic.name, msg.body, message_group="yourMessageGroup") + await producer.start() + await asyncio.sleep(10) + send_receipt = await producer.send(message) + logger.info(f"Send message successfully, {send_receipt}") + + +async def test_transaction_message(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "transaction_topic" + msg = ProtoMessage() + msg.topic.CopyFrom(topic) + msg.body = b"My Transaction Message Body" + sysperf = SystemProperties() + sysperf.message_id = MessageIdCodec.next_message_id() + msg.system_properties.CopyFrom(sysperf) + logger.debug(f"{msg}") + producer = Producer(client_config, topics={"transaction_topic"}) + message = Message(topic.name, msg.body) + await producer.start() + # await asyncio.sleep(10) + transaction = producer.begin_transaction() + send_receipt = await producer.send(message, transaction) + logger.info(f"Send message successfully, {send_receipt}") + await transaction.commit() + + +async def test_retry_and_isolation(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( @@ -137,10 +483,25 @@ async def test(): msg.system_properties.CopyFrom(sysperf) logger.info(f"{msg}") producer = Producer(client_config, topics={"normal_topic"}) - await producer.start() - result = await producer.send_message(msg) - print(result) + message = Message(topic.name, msg.body) + with patch.object(producer.client_manager, 'send_message', new_callable=MagicMock) as mock_send: + mock_send.side_effect = Exception("Forced Exception for Testing") + await producer.start() + + try: + await producer.send(message) + except Exception: + logger.info("Exception occurred as expected") + + assert mock_send.call_count == producer.get_retry_policy().get_max_attempts(), "Number of attempts should equal max_attempts." + logger.debug(producer.isolated) + assert producer.isolated, "Endpoint should be marked as isolated after an error." + logger.info("Test completed successfully.") if __name__ == "__main__": asyncio.run(test()) + asyncio.run(test_delay_message()) + asyncio.run(test_fifo_message()) + asyncio.run(test_transaction_message()) + asyncio.run(test_retry_and_isolation()) diff --git a/python/rocketmq/publish_settings.py b/python/rocketmq/publish_settings.py index c629d5142..4f09cb7ca 100644 --- a/python/rocketmq/publish_settings.py +++ b/python/rocketmq/publish_settings.py @@ -17,13 +17,14 @@ import socket from typing import Dict +from rocketmq.exponential_backoff_retry_policy import \ + ExponentialBackoffRetryPolicy from rocketmq.protocol.definition_pb2 import UA from rocketmq.protocol.definition_pb2 import Publishing as ProtoPublishing from rocketmq.protocol.definition_pb2 import Resource as ProtoResource from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings from rocketmq.rpc_client import Endpoints -from rocketmq.settings import (ClientType, ClientTypeHelper, IRetryPolicy, - Settings) +from rocketmq.settings import ClientType, ClientTypeHelper, Settings from rocketmq.signature import Signature @@ -44,7 +45,7 @@ def __init__( self, client_id: str, endpoints: Endpoints, - retry_policy: IRetryPolicy, + retry_policy: ExponentialBackoffRetryPolicy, request_timeout: int, topics: Dict[str, bool], ): diff --git a/python/rocketmq/publishing_message.py b/python/rocketmq/publishing_message.py new file mode 100644 index 000000000..195399cb0 --- /dev/null +++ b/python/rocketmq/publishing_message.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket + +from definition import Encoding, EncodingHelper, MessageType, MessageTypeHelper +from google.protobuf.timestamp_pb2 import Timestamp +from message import Message +from message_id_codec import MessageIdCodec +from protocol.definition_pb2 import Message as ProtoMessage +from protocol.definition_pb2 import Resource, SystemProperties +from rocketmq.log import logger + + +class PublishingMessage(Message): + def __init__(self, message, publishing_settings, tx_enabled=False): + self.message = message + self.publishing_settings = publishing_settings + self.tx_enabled = tx_enabled + self.message_type = None + + max_body_size_bytes = publishing_settings.get_max_body_size_bytes() + if len(message.body) > max_body_size_bytes: + raise IOError(f"Message body size exceed the threshold, max size={max_body_size_bytes} bytes") + + self.message_id = MessageIdCodec.next_message_id() + + if not message.message_group and not message.delivery_timestamp and not tx_enabled: + self.message_type = MessageType.NORMAL + return + + if message.message_group and not tx_enabled: + self.message_type = MessageType.FIFO + return + + if message.delivery_timestamp and not tx_enabled: + self.message_type = MessageType.DELAY + return + + if message.message_group or message.delivery_timestamp or not tx_enabled: + pass + + self.message_type = MessageType.TRANSACTION + logger.debug(self.message_type) + + def to_protobuf(self, queue_id): + system_properties = SystemProperties( + keys=self.message.keys, + message_id=self.message_id, + # born_timestamp=Timestamp.FromDatetime(dt=datetime.datetime.utcnow()), + born_host=socket.gethostname(), + body_encoding=EncodingHelper.to_protobuf(Encoding.IDENTITY), + queue_id=queue_id, + message_type=MessageTypeHelper.to_protobuf(self.message_type) + ) + if self.message.tag: + system_properties.tag = self.message.tag + + if self.message.delivery_timestamp: + timestamp = Timestamp() + timestamp.FromDatetime(self.message.delivery_timestamp) + system_properties.delivery_timestamp.CopyFrom(timestamp) + + if self.message.message_group: + system_properties.message_group = self.message.message_group + + topic_resource = Resource(name=self.message.topic) + + return ProtoMessage( + topic=topic_resource, + body=self.message.body, + system_properties=system_properties, + user_properties=self.message.properties + ) diff --git a/python/rocketmq/rpc_client.py b/python/rocketmq/rpc_client.py index c1f129c4c..6c1107ab3 100644 --- a/python/rocketmq/rpc_client.py +++ b/python/rocketmq/rpc_client.py @@ -132,7 +132,7 @@ def _calculate_hash(self): def __str__(self): for address in self.Addresses: - return None + return str(address.host) + str(address.port) def grpc_target(self, sslEnabled): for address in self.Addresses: diff --git a/python/rocketmq/send_receipt.py b/python/rocketmq/send_receipt.py index 3aaa59789..8e742da2d 100644 --- a/python/rocketmq/send_receipt.py +++ b/python/rocketmq/send_receipt.py @@ -13,13 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +# from rocketmq.status_checker import StatusChecker +from rocketmq.log import logger from rocketmq.message_id import MessageId +from rocketmq.protocol.definition_pb2 import Code as ProtoCode class SendReceipt: - def __init__(self, message_id: MessageId): - self.__message_id = message_id + def __init__(self, message_id: MessageId, transaction_id, message_queue): + self.message_id = message_id + self.transaction_id = transaction_id + self.message_queue = message_queue @property - def message_id(self): - return self.__message_id + def endpoints(self): + return self.message_queue.broker.endpoints + + def __str__(self): + return f'MessageId: {self.message_id}' + + @staticmethod + def process_send_message_response(mq, invocation): + status = invocation.status + for entry in invocation.entries: + if entry.status.code == ProtoCode.OK: + status = entry.status + logger.debug(status) + # May throw exception. + # StatusChecker.check(status, invocation.request, invocation.request_id) + return [SendReceipt(entry.message_id, entry.transaction_id, mq) for entry in invocation.entries] diff --git a/python/rocketmq/session.py b/python/rocketmq/session.py index 50c11d986..c3ea9d32e 100644 --- a/python/rocketmq/session.py +++ b/python/rocketmq/session.py @@ -14,8 +14,8 @@ # limitations under the License. import asyncio -from threading import Event +from rocketmq.log import logger from rocketmq.protocol.service_pb2 import \ TelemetryCommand as ProtoTelemetryCommand @@ -26,12 +26,25 @@ def __init__(self, endpoints, streaming_call, client): self._semaphore = asyncio.Semaphore(1) self._streaming_call = streaming_call self._client = client - self._event = Event() + asyncio.create_task(self.loop()) + + async def loop(self): + try: + while True: + await self._streaming_call.read() + except asyncio.exceptions.InvalidStateError as e: + logger.error('Error:', e) async def write_async(self, telemetry_command: ProtoTelemetryCommand): - await self._streaming_call.write(telemetry_command) - response = await self._streaming_call.read() - print(response) + await asyncio.sleep(1) + try: + await self._streaming_call.write(telemetry_command) + # TODO handle read operation exceed the time limit + # await asyncio.wait_for(self._streaming_call.read(), timeout=5) + except asyncio.exceptions.InvalidStateError as e: + self.on_error(e) + except asyncio.TimeoutError: + logger.error('Timeout: The read operation exceeded the time limit') async def sync_settings(self, await_resp): await self._semaphore.acquire() @@ -42,3 +55,20 @@ async def sync_settings(self, await_resp): await self.write_async(telemetry_command) finally: self._semaphore.release() + + def rebuild_telemetry(self): + logger.info("Try to rebuild telemetry") + stream = self._client.client_manager.telemetry(self._endpoints, 10) + self._streaming_call = stream + + def on_error(self, exception): + client_id = self._client.get_client_id() + logger.error("Caught InvalidStateError: RPC already finished.") + logger.error(f"Exception raised from stream, clientId={client_id}, endpoints={self._endpoints}", exception) + max_retry = 3 + for i in range(max_retry): + try: + self.rebuild_telemetry() + break + except Exception as e: + logger.error(f"An error occurred during rebuilding telemetry: {e}, attempt {i + 1} of {max_retry}") diff --git a/python/rocketmq/status_checker.py b/python/rocketmq/status_checker.py new file mode 100644 index 000000000..ae2d19130 --- /dev/null +++ b/python/rocketmq/status_checker.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rocketmq.log import logger +from rocketmq.message import Message +from rocketmq.protocol.definition_pb2 import Code as ProtoCode +from rocketmq.protocol.definition_pb2 import Message as ProtoMessage +from rocketmq.protocol.definition_pb2 import Status as ProtoStatus +from rocketmq.protocol.service_pb2 import \ + ReceiveMessageRequest as ProtoReceiveMessageRequest + + +class RocketMQException(Exception): + def __init__(self, status_code, request_id, status_message): + self.status_code = status_code + self.request_id = request_id + self.status_message = status_message + + def __str__(self): + return f"{self.__class__.__name__}: code={self.status_code}, requestId={self.request_id}, message={self.status_message}" + + +class BadRequestException(RocketMQException): + pass + + +class UnauthorizedException(RocketMQException): + pass + + +class PaymentRequiredException(RocketMQException): + pass + + +class ForbiddenException(RocketMQException): + pass + + +class NotFoundException(RocketMQException): + pass + + +class PayloadTooLargeException(RocketMQException): + pass + + +class TooManyRequestsException(RocketMQException): + pass + + +class RequestHeaderFieldsTooLargeException(RocketMQException): + pass + + +class InternalErrorException(RocketMQException): + pass + + +class ProxyTimeoutException(RocketMQException): + pass + + +class UnsupportedException(RocketMQException): + pass + + +class StatusChecker: + @staticmethod + def check(status: ProtoStatus, request: Message, request_id: str): + """Check the status of a request and raise an exception if necessary. + + :param status: A ProtoStatus object that contains the status code and message. + :param request: The request message object. + :param request_id: The ID of the request. + :raise BadRequestException: If the status code indicates a bad request. + :raise UnauthorizedException: If the status code indicates an unauthorized request. + :raise PaymentRequiredException: If the status code indicates payment is required. + :raise ForbiddenException: If the status code indicates a forbidden request. + :raise NotFoundException: If the status code indicates a resource is not found. + :raise PayloadTooLargeException: If the status code indicates the request payload is too large. + :raise TooManyRequestsException: If the status code indicates too many requests. + :raise RequestHeaderFieldsTooLargeException: If the status code indicates the request headers are too large. + :raise InternalErrorException: If the status code indicates an internal error. + :raise ProxyTimeoutException: If the status code indicates a proxy timeout. + :raise UnsupportedException: If the status code indicates an unsupported operation. + """ + status_code = status.code + status_message = status.message + + if status_code in [ProtoCode.OK, ProtoCode.MULTIPLE_RESULTS]: + return + elif status_code in [ + ProtoCode.BAD_REQUEST, + ProtoCode.ILLEGAL_ACCESS_POINT, + ProtoCode.ILLEGAL_TOPIC, + ProtoCode.ILLEGAL_CONSUMER_GROUP, + ProtoCode.ILLEGAL_MESSAGE_TAG, + ProtoCode.ILLEGAL_MESSAGE_KEY, + ProtoCode.ILLEGAL_MESSAGE_GROUP, + ProtoCode.ILLEGAL_MESSAGE_PROPERTY_KEY, + ProtoCode.INVALID_TRANSACTION_ID, + ProtoCode.ILLEGAL_MESSAGE_ID, + ProtoCode.ILLEGAL_FILTER_EXPRESSION, + ProtoCode.ILLEGAL_INVISIBLE_TIME, + ProtoCode.ILLEGAL_DELIVERY_TIME, + ProtoCode.INVALID_RECEIPT_HANDLE, + ProtoCode.MESSAGE_PROPERTY_CONFLICT_WITH_TYPE, + ProtoCode.UNRECOGNIZED_CLIENT_TYPE, + ProtoCode.MESSAGE_CORRUPTED, + ProtoCode.CLIENT_ID_REQUIRED, + ProtoCode.ILLEGAL_POLLING_TIME, + ]: + raise BadRequestException(status_code, request_id, status_message) + elif status_code == ProtoCode.UNAUTHORIZED: + raise UnauthorizedException(status_code, request_id, status_message) + elif status_code == ProtoCode.PAYMENT_REQUIRED: + raise PaymentRequiredException(status_code, request_id, status_message) + elif status_code == ProtoCode.FORBIDDEN: + raise ForbiddenException(status_code, request_id, status_message) + elif status_code == ProtoCode.MESSAGE_NOT_FOUND: + if isinstance(request, ProtoReceiveMessageRequest): + return + else: + # Fall through on purpose. + status_code = ProtoCode.NOT_FOUND + if status_code in [ + ProtoCode.NOT_FOUND, + ProtoCode.TOPIC_NOT_FOUND, + ProtoCode.CONSUMER_GROUP_NOT_FOUND, + ]: + raise NotFoundException(status_code, request_id, status_message) + elif status_code in [ + ProtoCode.PAYLOAD_TOO_LARGE, + ProtoCode.MESSAGE_BODY_TOO_LARGE, + ]: + raise PayloadTooLargeException(status_code, request_id, status_message) + elif status_code == ProtoCode.TOO_MANY_REQUESTS: + raise TooManyRequestsException(status_code, request_id, status_message) + elif status_code in [ + ProtoCode.REQUEST_HEADER_FIELDS_TOO_LARGE, + ProtoCode.MESSAGE_PROPERTIES_TOO_LARGE, + ]: + raise RequestHeaderFieldsTooLargeException(status_code, request_id, status_message) + elif status_code in [ + ProtoCode.INTERNAL_ERROR, + ProtoCode.INTERNAL_SERVER_ERROR, + ProtoCode.HA_NOT_AVAILABLE, + ]: + raise InternalErrorException(status_code, request_id, status_message) + elif status_code in [ + ProtoCode.PROXY_TIMEOUT, + ProtoCode.MASTER_PERSISTENCE_TIMEOUT, + ProtoCode.SLAVE_PERSISTENCE_TIMEOUT, + ]: + raise ProxyTimeoutException(status_code, request_id, status_message) + elif status_code in [ + ProtoCode.UNSUPPORTED, + ProtoCode.VERSION_UNSUPPORTED, + ProtoCode.VERIFY_FIFO_MESSAGE_UNSUPPORTED, + ]: + raise UnsupportedException(status_code, request_id, status_message) + else: + logger.warning(f"Unrecognized status code={status_code}, requestId={request_id}, statusMessage={status_message}") + raise UnsupportedException(status_code, request_id, status_message) + + +def main(): + # 创建一个表示'OK'状态的ProtoStatus + status_ok = ProtoStatus() + status_ok.code = ProtoCode.OK + status_ok.message = "Everything is OK" + + # 创建一个表示'BadRequest'状态的ProtoStatus + status_bad_request = ProtoStatus() + status_bad_request.code = ProtoCode.BAD_REQUEST + status_bad_request.message = "Bad request" + + # 创建一个表示'Unauthorized'状态的ProtoStatus + status_unauthorized = ProtoStatus() + status_unauthorized.code = ProtoCode.UNAUTHORIZED + status_unauthorized.message = "Unauthorized" + + request = ProtoMessage() + + # 进行一些测试 + StatusChecker.check(status_ok, request, "request1") # 不应抛出异常 + + try: + StatusChecker.check(status_bad_request, request, "request2") + except BadRequestException as e: + logger.error(f"Caught expected exception: {e}") + + try: + StatusChecker.check(status_unauthorized, request, "request3") + except UnauthorizedException as e: + logger.error(f"Caught expected exception: {e}") + + +if __name__ == "__main__": + main() diff --git a/python/rocketmq/utils.py b/python/rocketmq/utils.py index dd2a4a980..e5cf3b278 100644 --- a/python/rocketmq/utils.py +++ b/python/rocketmq/utils.py @@ -39,3 +39,8 @@ def sign(access_secret: str, datetime: str) -> str: hashlib.sha1, ) return digester.hexdigest().upper() + + +def get_positive_mod(k: int, n: int): + result = k % n + return result + n if result < 0 else result diff --git a/python/tests/test_foo.py b/python/tests/test_foo.py index 70b00f6a4..89b9ea313 100644 --- a/python/tests/test_foo.py +++ b/python/tests/test_foo.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rocketmq import foo, logger +from rocketmq import foo +from rocketmq.log import logger def test_passing():