Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General message engineering in Producer #559

Merged
merged 11 commits into from
Jul 25, 2023
199 changes: 193 additions & 6 deletions python/rocketmq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,137 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import threading
from typing import Set

from protocol import service_pb2
from protocol.service_pb2 import QueryRouteRequest
from protocol import definition_pb2, service_pb2
from protocol.service_pb2 import HeartbeatRequest, QueryRouteRequest
from rocketmq.client_config import ClientConfig
from rocketmq.client_id_encoder import ClientIdEncoder
from rocketmq.definition import TopicRouteData
from rocketmq.definition import Resource, TopicRouteData
from rocketmq.log import logger
from rocketmq.rpc_client import Endpoints, RpcClient
from rocketmq.session import Session
from rocketmq.signature import Signature


class ScheduleWithFixedDelay:
def __init__(self, action, delay, period):
self.action = action
self.delay = delay
self.period = period
self.task = None

async def start(self):
# await asyncio.sleep(self.delay)
while True:
await self.action()
await asyncio.sleep(self.period)

def schedule(self):
loop1 = asyncio.new_event_loop()
asyncio.set_event_loop(loop1)
self.task = asyncio.create_task(self.start())

def cancel(self):
if self.task:
self.task.cancel()


class Client:
"""
Main client class which handles interaction with the server.
"""
def __init__(self, client_config: ClientConfig, topics: Set[str]):
"""
Initialization method for the Client class.

:param client_config: Client configuration.
:param topics: Set of topics that the client is subscribed to.
"""
self.client_config = client_config
self.client_id = ClientIdEncoder.generate()
self.endpoints = client_config.endpoints
self.topics = topics

#: A cache to store topic routes.
self.topic_route_cache = {}

#: A table to store session information.
self.sessions_table = {}
self.sessionsLock = threading.Lock()
self.client_manager = ClientManager(self)

#: A dictionary to store isolated items.
self.isolated = dict()

async def start(self):
"""
Start method which initiates fetching of topic routes and schedules heartbeats.
"""
# get topic route
for topic in self.topics:
self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12)
scheduler.schedule()

async def heartbeat(self):
"""
Asynchronous method that sends a heartbeat to the server.
"""
try:
endpoints = self.GetTotalRouteEndpoints()
request = HeartbeatRequest()
request.client_type = definition_pb2.PRODUCER
topic = Resource()
topic.name = "normal_topic"
invocations = {}
logger.info(len(endpoints))
# Collect task into a map.
for item in endpoints:
task = await self.client_manager.heartbeat(item, request, self.client_config.request_timeout)
invocations[item] = task
logger.info(task)
logger.info("finish")
break
except Exception as e:
logger.error(f"[Bug] unexpected exception raised during heartbeat, clientId={self.client_id}, Exception: {str(e)}")

