Skip to content

Commit

Permalink
🍺support interface dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueGlassBlock committed Nov 14, 2022
1 parent 1419e4b commit 81d3f68
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
28 changes: 21 additions & 7 deletions src/graia/ariadne/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from graia.broadcast.entities.dispatcher import BaseDispatcher as AbstractDispatcher
from graia.broadcast.entities.signatures import Force
from graia.broadcast.interfaces.dispatcher import DispatcherInterface
from launart import ExportInterface, Launart

from .message.chain import MessageChain
from .message.element import Quote, Source
Expand Down Expand Up @@ -47,6 +48,23 @@ async def catch(interface: DispatcherInterface):
return Ariadne.current()


class LaunartInterfaceDispatcher(AbstractDispatcher):
@staticmethod
async def catch(interface: DispatcherInterface):
from graia.ariadne.typing import Unions, get_args, get_origin

if isinstance(interface.annotation, type) and issubclass(interface.annotation, ExportInterface):
manager = Launart.current()
with contextlib.suppress(ValueError):
return manager.get_interface(interface.annotation)
elif get_origin(interface.annotation) in Unions and (types := get_args(interface.annotation)):
manager = Launart.current()
for anno in types:
if isinstance(anno, type) and issubclass(anno, ExportInterface):
with contextlib.suppress(ValueError):
return manager.get_interface(anno)


class NoneDispatcher(AbstractDispatcher):
"""给 Optional[...] 提供 None 的 Dispatcher"""

Expand Down Expand Up @@ -79,7 +97,7 @@ async def catch(interface: DispatcherInterface):
if isinstance(interface.event, (MessageEvent, ActiveMessage)) and generic_issubclass(
Quote, interface.annotation
):
return interface.event.quote or await NoneDispatcher.catch(interface)
return interface.event.quote


class SenderDispatcher(AbstractDispatcher):
Expand Down Expand Up @@ -151,7 +169,7 @@ class OperatorDispatcher(AbstractDispatcher):
@staticmethod
async def catch(interface: DispatcherInterface):
if generic_issubclass(Member, interface.annotation):
return interface.event.operator or await NoneDispatcher.catch(interface)
return interface.event.operator
elif generic_issubclass(Group, interface.annotation):
# NOTE: operator 不为 None。因为 operator 可为 None 的事件必有 group 属性,
# 会由 dispatcher 之前的 GroupDispatcher 处理,不可能进入此处。
Expand All @@ -164,10 +182,6 @@ class OperatorMemberDispatcher(AbstractDispatcher):
@staticmethod
async def catch(interface: DispatcherInterface):
if generic_issubclass(Member, interface.annotation):
return (
interface.event.operator or await NoneDispatcher.catch(interface)
if interface.name == "operator"
else interface.event.member
)
return interface.event.operator if interface.name == "operator" else interface.event.member
elif generic_issubclass(Group, interface.annotation):
return interface.event.member.group
6 changes: 1 addition & 5 deletions src/graia/ariadne/event/mirai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
FriendDispatcher,
GroupDispatcher,
MemberDispatcher,
NoneDispatcher,
OperatorDispatcher,
OperatorMemberDispatcher,
)
Expand Down Expand Up @@ -646,10 +645,7 @@ class Dispatcher(AbstractDispatcher):
@staticmethod
async def catch(interface: DispatcherInterface["MemberJoinEvent"]):
if interface.name == "inviter" and generic_issubclass(Member, interface.annotation):
if inviter := interface.event.inviter:
return inviter
elif result := await NoneDispatcher.catch(interface):
return result
return interface.event.inviter


class MemberLeaveEventKick(GroupEvent):
Expand Down
9 changes: 3 additions & 6 deletions src/graia/ariadne/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
HttpClientConnection,
)
from .connection._info import HttpClientInfo, U_Info
from .dispatcher import ContextDispatcher, NoneDispatcher
from .dispatcher import ContextDispatcher, LaunartInterfaceDispatcher, NoneDispatcher
from .exception import AriadneConfigurationError

ARIADNE_ASCII_LOGO = r"""
Expand Down Expand Up @@ -119,6 +119,8 @@ def __init__(self) -> None:

if ContextDispatcher not in self.broadcast.prelude_dispatchers:
self.broadcast.prelude_dispatchers.append(ContextDispatcher)
if LaunartInterfaceDispatcher not in self.broadcast.prelude_dispatchers:
self.broadcast.prelude_dispatchers.append(LaunartInterfaceDispatcher)
if NoneDispatcher not in self.broadcast.finale_dispatchers:
self.broadcast.finale_dispatchers.append(NoneDispatcher)

Expand Down Expand Up @@ -214,11 +216,6 @@ async def launch(self, mgr: Launart):
self.base_telemetry()
async with self.stage("preparing"):
self.http_interface = mgr.get_interface(AiohttpClientInterface)
if self.broadcast:
if asyncio.get_running_loop() is not self.loop:
raise AriadneConfigurationError("Broadcast is attached to a different loop")
else:
self.broadcast = Broadcast(loop=self.loop)
if "default_account" in Ariadne.options:
app = Ariadne.current()
with enter_context(app=app):
Expand Down
5 changes: 3 additions & 2 deletions src/test_old/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from graia.saya.context import channel_instance
from loguru import logger

from graia.ariadne.connection import ConnectionInterface
from graia.ariadne.entry import *
from graia.ariadne.message.exp import MessageChain as ExpMessageChain
from graia.ariadne.message.parser.base import RegexGroup, StartsWith
Expand Down Expand Up @@ -200,10 +201,10 @@ async def regex(app: Ariadne, chain: Annotated[MessageChain, RegexGroup("args")]
await app.send_friend_message(target, chain)

@bcc.receiver(GroupMessage, decorators=[StartsWith(".test exp")])
async def exp(app: Ariadne, ev: GroupMessage, exp_c: ExpMessageChain):
async def exp(app: Ariadne, ev: GroupMessage, exp_c: ExpMessageChain, interf: ConnectionInterface):
await app.send_message(ev, repr(exp_c.content))
res = await app.send_message(ev, repr(ev))
await app.send_message(ev, repr(res))
await app.send_message(ev, [repr(res), repr(interf)])

@bcc.receiver(ApplicationLaunch)
async def m(app: Ariadne):
Expand Down

0 comments on commit 81d3f68

Please sign in to comment.