Skip to content

Commit

Permalink
allow forward_fn to accept more arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Apr 21, 2024
1 parent a90c319 commit 55704e9
Showing 1 changed file with 24 additions and 60 deletions.
84 changes: 24 additions & 60 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,27 +466,19 @@ def verify_custom(synapse: MyCustomSynapse):
offered by this method allows developers to tailor the Axon's behavior to specific requirements and
use cases.
"""

# Assert 'forward_fn' has exactly one argument
forward_sig = signature(forward_fn)
assert (
len(list(forward_sig.parameters)) == 1
), "The passed function must have exactly one argument"

# Obtain the class of the first argument of 'forward_fn'
request_class = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation
try:
first_param = next(iter(forward_sig.parameters.values()))
except StopIteration:
raise ValueError(
"The forward function first argument must be a subclass of bittensor.Synapse, but it has no arguments"
)

# Assert that the first argument of 'forward_fn' is a subclass of 'bittensor.Synapse'
param_class = first_param.annotation
assert issubclass(
request_class, bittensor.Synapse
), "The argument of forward_fn must inherit from bittensor.Synapse"

# Obtain the class name of the first argument of 'forward_fn'
request_name = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation.__name__
param_class, bittensor.Synapse
), "The first argument of forward_fn must inherit from bittensor.Synapse"
request_name = param_class.__name__

# Add the endpoint to the router, making it available on both GET and POST methods
self.router.add_api_route(
Expand All @@ -497,68 +489,40 @@ def verify_custom(synapse: MyCustomSynapse):
)
self.app.include_router(self.router)

# Expected signatures for 'blacklist_fn', 'priority_fn' and 'verify_fn'
blacklist_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=Tuple[bool, str],
)
priority_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=float,
)
verify_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=None,
)

# Check the signature of blacklist_fn, priority_fn and verify_fn if they are provided
expected_params = [
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
]
if blacklist_fn:
blacklist_sig = Signature(expected_params, return_annotation=Tuple[bool, str])
assert (
signature(blacklist_fn) == blacklist_sig
), "The blacklist_fn function must have the signature: blacklist( synapse: {} ) -> Tuple[bool, str]".format(
request_name
)
if priority_fn:
priority_sig = Signature(expected_params, return_annotation=float)
assert (
signature(priority_fn) == priority_sig
), "The priority_fn function must have the signature: priority( synapse: {} ) -> float".format(
request_name
)
if verify_fn:
verify_sig = Signature(expected_params, return_annotation=None)
assert (
signature(verify_fn) == verify_sig
), "The verify_fn function must have the signature: verify( synapse: {} ) -> None".format(
request_name
)

# Store functions in appropriate attribute dictionaries
self.forward_class_types[request_name] = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation
self.forward_class_types[request_name] = param_class
self.blacklist_fns[request_name] = blacklist_fn
self.priority_fns[request_name] = priority_fn
self.verify_fns[request_name] = (
Expand All @@ -567,7 +531,7 @@ def verify_custom(synapse: MyCustomSynapse):
self.forward_fns[request_name] = forward_fn

# Parse required hash fields from the forward function protocol defaults
required_hash_fields = request_class.__dict__["__fields__"][
required_hash_fields = param_class.__dict__["__fields__"][
"required_hash_fields"
].default
self.required_hash_fields[request_name] = required_hash_fields
Expand Down

0 comments on commit 55704e9

Please sign in to comment.