diff --git a/bittensor/axon.py b/bittensor/axon.py index 34ce9e51f1..be7711c611 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -466,27 +466,19 @@ def verify_custom(synapse: MyCustomSynapse): offered by this method allows developers to tailor the Axon's behavior to specific requirements and use cases. """ - - # Assert 'forward_fn' has exactly one argument forward_sig = signature(forward_fn) - assert ( - len(list(forward_sig.parameters)) == 1 - ), "The passed function must have exactly one argument" - - # Obtain the class of the first argument of 'forward_fn' - request_class = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation + try: + first_param = next(iter(forward_sig.parameters.values())) + except StopIteration: + raise ValueError( + "The forward_fn first argument must be a subclass of bittensor.Synapse, but it has no arguments" + ) - # Assert that the first argument of 'forward_fn' is a subclass of 'bittensor.Synapse' + param_class = first_param.annotation assert issubclass( - request_class, bittensor.Synapse - ), "The argument of forward_fn must inherit from bittensor.Synapse" - - # Obtain the class name of the first argument of 'forward_fn' - request_name = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation.__name__ + param_class, bittensor.Synapse + ), "The first argument of forward_fn must inherit from bittensor.Synapse" + request_name = param_class.__name__ # Add the endpoint to the router, making it available on both GET and POST methods self.router.add_api_route( @@ -497,58 +489,34 @@ def verify_custom(synapse: MyCustomSynapse): ) self.app.include_router(self.router) - # Expected signatures for 'blacklist_fn', 'priority_fn' and 'verify_fn' - blacklist_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=Tuple[bool, str], - ) - priority_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=float, - ) - verify_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=None, - ) - # Check the signature of blacklist_fn, priority_fn and verify_fn if they are provided + expected_params = [ + Parameter( + "synapse", + Parameter.POSITIONAL_OR_KEYWORD, + annotation=forward_sig.parameters[ + list(forward_sig.parameters)[0] + ].annotation, + ) + ] if blacklist_fn: + blacklist_sig = Signature( + expected_params, return_annotation=Tuple[bool, str] + ) assert ( signature(blacklist_fn) == blacklist_sig ), "The blacklist_fn function must have the signature: blacklist( synapse: {} ) -> Tuple[bool, str]".format( request_name ) if priority_fn: + priority_sig = Signature(expected_params, return_annotation=float) assert ( signature(priority_fn) == priority_sig ), "The priority_fn function must have the signature: priority( synapse: {} ) -> float".format( request_name ) if verify_fn: + verify_sig = Signature(expected_params, return_annotation=None) assert ( signature(verify_fn) == verify_sig ), "The verify_fn function must have the signature: verify( synapse: {} ) -> None".format( @@ -556,9 +524,7 @@ def verify_custom(synapse: MyCustomSynapse): ) # Store functions in appropriate attribute dictionaries - self.forward_class_types[request_name] = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation + self.forward_class_types[request_name] = param_class self.blacklist_fns[request_name] = blacklist_fn self.priority_fns[request_name] = priority_fn self.verify_fns[request_name] = ( @@ -567,7 +533,7 @@ def verify_custom(synapse: MyCustomSynapse): self.forward_fns[request_name] = forward_fn # Parse required hash fields from the forward function protocol defaults - required_hash_fields = request_class.__dict__["__fields__"][ + required_hash_fields = param_class.__dict__["__fields__"][ "required_hash_fields" ].default self.required_hash_fields[request_name] = required_hash_fields