def GetTotalRouteEndpoints(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use snake case naming rather than camel naming.

"""
Method that returns all route endpoints.
"""
endpoints = set()
for item in self.topic_route_cache.items():
for endpoint in [mq.broker.endpoints for mq in item[1].message_queues]:
endpoints.add(endpoint)
return endpoints

async def get_route_data(self, topic):
"""
Asynchronous method that fetches route data for a given topic.

:param topic: The topic to fetch route data for.
"""
if topic in self.topic_route_cache:
return self.topic_route_cache[topic]
topic_route_data = await self.fetch_topic_route(topic=topic)
return topic_route_data

def get_client_config(self):
"""
Method to return client configuration.
"""
return self.client_config

async def OnTopicRouteDataFetched(self, topic, topicRouteData):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the naming issue...

"""
Asynchronous method that handles the process once the topic route data is fetched.

:param topic: The topic for which the route data is fetched.
:param topicRouteData: The fetched topic route data.
"""
route_endpoints = set()
for mq in topicRouteData.message_queues:
route_endpoints.add(mq.broker.endpoints)
Expand All @@ -70,9 +159,13 @@ async def OnTopicRouteDataFetched(self, topic, topicRouteData):
await session.sync_settings(True)

self.topic_route_cache[topic] = topicRouteData
# self.OnTopicRouteDataUpdated0(topic, topicRouteData)

async def fetch_topic_route0(self, topic):
"""
Asynchronous method that fetches the topic route.

:param topic: The topic to fetch the route for.
"""
req = QueryRouteRequest()
req.topic.name = topic
address = req.endpoints.addresses.add()
Expand All @@ -84,13 +177,22 @@ async def fetch_topic_route0(self, topic):
message_queues = response.message_queues
return TopicRouteData(message_queues)

# return topic data
async def fetch_topic_route(self, topic):
"""
Asynchronous method that fetches the topic route and updates the data.

:param topic: The topic to fetch the route for.
"""
topic_route_data = await self.fetch_topic_route0(topic)
await self.OnTopicRouteDataFetched(topic, topic_route_data)
return topic_route_data

async def GetSession(self, endpoints):
"""
Asynchronous method that gets the session for a given endpoint.

:param endpoints: The endpoints to get the session for.
"""
self.sessionsLock.acquire()
try:
# Session exists, return in advance.
Expand All @@ -114,12 +216,29 @@ async def GetSession(self, endpoints):


class ClientManager:
"""Manager class for RPC Clients in a thread-safe manner.
Each instance is created by a specific client and can manage
multiple RPC clients.
"""

def __init__(self, client: Client):
#: The client that instantiated this manager.
self.__client = client

#: A dictionary that maps endpoints to the corresponding RPC clients.
self.__rpc_clients = {}

#: A lock used to ensure thread safety when accessing __rpc_clients.
self.__rpc_clients_lock = threading.Lock()

def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
"""Retrieve the RPC client corresponding to the given endpoints.
If not present, a new RPC client is created and stored in __rpc_clients.

:param endpoints: The endpoints associated with the RPC client.
:param ssl_enabled: A flag indicating whether SSL is enabled.
:return: The RPC client associated with the given endpoints.
"""
with self.__rpc_clients_lock:
rpc_client = self.__rpc_clients.get(endpoints)
if rpc_client:
Expand All @@ -134,10 +253,16 @@ async def query_route(
request: service_pb2.QueryRouteRequest,
timeout_seconds: int,
):
"""Query the routing information.

:param endpoints: The endpoints to query.
:param request: The request containing the details of the query.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the query.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)

metadata = Signature.sign(self.__client.client_config, self.__client.client_id)
return await rpc_client.query_route(request, metadata, timeout_seconds)

Expand All @@ -147,6 +272,13 @@ async def heartbeat(
request: service_pb2.HeartbeatRequest,
timeout_seconds: int,
):
"""Send a heartbeat to the server to indicate that the client is still alive.

:param endpoints: The endpoints to send the heartbeat to.
:param request: The request containing the details of the heartbeat.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the heartbeat.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -159,6 +291,13 @@ async def send_message(
request: service_pb2.SendMessageRequest,
timeout_seconds: int,
):
"""Send a message to the server.

:param endpoints: The endpoints to send the message to.
:param request: The request containing the details of the message.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the message sending operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -171,6 +310,13 @@ async def query_assignment(
request: service_pb2.QueryAssignmentRequest,
timeout_seconds: int,
):
"""Query the assignment information.

:param endpoints: The endpoints to query.
:param request: The request containing the details of the query.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the query.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -183,6 +329,13 @@ async def ack_message(
request: service_pb2.AckMessageRequest,
timeout_seconds: int,
):
"""Send an acknowledgment for a message to the server.

:param endpoints: The endpoints to send the acknowledgment to.
:param request: The request containing the details of the acknowledgment.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the acknowledgment.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -195,6 +348,13 @@ async def forward_message_to_dead_letter_queue(
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
timeout_seconds: int,
):
"""Forward a message to the dead letter queue.

:param endpoints: The endpoints to send the request to.
:param request: The request containing the details of the message to forward.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the forward operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -209,6 +369,13 @@ async def end_transaction(
request: service_pb2.EndTransactionRequest,
timeout_seconds: int,
):
"""Ends a transaction.

:param endpoints: The endpoints to send the request to.
:param request: The request to end the transaction.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the end transaction operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -221,6 +388,13 @@ async def notify_client_termination(
request: service_pb2.NotifyClientTerminationRequest,
timeout_seconds: int,
):
"""Notify server about client termination.

:param endpoints: The endpoints to send the notification to.
:param request: The request containing the details of the termination.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the notification operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -235,6 +409,13 @@ async def change_invisible_duration(
request: service_pb2.ChangeInvisibleDurationRequest,
timeout_seconds: int,
):
"""Change the invisible duration of a message.

:param endpoints: The endpoints to send the request to.
:param request: The request containing the new invisible duration.
:param timeout_seconds: The maximum time to wait for a response.
:return: The result of the change operation.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand All @@ -248,6 +429,12 @@ def telemetry(
endpoints: Endpoints,
timeout_seconds: int,
):
"""Fetch telemetry information.

:param endpoints: The endpoints to send the request to.
:param timeout_seconds: The maximum time to wait for a response.
:return: The telemetry information.
"""
rpc_client = self.__get_rpc_client(
endpoints, self.__client.client_config.ssl_enabled
)
Expand Down
Loading