diff --git a/opentaxii/auth/sqldb/api.py b/opentaxii/auth/sqldb/api.py index 82904fa3..29808341 100644 --- a/opentaxii/auth/sqldb/api.py +++ b/opentaxii/auth/sqldb/api.py @@ -47,7 +47,7 @@ def __init__( def authenticate(self, username, password): try: - account = Account.query.filter_by(username=username).one() + account = self.db.session.query(Account).filter_by(username=username).one() except exc.NoResultFound: return if not account.is_password_valid(password): @@ -55,7 +55,7 @@ def authenticate(self, username, password): return self._generate_token(account.id, ttl=self.token_ttl_secs) def create_account(self, username, password, is_admin=False): - account = Account(username=username, is_admin=is_admin) + account = Account(username=username, is_admin=is_admin, permissions={}) account.set_password(password) self.db.session.add(account) self.db.session.commit() @@ -65,13 +65,13 @@ def get_account(self, token): account_id = self._get_account_id(token) if not account_id: return - account = Account.query.get(account_id) + account = self.db.session.query(Account).get(account_id) if not account: return return account_to_account_entity(account) def delete_account(self, username): - account = Account.query.filter_by(username=username).one_or_none() + account = self.db.session.query(Account).filter_by(username=username).one_or_none() if account: self.db.session.delete(account) self.db.session.commit() @@ -79,10 +79,10 @@ def delete_account(self, username): def get_accounts(self): return [ account_to_account_entity(account) - for account in Account.query.all()] + for account in self.db.session.query(Account).all()] def update_account(self, obj, password=None): - account = Account.query.filter_by(username=obj.username).one_or_none() + account = self.db.session.query(Account).filter_by(username=obj.username).one_or_none() if not account: account = Account(username=obj.username) self.db.session.add(account) diff --git a/opentaxii/cli/auth.py b/opentaxii/cli/auth.py index 513d793d..428b81fe 100644 --- a/opentaxii/cli/auth.py +++ b/opentaxii/cli/auth.py @@ -1,6 +1,5 @@ -import sys - import argparse +import sys from opentaxii.cli import app @@ -61,7 +60,7 @@ def update_account(argv=None): return if args.field == 'admin': account.is_admin = is_truely(args.value) - app.taxii_server.auth.update_account(account) + account = app.taxii_server.auth.update_account(account, None) if account.is_admin: print('now user is admin') else: diff --git a/opentaxii/cli/persistence.py b/opentaxii/cli/persistence.py index bbe436e1..8cfeaaae 100644 --- a/opentaxii/cli/persistence.py +++ b/opentaxii/cli/persistence.py @@ -1,30 +1,34 @@ import argparse + import structlog import yaml - -from opentaxii.entities import Account from opentaxii.cli import app +from opentaxii.entities import Account from opentaxii.local import context from opentaxii.utils import sync_conf_dict_into_db log = structlog.getLogger(__name__) -local_admin = Account( - id=None, username="local-admin", permissions=None, is_admin=True) +local_admin = Account(id=None, username="local-admin", permissions=None, is_admin=True) def sync_data_configuration(): parser = argparse.ArgumentParser( description="Create services/collections/accounts", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "config", help="YAML file with data configuration") + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("config", help="YAML file with data configuration") parser.add_argument( - "-f", "--force-delete", dest="force_deletion", + "-f", + "--force-delete", + dest="force_deletion", action="store_true", - help=("force deletion of collections and their content blocks " - "if collection is not defined in configuration file"), - required=False) + help=( + "force deletion of collections and their content blocks " + "if collection is not defined in configuration file" + ), + required=False, + ) args = parser.parse_args() with open(args.config) as stream: config = yaml.safe_load(stream=stream) @@ -33,9 +37,8 @@ def sync_data_configuration(): # run as admin with full access context.account = local_admin sync_conf_dict_into_db( - app.taxii_server, - config, - force_collection_deletion=args.force_deletion) + app.taxii_server, config, force_collection_deletion=args.force_deletion + ) def delete_content_blocks(): @@ -43,31 +46,106 @@ def delete_content_blocks(): parser = argparse.ArgumentParser( description=( "Delete content blocks from specified collections " - "with timestamp labels matching defined time window"), - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + "with timestamp labels matching defined time window" + ), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - "-c", "--collection", action="append", dest="collection", - help="Collection to remove content blocks from", required=True) + "-c", + "--collection", + action="append", + dest="collection", + help="Collection to remove content blocks from", + required=True, + ) parser.add_argument( - "-m", "--with-messages", dest="delete_inbox_messages", + "-m", + "--with-messages", + dest="delete_inbox_messages", action="store_true", help=("delete inbox messages associated with deleted content blocks"), - required=False) + required=False, + ) parser.add_argument( - "--begin", dest="begin", + "--begin", + dest="begin", help="exclusive beginning of time window as ISO8601 formatted date", - required=True) + required=True, + ) parser.add_argument( - "--end", dest="end", - help="inclusive ending of time window as ISO8601 formatted date") + "--end", + dest="end", + help="inclusive ending of time window as ISO8601 formatted date", + ) args = parser.parse_args() with app.app_context(): start_time = args.begin end_time = args.end for collection in args.collection: - app.taxii_server.persistence.delete_content_blocks( + app.taxii_server.servers.taxii1.persistence.delete_content_blocks( collection, with_messages=args.delete_inbox_messages, start_time=start_time, - end_time=end_time) + end_time=end_time, + ) + + +def add_api_root(): + """CLI command to add taxii2 api root to database.""" + parser = argparse.ArgumentParser( + description=("Add a new taxii2 ApiRoot object."), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-t", "--title", required=True, help="Title of the api root") + parser.add_argument( + "-d", "--description", required=False, help="Description of the api root" + ) + parser.add_argument( + "--default", action="store_true", help="Set as default api root" + ) + + args = parser.parse_args() + with app.app_context(): + app.taxii_server.servers.taxii2.persistence.api.add_api_root( + title=args.title, description=args.description, default=args.default + ) + + +def add_collection(): + """CLI command to add taxii2 collection to database.""" + existing_api_root_ids = [ + str(api_root.id) + for api_root in app.taxii_server.servers.taxii2.persistence.api.get_api_roots() + ] + parser = argparse.ArgumentParser( + description=("Add a new taxii2 Collection object."), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-r", + "--rootid", + choices=existing_api_root_ids, + required=True, + help="Api root id of the collection", + ) + parser.add_argument("-t", "--title", required=True, help="Title of the collection") + parser.add_argument( + "-d", "--description", required=False, help="Description of the collection" + ) + parser.add_argument("-a", "--alias", required=False, help="alias of the collection") + + args = parser.parse_args() + with app.app_context(): + app.taxii_server.servers.taxii2.persistence.api.add_collection( + api_root_id=args.rootid, + title=args.title, + description=args.description, + alias=args.alias, + ) + + +def job_cleanup(): + """CLI command to clean up taxii2 job logs that are >24h old.""" + number_removed = app.taxii_server.servers.taxii2.persistence.api.job_cleanup() + print(f"{number_removed} removed") diff --git a/opentaxii/common/entities.py b/opentaxii/common/entities.py index 4976d419..2d3a5e4a 100644 --- a/opentaxii/common/entities.py +++ b/opentaxii/common/entities.py @@ -1,8 +1,29 @@ +def sorted_dicts(obj): + """ + sort all dicts contained in obj, for repeatable repr + """ + if isinstance(obj, dict): + response = {} + for key, value in sorted(obj.items()): + value = sorted_dicts(value) + response[key] = value + elif isinstance(obj, (list, tuple)): + response = type(obj)(sorted_dicts(item) for item in obj) + else: + response = obj + return response + class Entity: '''Abstract TAXII entity class. ''' def __repr__(self): - pairs = ["%s=%s" % (k, v) for k, v in sorted(self.__dict__.items())] + pairs = ["%s=%s" % (k, v) for k, v in sorted(sorted_dicts(self.__dict__).items())] return "%s(%s)" % (self.__class__.__name__, ", ".join(pairs)) + + def to_dict(self): + return {key: value for key, value in self.__dict__.items()} + + def __eq__(self, other): + return repr(self) == repr(other) diff --git a/opentaxii/config.py b/opentaxii/config.py index 8621bbb5..b721aa2e 100644 --- a/opentaxii/config.py +++ b/opentaxii/config.py @@ -57,7 +57,12 @@ class ServerConfig(dict): "unauthorized_status", "hooks", ) - VALID_TAXII2_OPTIONS = ("max_content_length",) + VALID_TAXII2_OPTIONS = ( + "contact", + "description", + "max_content_length", + "title", + ) ALL_VALID_OPTIONS = VALID_BASE_OPTIONS + VALID_TAXII_OPTIONS + VALID_TAXII1_OPTIONS def __init__(self, optional_env_var=CONFIG_ENV_VAR, extra_configs=None): diff --git a/opentaxii/middleware.py b/opentaxii/middleware.py index 443f2ca0..02f8ebb1 100644 --- a/opentaxii/middleware.py +++ b/opentaxii/middleware.py @@ -2,6 +2,9 @@ import structlog from flask import Flask, request +from marshmallow.exceptions import \ + ValidationError as MarshmallowValidationError +from werkzeug.exceptions import HTTPException from .exceptions import InvalidAuthHeader from .local import context, release_context @@ -31,13 +34,15 @@ def create_app(server): "/", "opentaxii_services_view", server.handle_request, - methods=["POST", "OPTIONS"], + methods=["GET", "POST", "OPTIONS", "DELETE"], ) app.register_blueprint(management, url_prefix="/management") app.register_error_handler(500, server.handle_internal_error) app.register_error_handler(StatusMessageException, server.handle_status_exception) + app.register_error_handler(HTTPException, server.handle_http_exception) + app.register_error_handler(MarshmallowValidationError, server.handle_validation_exception) app.before_request(functools.partial(create_context_before_request, server)) app.after_request(cleanup_context) return app diff --git a/opentaxii/persistence/api.py b/opentaxii/persistence/api.py index d1ede28f..84bd33ba 100644 --- a/opentaxii/persistence/api.py +++ b/opentaxii/persistence/api.py @@ -1,16 +1,23 @@ +import datetime +from typing import Dict, List, Optional, Tuple + +from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, JobDetail, + ManifestRecord, STIXObject, + VersionRecord) + class OpenTAXIIPersistenceAPI: - '''Abstract class that represents OpenTAXII Persistence API. + """Abstract class that represents OpenTAXII Persistence API. This class defines required methods that need to exist in a specific Persistence API implementation. - ''' + """ def init_app(self, app): pass def create_service(self, service_entity): - '''Create a service. + """Create a service. NOTE: Additional data management method that is not used in TAXII server logic but only in helper scripts. @@ -19,11 +26,11 @@ def create_service(self, service_entity): service entity in question :return: updated service entity, with ID field not None :rtype: :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ raise NotImplementedError() def create_collection(self, collection_entity): - '''Create a collection. + """Create a collection. NOTE: Additional data management method that is not used in TAXII server logic but only in helper scripts. @@ -32,11 +39,11 @@ def create_collection(self, collection_entity): collection entity in question :return: updated collection entity, with ID field not None :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ raise NotImplementedError() def attach_collection_to_services(self, collection_id, service_ids): - '''Attach collection to the services. + """Attach collection to the services. NOTE: Additional data management method that is not used in TAXII server logic but only in helper scripts. @@ -45,33 +52,33 @@ def attach_collection_to_services(self, collection_id, service_ids): :param list service_ids: collection entity in question :return: updated collection entity, with ID field not None :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ raise NotImplementedError() def get_services(self, collection_id=None): - '''Get configured services. + """Get configured services. :param str collection_id: get only services assigned to collection with provided ID :return: list of service entities. :rtype: list of :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ raise NotImplementedError() def get_collections(self, service_id=None): - '''Get the collections. If `service_id` is provided, return collection + """Get the collections. If `service_id` is provided, return collection attached to a service. :param str service_id: ID of a service in question :return: list of collection entities. :rtype: list of :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ raise NotImplementedError() def get_collection(self, collection_name, service_id=None): - '''Get a collection by name and service ID. + """Get a collection by name and service ID. Collection name is unique globally, so can be used as a key. Method retrieves collection entity using collection name @@ -82,39 +89,40 @@ def get_collection(self, collection_name, service_id=None): :return: collection entity :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ raise NotImplementedError() def update_collection(self, collection_entity): - '''Update collection. + """Update collection. :param `opentaxii.taxii.entities.CollectionEntity` collection_entity: collection entity object :return: updated collection entity :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ raise NotImplementedError() def delete_collection(self, collection_name): - '''Delete collection. + """Delete collection. :param int collection_id: id of a collection object - ''' + """ pass def create_inbox_message(self, inbox_message_entity): - '''Create an inbox message. + """Create an inbox message. :param `opentaxii.taxii.entities.InboxMessageEntity` \ inbox_message_entity: inbox message entity in question :return: updated inbox message entity :rtype: :py:class:`opentaxii.taxii.entities.InboxMessageEntity` - ''' + """ raise NotImplementedError() - def create_content_block(self, content_block_entity, collection_ids=None, - service_id=None): - '''Create a content block. + def create_content_block( + self, content_block_entity, collection_ids=None, service_id=None + ): + """Create a content block. :param `opentaxii.taxii.entities.ContentBlockEntity` \ content_block_entity: content block in question @@ -124,12 +132,13 @@ def create_content_block(self, content_block_entity, collection_ids=None, :return: updated content block entity :rtype: :py:class:`opentaxii.taxii.entities.ContentBlockEntity` - ''' + """ raise NotImplementedError() - def get_content_blocks_count(self, collection_id, start_time=None, - end_time=None, bindings=None): - '''Get a count of the content blocks associated with a collection. + def get_content_blocks_count( + self, collection_id, start_time=None, end_time=None, bindings=None + ): + """Get a count of the content blocks associated with a collection. :param str collection_id: ID fo a collection in question :param datetime start_time: start of a time frame @@ -139,12 +148,19 @@ def get_content_blocks_count(self, collection_id, start_time=None, :return: content block count :rtype: int - ''' + """ raise NotImplementedError() - def get_content_blocks(self, collection_id, start_time=None, end_time=None, - bindings=None, offset=0, limit=10): - '''Get the content blocks associated with a collection. + def get_content_blocks( + self, + collection_id, + start_time=None, + end_time=None, + bindings=None, + offset=0, + limit=10, + ): + """Get the content blocks associated with a collection. :param str collection_id: ID fo a collection in question :param datetime start_time: start of a time frame @@ -156,91 +172,92 @@ def get_content_blocks(self, collection_id, start_time=None, end_time=None, :return: content blocks list :rtype: list of :py:class:`opentaxii.taxii.entities.ContentBlockEntity` - ''' + """ raise NotImplementedError() def create_result_set(self, result_set_entity): - '''Create a result set. + """Create a result set. :param `opentaxii.taxii.entities.ResultSetEntity` result_set_entity: result set entity in question :return: updated result set entity :rtype: :py:class:`opentaxii.taxii.entities.ResultSetEntity` - ''' + """ raise NotImplementedError() def get_result_set(self, result_set_id): - '''Get a result set entity by ID. + """Get a result set entity by ID. :param str result_set_id: ID of a result set. :return: result set entity :rtype: :py:class:`opentaxii.taxii.entities.ResultSetEntity` - ''' + """ raise NotImplementedError() def create_subscription(self, subscription_entity): - '''Create a subscription. + """Create a subscription. :param `opentaxii.taxii.entities.SubscriptionEntity` \ subscription_entity: subscription entity in question. :return: updated subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ raise NotImplementedError() def get_subscription(self, subscription_id): - '''Get a subscription entity by ID. + """Get a subscription entity by ID. :param str subscription_id: ID of a subscription :return: subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ raise NotImplementedError() def get_subscriptions(self, service_id): - '''Get the subscriptions attached to/created via a service. + """Get the subscriptions attached to/created via a service. :param str service_id: ID of a service :return: list of subscription entities :rtype: list of :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ raise NotImplementedError() def update_subscription(self, subscription_entity): - '''Update a subscription status. + """Update a subscription status. :param `opentaxii.taxii.entities.SubscriptionEntity` \ subscription_entity: subscription entity in question :return: updated subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ raise NotImplementedError() def get_domain(self, service_id): - '''Get configured domain needed to create absolute URLs. + """Get configured domain needed to create absolute URLs. Returns `None` by default. :param str service_id: ID of a service - ''' + """ return None - def delete_content_blocks(self, collection_name, start_time, - end_time=None, with_messages=False): - '''Delete content blocks in a specified collection with + def delete_content_blocks( + self, collection_name, start_time, end_time=None, with_messages=False + ): + """Delete content blocks in a specified collection with timestamp label in a specified time frame. :param str collection_name: collection name :param datetime start_time: exclusive beginning of a timeframe :param datetime end_time: inclusive end of a timeframe :param bool with_messages: delete related inbox messages - ''' + """ pass @@ -250,4 +267,88 @@ class OpenTAXII2PersistenceAPI: Stub, pending implementation. """ - pass + + def get_api_roots(self) -> List[ApiRoot]: + raise NotImplementedError + + def get_api_root(self, api_root_id: str) -> Optional[ApiRoot]: + raise NotImplementedError + + def get_job_and_details( + self, api_root_id: str, job_id: str + ) -> Tuple[Optional[Job], List[JobDetail]]: + raise NotImplementedError + + def get_collections(self, api_root_id: str) -> List[Collection]: + raise NotImplementedError + + def get_collection( + self, api_root_id: str, collection_id_or_alias: str + ) -> Optional[Collection]: + raise NotImplementedError + + def get_manifest( + self, + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[ManifestRecord], bool]: + raise NotImplementedError + + def get_objects( + self, + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[STIXObject], bool]: + raise NotImplementedError + + def add_objects(self, api_root_id: str, collection_id: str, objects: List[Dict]) -> Tuple[Job, List[JobDetail]]: + raise NotImplementedError + + def get_object( + self, + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[Optional[List[STIXObject]], bool]: + """ + Get all versions of single object from database. + + Should return `None` when object matching object_id doesn't exist. + """ + raise NotImplementedError + + def delete_object( + self, + collection_id: str, + object_id: str, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> None: + raise NotImplementedError + + def get_versions( + self, + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[VersionRecord], bool]: + raise NotImplementedError diff --git a/opentaxii/persistence/exceptions.py b/opentaxii/persistence/exceptions.py index a0ca4c89..55394fbf 100644 --- a/opentaxii/persistence/exceptions.py +++ b/opentaxii/persistence/exceptions.py @@ -1,3 +1,19 @@ class ResultsNotReady(Exception): pass + + +class DoesNotExist(Exception): + pass + + +class NoReadPermission(Exception): + pass + + +class NoWritePermission(Exception): + pass + + +class NoReadNoWritePermission(NoReadPermission, NoWritePermission): + pass diff --git a/opentaxii/persistence/manager.py b/opentaxii/persistence/manager.py index dcb9e4f4..e29705a3 100644 --- a/opentaxii/persistence/manager.py +++ b/opentaxii/persistence/manager.py @@ -1,7 +1,17 @@ +import datetime +from typing import Dict, List, NamedTuple, Optional, Tuple + import structlog from opentaxii.local import context +from opentaxii.persistence.exceptions import (DoesNotExist, + NoReadNoWritePermission, + NoReadPermission, + NoWritePermission) from opentaxii.signals import (CONTENT_BLOCK_CREATED, INBOX_MESSAGE_CREATED, SUBSCRIPTION_CREATED) +from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, JobDetail, + ManifestRecord, STIXObject, + VersionRecord) log = structlog.getLogger(__name__) @@ -11,120 +21,120 @@ class BasePersistenceManager: class Taxii1PersistenceManager(BasePersistenceManager): - '''Manager responsible for persisting and retrieving data. + """Manager responsible for persisting and retrieving data. Manager uses API instance ``api`` for basic data CRUD operations and provides additional logic on top. :param `opentaxii.persistence.api.OpenTAXIIPersistenceAPI` api: instance of persistence API class - ''' + """ def __init__(self, server, api): self.server = server self.api = api def create_service(self, service_entity): - '''Create service. + """Create service. :param `opentaxii.taxii.entities.ServiceEntity` service_entity: service entity object :return: created collection entity :rtype: :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ return self.api.create_service(service_entity) def update_service(self, service_entity): - '''Update service. + """Update service. :param `opentaxii.taxii.entities.ServiceEntity` service_entity: service entity object :return: created collection entity :rtype: :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ return self.api.update_service(service_entity) def delete_service(self, service_id): - '''Delete service. + """Delete service. :param `opentaxii.taxii.entities.ServiceEntity` service_entity: service entity object - ''' + """ return self.api.delete_service(service_id) def delete_collection(self, collection_name): - '''Delete cllection. + """Delete cllection. :param str collection_name: name of a collection to delete - ''' + """ return self.api.delete_collection(collection_name) def set_collection_services(self, collection_id, service_ids): - '''Set collection's services. + """Set collection's services. NOTE: Additional method that is only used in the helper scripts shipped with OpenTAXII. - ''' - return self.api.set_collection_services( - collection_id, service_ids) + """ + return self.api.set_collection_services(collection_id, service_ids) def create_collection(self, entity): - '''Create a collection. + """Create a collection. :param `opentaxii.taxii.entities.CollectionEntity` collection_entity: collection entity object :return: created collection entity :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ collection = self.api.create_collection(entity) return collection def get_services(self): - '''Get configured services. + """Get configured services. Methods loads services entities via persistence API. :return: list of service entities. :rtype: list of :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ return self.api.get_services() def get_services_for_collection(self, collection): - '''Get the services associated with a collection. + """Get the services associated with a collection. :param `opentaxii.taxii.entities.CollectionEntity` collection: collection entity in question :return: list of service entities. :rtype: list of :py:class:`opentaxii.taxii.entities.ServiceEntity` - ''' + """ if context.account.can_read(collection.name): services = self.api.get_services(collection_id=collection.id) if context.account.can_modify(collection.name): return services else: - return list(filter(lambda s: s.type != 'inbox', services)) + return list(filter(lambda s: s.type != "inbox", services)) def get_collections(self, service_id=None): - '''Get the collections. If `service_id` is provided, return collection + """Get the collections. If `service_id` is provided, return collection attached to a service. :param str service_id: ID of the service in question :return: list of collection entities. :rtype: list of :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ collections = [ collection for collection in self.api.get_collections(service_id=service_id) - if context.account.can_read(collection.name)] + if context.account.can_read(collection.name) + ] return collections def get_collection(self, name, service_id=None): - '''Get a collection by name and service ID. + """Get a collection by name and service ID. Collection name is unique globally, so can be used as a key. Method retrieves collection entity using collection name @@ -135,24 +145,24 @@ def get_collection(self, name, service_id=None): :return: collection entity :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ collection = self.api.get_collection(name, service_id=service_id) if collection and context.account.can_read(collection.name): return collection def update_collection(self, collection): - '''Update a collection + """Update a collection :param `opentaxii.taxii.entities.CollectionEntity` collection_entity: collection entity object :return: updated collection entity :rtype: :py:class:`opentaxii.taxii.entities.CollectionEntity` - ''' + """ return self.api.update_collection(collection) def create_inbox_message(self, entity): - '''Create an inbox message. + """Create an inbox message. Methods emits :py:const:`opentaxii.signals.INBOX_MESSAGE_CREATED` signal. @@ -161,17 +171,18 @@ def create_inbox_message(self, entity): inbox message entity in question :return: updated inbox message entity :rtype: :py:class:`opentaxii.taxii.entities.InboxMessageEntity` - ''' + """ - if self.server.config['save_raw_inbox_messages']: + if self.server.config["save_raw_inbox_messages"]: entity = self.api.create_inbox_message(entity) INBOX_MESSAGE_CREATED.send(self, inbox_message=entity) return entity - def create_content(self, content, service_id=None, inbox_message_id=None, - collections=None): - '''Create a content block. + def create_content( + self, content, service_id=None, inbox_message_id=None, collections=None + ): + """Create a content block. Methods emits :py:const:`opentaxii.signals.CONTENT_BLOCK_CREATED` signal. @@ -186,7 +197,7 @@ def create_content(self, content, service_id=None, inbox_message_id=None, :py:class:`opentaxii.taxii.entities.CollectionEntity` :return: updated content block entity :rtype: :py:class:`opentaxii.taxii.entities.ContentBlockEntity` - ''' + """ if inbox_message_id: content.inbox_message_id = inbox_message_id @@ -194,25 +205,32 @@ def create_content(self, content, service_id=None, inbox_message_id=None, collection_ids = [ collection.id for collection in collections - if context.account.can_modify(collection.name)] + if context.account.can_modify(collection.name) + ] if collection_ids: content = self.api.create_content_block( - content, collection_ids=collection_ids, service_id=service_id) + content, collection_ids=collection_ids, service_id=service_id + ) CONTENT_BLOCK_CREATED.send( - self, content_block=content, - collection_ids=collection_ids, service_id=service_id) + self, + content_block=content, + collection_ids=collection_ids, + service_id=service_id, + ) else: log.warning( "create_content.unknown_collections", collections=[c.name for c in collections], - user=context.account) + user=context.account, + ) return content - def get_content_blocks_count(self, collection_id, start_time=None, - end_time=None, bindings=None): - '''Get a count of the content blocks associated with a collection. + def get_content_blocks_count( + self, collection_id, start_time=None, end_time=None, bindings=None + ): + """Get a count of the content blocks associated with a collection. :param str collection_id: ID fo a collection in question :param datetime start_time: start of a time frame @@ -222,16 +240,24 @@ def get_content_blocks_count(self, collection_id, start_time=None, :return: content block count :rtype: int - ''' + """ return self.api.get_content_blocks_count( collection_id=collection_id, start_time=start_time, end_time=end_time, - bindings=bindings or []) - - def get_content_blocks(self, collection_id, start_time=None, end_time=None, - bindings=None, offset=0, limit=None): - '''Get the content blocks associated with a collection. + bindings=bindings or [], + ) + + def get_content_blocks( + self, + collection_id, + start_time=None, + end_time=None, + bindings=None, + offset=0, + limit=None, + ): + """Get the content blocks associated with a collection. :param str collection_id: ID fo a collection in question :param datetime start_time: start of a time frame @@ -243,7 +269,7 @@ def get_content_blocks(self, collection_id, start_time=None, end_time=None, :return: content blocks list :rtype: list of :py:class:`opentaxii.taxii.entities.ContentBlockEntity` - ''' + """ return self.api.get_content_blocks( collection_id=collection_id, @@ -251,31 +277,32 @@ def get_content_blocks(self, collection_id, start_time=None, end_time=None, end_time=end_time, bindings=bindings or [], offset=offset, - limit=limit) + limit=limit, + ) def create_result_set(self, entity): - '''Create a result set. + """Create a result set. :param `opentaxii.taxii.entities.ResultSetEntity` entity: result set entity in question :return: updated result set entity :rtype: :py:class:`opentaxii.taxii.entities.ResultSetEntity` - ''' + """ return self.api.create_result_set(entity) def get_result_set(self, result_set_id): - '''Get a result set entity by ID. + """Get a result set entity by ID. :param str result_set_id: ID of a result set. :return: result set entity :rtype: :py:class:`opentaxii.taxii.entities.ResultSetEntity` - ''' + """ return self.api.get_result_set(result_set_id) def create_subscription(self, entity): - '''Create a subscription. + """Create a subscription. Methods emits :py:const:`opentaxii.signals.SUBSCRIPTION_CREATED` signal. @@ -285,7 +312,7 @@ def create_subscription(self, entity): :return: updated subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ created = self.api.create_subscription(entity) @@ -294,47 +321,47 @@ def create_subscription(self, entity): return created def get_subscription(self, subscription_id): - '''Get a subscription entity by ID. + """Get a subscription entity by ID. :param str subscription_id: ID of a subscription :return: subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ return self.api.get_subscription(subscription_id) def get_subscriptions(self, service_id): - '''Get the subscriptions attached to/created via a service. + """Get the subscriptions attached to/created via a service. :param str service_id: ID of a service :return: list of subscription entities :rtype: list of :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ return self.api.get_subscriptions(service_id=service_id) def update_subscription(self, subscription): - '''Update a subscription status. + """Update a subscription status. :param `opentaxii.taxii.entities.SubscriptionEntity` subscription: subscription entity in question :return: updated subscription entity :rtype: :py:class:`opentaxii.taxii.entities.SubscriptionEntity` - ''' + """ return self.api.update_subscription(subscription) def get_domain(self, service_id): - '''Get configured domain name needed to create absolute URLs. + """Get configured domain name needed to create absolute URLs. :param str service_id: ID of a service - ''' + """ return self.api.get_domain(service_id) def delete_content_blocks( - self, collection_name, start_time, end_time=None, - with_messages=False): - '''Delete content blocks in a specified collection with + self, collection_name, start_time, end_time=None, with_messages=False + ): + """Delete content blocks in a specified collection with timestamp label in a specified time frame. :param str collection_name: collection name @@ -344,28 +371,253 @@ def delete_content_blocks( :return: the count of rows deleted :rtype: int - ''' + """ count = self.api.delete_content_blocks( - collection_name, start_time, end_time=end_time, - with_messages=with_messages) + collection_name, start_time, end_time=end_time, with_messages=with_messages + ) log.info( "collection.content_blocks.deleted", with_messages=with_messages, collection=collection_name, - count=count) + count=count, + ) return count +class JobDetailsResponse(NamedTuple): + total_count: int + success: List[JobDetail] + failure: List[JobDetail] + pending: List[JobDetail] + + class Taxii2PersistenceManager(BasePersistenceManager): - '''Manager responsible for persisting and retrieving data. + """Manager responsible for persisting and retrieving data. Manager uses API instance ``api`` for basic data CRUD operations and provides additional logic on top. :param `opentaxii.persistence.api.OpenTAXII2PersistenceAPI` api: instance of persistence API class - ''' + """ def __init__(self, server, api): self.server = server self.api = api + + def get_api_roots(self) -> Tuple[Optional[ApiRoot], List[ApiRoot]]: + """ + Get (optional) default api root and list of all api roots. + + :return: Tuple of (default_api_root, all_api_roots) + """ + api_roots = self.api.get_api_roots() + if not api_roots: + return None, [] + default_api_root = None + for api_root in api_roots: + if api_root.default: + default_api_root = api_root + break + return (default_api_root, api_roots) + + def get_api_root(self, api_root_id: str) -> ApiRoot: + api_root = self.api.get_api_root(api_root_id=api_root_id) + if api_root is None: + raise DoesNotExist() + return api_root + + def _get_job_details_response( + self, job_details: List[JobDetail] + ) -> JobDetailsResponse: + job_details_response = JobDetailsResponse( + total_count=len(job_details), success=[], failure=[], pending=[] + ) + for job_detail in job_details: + getattr(job_details_response, job_detail.status).append(job_detail) + return job_details_response + + def get_job_and_details( + self, api_root_id: str, job_id: str + ) -> Tuple[Job, JobDetailsResponse]: + job, job_details = self.api.get_job_and_details( + api_root_id=api_root_id, job_id=job_id + ) + if job is None: + raise DoesNotExist() + job_details_response = self._get_job_details_response(job_details) + return (job, job_details_response) + + def get_collections(self, api_root_id: str) -> List[Collection]: + return self.api.get_collections(api_root_id=api_root_id) + + def get_collection( + self, api_root_id: str, collection_id_or_alias: str + ) -> Collection: + collection = self.api.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if collection is None: + raise DoesNotExist() + return collection + + def get_manifest( + self, + api_root_id: str, + collection_id_or_alias: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[ManifestRecord], bool]: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection.can_read(context.account): + raise NoReadPermission() + return self.api.get_manifest( + collection_id=collection.id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + def get_objects( + self, + api_root_id: str, + collection_id_or_alias: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[STIXObject], bool]: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection.can_read(context.account): + raise NoReadPermission() + return self.api.get_objects( + collection_id=collection.id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + def add_objects( + self, + api_root_id: str, + collection_id_or_alias: str, + data: Dict, + ) -> Tuple[Job, JobDetailsResponse]: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection.can_write(context.account): + raise NoWritePermission() + job, job_details = self.api.add_objects( + api_root_id=api_root_id, + collection_id=collection.id, + objects=data["objects"], + ) + job_details_response = self._get_job_details_response(job_details) + return (job, job_details_response) + + def get_object( + self, + api_root_id: str, + collection_id_or_alias: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[STIXObject], bool]: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection.can_read(context.account): + raise NoReadPermission() + versions, more = self.api.get_object( + collection_id=collection.id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_version=match_version, + match_spec_version=match_spec_version, + ) + if versions is None: + raise DoesNotExist() + return (versions, more) + + def delete_object( + self, + api_root_id: str, + collection_id_or_alias: str, + object_id: str, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> None: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection: + raise DoesNotExist() + if not collection.can_read(context.account) and not collection.can_write( + context.account + ): + raise NoReadNoWritePermission + if not collection.can_read(context.account): + raise NoReadPermission + if not collection.can_write(context.account): + raise NoWritePermission + return self.api.delete_object( + collection_id=collection.id, + object_id=object_id, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + def get_versions( + self, + api_root_id: str, + collection_id_or_alias: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[VersionRecord], bool]: + collection = self.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + if not collection: + raise DoesNotExist() + if not collection.can_read(context.account): + raise NoReadPermission + versions, more = self.api.get_versions( + collection_id=collection.id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_spec_version=match_spec_version, + ) + if versions is None: + raise DoesNotExist() + return (versions, more) diff --git a/opentaxii/persistence/sqldb/api.py b/opentaxii/persistence/sqldb/api.py index 1151684d..b3e7fe5c 100644 --- a/opentaxii/persistence/sqldb/api.py +++ b/opentaxii/persistence/sqldb/api.py @@ -1,16 +1,23 @@ +import datetime import json +import uuid +from functools import reduce +from typing import Dict, List, Optional, Tuple import six import structlog from opentaxii.common.sqldb import BaseSQLDatabaseAPI from opentaxii.persistence import (OpenTAXII2PersistenceAPI, OpenTAXIIPersistenceAPI) -from sqlalchemy import and_, func, or_ +from opentaxii.persistence.sqldb import taxii2models +from opentaxii.taxii2 import entities +from opentaxii.taxii2.utils import DATETIMEFORMAT +from sqlalchemy import and_, func, literal, or_ +from sqlalchemy.orm import Query, load_only from . import converters as conv from .models import (Base, ContentBlock, DataCollection, InboxMessage, ResultSet, Service, Subscription) -from .taxii2models import Base as Taxii2Base __all__ = ["SQLDatabaseAPI"] @@ -35,21 +42,22 @@ class SQLDatabaseAPI(BaseSQLDatabaseAPI, OpenTAXIIPersistenceAPI): :param engine_parameters=None: if defined, these arguments would be passed to sqlalchemy.create_engine """ + BASEMODEL = Base def get_services(self, collection_id=None): if collection_id: - collection = DataCollection.query.get(collection_id) + collection = self.db.session.query(DataCollection).get(collection_id) services = collection.services else: - services = Service.query.all() + services = self.db.session.query(Service).all() return [conv.to_service_entity(s) for s in services] def get_service(self, service_id): - return conv.to_service_entity(Service.query.get(service_id)) + return conv.to_service_entity(self.db.session.query(Service).get(service_id)) def update_service(self, obj): - service = Service.query.get(obj.id) if obj.id else None + service = self.db.session.query(Service).get(obj.id) if obj.id else None if service: service.type = obj.type service.properties = obj.properties @@ -64,30 +72,33 @@ def create_service(self, entity): def get_collections(self, service_id=None): if service_id: - service = Service.query.get(service_id) + service = self.db.session.query(Service).get(service_id) collections = service.collections else: - collections = DataCollection.query.all() + collections = self.db.session.query(DataCollection).all() return [conv.to_collection_entity(c) for c in collections] def get_collection(self, name, service_id=None): if service_id: collection = ( - DataCollection.query.join(Service.collections) + self.db.session.query(DataCollection) + .join(Service.collections) .filter(Service.id == service_id) .filter(DataCollection.name == name) .one_or_none() ) else: - collection = DataCollection.query.filter( - DataCollection.name == name - ).one_or_none() + collection = ( + self.db.session.query(DataCollection) + .filter(DataCollection.name == name) + .one_or_none() + ) if collection: return conv.to_collection_entity(collection) def update_collection(self, entity): _bindings = conv.serialize_content_bindings(entity.supported_content) - collection = DataCollection.query.get(entity.id) + collection = self.db.session.query(DataCollection).get(entity.id) if not collection: raise ValueError("DataCollection with id {} is not found".format(entity.id)) collection.name = entity.name @@ -100,14 +111,16 @@ def update_collection(self, entity): return conv.to_collection_entity(collection) def delete_collection(self, collection_name): - collection = DataCollection.query.filter( - DataCollection.name == collection_name - ).one() + collection = ( + self.db.session.query(DataCollection) + .filter(DataCollection.name == collection_name) + .one() + ) self.db.session.delete(collection) self.db.session.commit() def delete_service(self, service_id): - service = Service.query.get(service_id) + service = self.db.session.query(Service).get(service_id) self.db.session.delete(service) self.db.session.commit() @@ -122,7 +135,9 @@ def _get_content_query( if count: query = self.db.session.query(func.count(ContentBlock.id)) else: - query = ContentBlock.query.order_by(ContentBlock.timestamp_label.asc()) + query = self.db.session.query(ContentBlock).order_by( + ContentBlock.timestamp_label.asc() + ) if collection_id: query = query.join(ContentBlock.collections).filter( @@ -209,13 +224,13 @@ def create_collection(self, entity): return conv.to_collection_entity(collection) def set_collection_services(self, collection_id, service_ids): - collection = DataCollection.query.get(collection_id) + collection = self.db.session.query(DataCollection).get(collection_id) if not collection: raise ValueError( "Collection with id {} does not exist".format(collection_id) ) services = ( - Service.query.filter(Service.id.in_(service_ids)).all() + self.db.session.query(Service).filter(Service.id.in_(service_ids)).all() if service_ids else [] ) @@ -311,7 +326,7 @@ def _attach_content_to_collections(self, content_block, collection_ids): return criteria = DataCollection.id.in_(collection_ids) - new_collections = DataCollection.query.filter(criteria) + new_collections = self.db.session.query(DataCollection).filter(criteria) content_block.collections.extend(new_collections) @@ -343,15 +358,15 @@ def create_result_set(self, entity): return conv.to_result_set_entity(result_set) def get_result_set(self, result_set_id): - result_set = ResultSet.query.get(result_set_id) + result_set = self.db.session.query(ResultSet).get(result_set_id) return conv.to_result_set_entity(result_set) def get_subscription(self, subscription_id): - s = Subscription.query.get(subscription_id) + s = self.db.session.query(Subscription).get(subscription_id) return conv.to_subscription_entity(s) def get_subscriptions(self, service_id): - service = Service.query.get(service_id) + service = self.db.session.query(Service).get(service_id) return [conv.to_subscription_entity(s) for s in service.subscriptions] def update_subscription(self, entity): @@ -367,7 +382,7 @@ def update_subscription(self, entity): params = {} subscription = ( - Subscription.query.get(entity.subscription_id) + self.db.session.query(Subscription).get(entity.subscription_id) if entity.subscription_id else None ) @@ -405,7 +420,11 @@ def delete_content_blocks( self, collection_name, start_time, end_time=None, with_messages=False ): - collection = DataCollection.query.filter_by(name=collection_name).one_or_none() + collection = ( + self.db.session.query(DataCollection) + .filter_by(name=collection_name) + .one_or_none() + ) if not collection: raise ValueError( @@ -432,18 +451,24 @@ def delete_content_blocks( if with_messages: ( - InboxMessage.query.filter( + self.db.session.query(InboxMessage) + .filter( InboxMessage.id.in_( self.db.session.query(inbox_messages_query.subquery(name="ids")) ) - ).delete(synchronize_session=False) + ) + .delete(synchronize_session=False) ) - counter = ContentBlock.query.filter( - ContentBlock.id.in_( - self.db.session.query(content_blocks_query.subquery(name="ids")) + counter = ( + self.db.session.query(ContentBlock) + .filter( + ContentBlock.id.in_( + self.db.session.query(content_blocks_query.subquery(name="ids")) + ) ) - ).delete(synchronize_session=False) + .delete(synchronize_session=False) + ) collection.volume = ( self.db.session.query(func.count(ContentBlock.id)) @@ -457,4 +482,617 @@ def delete_content_blocks( class Taxii2SQLDatabaseAPI(BaseSQLDatabaseAPI, OpenTAXII2PersistenceAPI): - BASEMODEL = Taxii2Base + BASEMODEL = taxii2models.Base + + def get_api_roots(self) -> List[entities.ApiRoot]: + query = self.db.session.query(taxii2models.ApiRoot).order_by("title") + return [ + entities.ApiRoot( + id=obj.id, + default=obj.default, + title=obj.title, + description=obj.description, + ) + for obj in query.all() + ] + + def get_api_root(self, api_root_id: str) -> Optional[entities.ApiRoot]: + api_root = ( + self.db.session.query(taxii2models.ApiRoot) + .filter(taxii2models.ApiRoot.id == api_root_id) + .one_or_none() + ) + if api_root: + return entities.ApiRoot( + id=api_root.id, + default=api_root.default, + title=api_root.title, + description=api_root.description, + ) + else: + return None + + def add_api_root( + self, + title: str, + description: Optional[str] = None, + default: Optional[bool] = False, + ) -> entities.ApiRoot: + """ + Add a new api root. + + :param str title: Title of the new api root + :param str description: [Optional] Description of the new api root + :param bool default: [Optional, False] If the new api should be the default + + :return: The added ApiRoot entity. + """ + api_root = taxii2models.ApiRoot( + title=title, description=description, default=False + ) + self.db.session.add(api_root) + self.db.session.commit() + if default: + api_root.set_default(self.db.session) + return entities.ApiRoot( + id=api_root.id, + default=api_root.default, + title=api_root.title, + description=api_root.description, + ) + + def get_job_and_details( + self, api_root_id: str, job_id: str + ) -> Tuple[Optional[entities.Job], List[entities.JobDetail]]: + job = ( + self.db.session.query(taxii2models.Job) + .filter( + taxii2models.Job.api_root_id == api_root_id, + taxii2models.Job.id == job_id, + ) + .one_or_none() + ) + if job is None: + return None, [] + job_details = ( + self.db.session.query(taxii2models.JobDetail) + .filter( + taxii2models.JobDetail.job_id == job_id, + ) + .order_by(taxii2models.JobDetail.stix_id) + .all() + ) + return ( + entities.Job( + id=job.id, + api_root_id=job.api_root_id, + status=job.status, + request_timestamp=job.request_timestamp, + completed_timestamp=job.completed_timestamp, + ), + [ + entities.JobDetail( + id=job_detail.id, + job_id=job_detail.job_id, + stix_id=job_detail.stix_id, + version=job_detail.version, + message=job_detail.message, + status=job_detail.status, + ) + for job_detail in job_details + ], + ) + + def job_cleanup(self) -> int: + """ + Remove jobs that are >24h old. + + :return: The number of removed jobs. + """ + return taxii2models.Job.cleanup(self.db.session) + + def get_collections(self, api_root_id: str) -> List[entities.Collection]: + query = ( + self.db.session.query(taxii2models.Collection) + .filter(taxii2models.Collection.api_root_id == api_root_id) + .order_by(taxii2models.Collection.title) + ) + return [ + entities.Collection( + id=obj.id, + api_root_id=obj.api_root_id, + title=obj.title, + description=obj.description, + alias=obj.alias, + ) + for obj in query.all() + ] + + def get_collection( + self, api_root_id: str, collection_id_or_alias: str + ) -> Optional[entities.Collection]: + id_or_alias_filter = taxii2models.Collection.alias == collection_id_or_alias + try: + uuid.UUID(collection_id_or_alias) + except ValueError: + pass + else: + id_or_alias_filter |= taxii2models.Collection.id == collection_id_or_alias + obj = ( + self.db.session.query(taxii2models.Collection) + .filter( + taxii2models.Collection.api_root_id == api_root_id, + id_or_alias_filter, + ) + .one_or_none() + ) + if obj is None: + return None + return entities.Collection( + id=obj.id, + api_root_id=obj.api_root_id, + title=obj.title, + description=obj.description, + alias=obj.alias, + ) + + def add_collection( + self, + api_root_id: str, + title: str, + description: Optional[str] = None, + alias: Optional[str] = None, + ) -> entities.Collection: + """ + Add a new collection. + + :param str api_root_id: ID of the api root the new collection is part of + :param str title: Title of the new collection + :param str description: [Optional] Description of the new collection + :param str alias: [Optional] Alias of the new collection + + :return: The added Collection entity. + """ + collection = taxii2models.Collection( + api_root_id=api_root_id, title=title, description=description, alias=alias + ) + self.db.session.add(collection) + self.db.session.commit() + + return entities.Collection( + id=collection.id, + api_root_id=collection.api_root_id, + title=collection.title, + description=collection.description, + alias=collection.alias, + ) + + def _objects_query(self, collection_id: str, ordered: bool) -> Query: + query = self.db.session.query(taxii2models.STIXObject).filter( + taxii2models.STIXObject.collection_id == collection_id, + ) + if ordered: + query = query.order_by( + taxii2models.STIXObject.date_added, taxii2models.STIXObject.id + ) + return query + + def _apply_added_after( + self, query: Query, added_after: Optional[datetime.datetime] = None + ) -> Query: + if added_after is not None: + query = query.filter(taxii2models.STIXObject.date_added > added_after) + return query + + def _apply_next_kwargs( + self, query: Query, next_kwargs: Optional[Dict] = None + ) -> Query: + if next_kwargs is not None: + query = query.filter( + (taxii2models.STIXObject.date_added > next_kwargs["date_added"]) + | ( + (taxii2models.STIXObject.date_added == next_kwargs["date_added"]) + & (taxii2models.STIXObject.id > next_kwargs["id"]) + ) + ) + return query + + def _apply_match_id( + self, query: Query, match_id: Optional[List[str]] = None + ) -> Query: + if match_id is not None: + query = query.filter(taxii2models.STIXObject.id.in_(match_id)) + return query + + def _apply_match_type( + self, query: Query, match_type: Optional[List[str]] = None + ) -> Query: + if match_type is not None: + query = query.filter(taxii2models.STIXObject.type.in_(match_type)) + return query + + def _apply_match_version( + self, + query: Query, + collection_id: str, + match_version: Optional[List[str]] = None, + ) -> Query: + if match_version is None: + match_version = ["last"] + if "all" in match_version: + return query + version_filters = [] + for value in match_version: + if value == "first": + min_versions_subq = ( + self.db.session.query( + taxii2models.STIXObject.id, + func.min(taxii2models.STIXObject.version).label("min_version"), + ) + .filter( + taxii2models.STIXObject.collection_id == collection_id, + ) + .group_by(taxii2models.STIXObject.id) + .subquery() + ) + min_version_pks = ( + self.db.session.query(taxii2models.STIXObject.pk) + .select_from(taxii2models.STIXObject) + .join( + min_versions_subq, + ( + (taxii2models.STIXObject.id == min_versions_subq.c.id) + & ( + taxii2models.STIXObject.version + == min_versions_subq.c.min_version + ) + ), + ) + ) + version_filters.append(taxii2models.STIXObject.pk.in_(min_version_pks)) + elif value == "last": + max_versions_subq = ( + self.db.session.query( + taxii2models.STIXObject.id, + func.max(taxii2models.STIXObject.version).label("max_version"), + ) + .filter( + taxii2models.STIXObject.collection_id == collection_id, + ) + .group_by(taxii2models.STIXObject.id) + .subquery() + ) + max_version_pks = ( + self.db.session.query(taxii2models.STIXObject.pk) + .select_from(taxii2models.STIXObject) + .join( + max_versions_subq, + ( + (taxii2models.STIXObject.id == max_versions_subq.c.id) + & ( + taxii2models.STIXObject.version + == max_versions_subq.c.max_version + ) + ), + ) + ) + version_filters.append(taxii2models.STIXObject.pk.in_(max_version_pks)) + else: + version_filters.append(taxii2models.STIXObject.version == value) + query = query.filter(reduce(or_, version_filters)) + return query + + def _apply_match_spec_version( + self, query: Query, match_spec_version: Optional[List[str]] = None + ) -> Query: + if match_spec_version is not None: + query = query.filter( + taxii2models.STIXObject.spec_version.in_(match_spec_version) + ) + return query + + def _apply_limit( + self, query: Query, limit: Optional[int] = None + ) -> Tuple[Query, bool]: + if limit is not None: + more = limit < query.count() + query = query.limit(limit) + else: + more = False + return query, more + + def _filtered_objects_query( + self, + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ordered: Optional[bool] = True, + ) -> Tuple[Query, bool]: + query = self._objects_query(collection_id, ordered) + query = self._apply_added_after(query, added_after) + query = self._apply_next_kwargs(query, next_kwargs) + query = self._apply_match_id(query, match_id) + query = self._apply_match_type(query, match_type) + query = self._apply_match_version(query, collection_id, match_version) + query = self._apply_match_spec_version(query, match_spec_version) + query, more = self._apply_limit(query, limit) + return query, more + + def get_manifest( + self, + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[entities.ManifestRecord], bool]: + query, more = self._filtered_objects_query( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + query = query.options( + load_only( + taxii2models.STIXObject.id, + taxii2models.STIXObject.date_added, + taxii2models.STIXObject.version, + taxii2models.STIXObject.spec_version, + ) + ) + return ( + [ + entities.ManifestRecord( + id=obj.id, + date_added=obj.date_added, + version=obj.version, + spec_version=obj.spec_version, + ) + for obj in query.all() + ], + more, + ) + + def get_objects( + self, + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[entities.STIXObject], bool]: + query, more = self._filtered_objects_query( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + return ( + [ + entities.STIXObject( + id=obj.id, + collection_id=collection_id, + type=obj.type, + spec_version=obj.spec_version, + date_added=obj.date_added, + version=obj.version, + serialized_data=obj.serialized_data, + ) + for obj in query.all() + ], + more, + ) + + def add_objects( + self, api_root_id: str, collection_id: str, objects: List[Dict] + ) -> Tuple[entities.Job, List[entities.JobDetail]]: + job = taxii2models.Job( + api_root_id=api_root_id, + status="pending", + request_timestamp=datetime.datetime.now(datetime.timezone.utc), + ) + self.db.session.add(job) + self.db.session.commit() + job_details = [] + for obj in objects: + version = datetime.datetime.strptime( + obj["modified"], DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc) + if ( + not self.db.session.query(literal(True)) + .filter( + self.db.session.query(taxii2models.STIXObject) + .filter( + taxii2models.STIXObject.id == obj["id"], + taxii2models.STIXObject.collection_id == collection_id, + taxii2models.STIXObject.version == version, + ) + .exists() + ) + .scalar() + ): + self.db.session.add( + taxii2models.STIXObject( + id=obj["id"], + collection_id=collection_id, + type=obj["id"].split("--")[0], + spec_version=obj["spec_version"], + date_added=datetime.datetime.now(datetime.timezone.utc), + version=version, + serialized_data={ + key: value + for (key, value) in obj.items() + if key not in ["id", "type", "spec_version"] + }, + ) + ) + job_detail = taxii2models.JobDetail( + job_id=job.id, + stix_id=obj["id"], + version=version, + message="", + status="success", + ) + job_details.append(job_detail) + self.db.session.add(job_detail) + job.status = "complete" + job.completed_timestamp = datetime.datetime.now(datetime.timezone.utc) + self.db.session.commit() + return ( + entities.Job( + id=job.id, + api_root_id=job.api_root_id, + status=job.status, + request_timestamp=job.request_timestamp, + completed_timestamp=job.completed_timestamp, + ), + [ + entities.JobDetail( + id=job_detail.id, + job_id=job_detail.job_id, + stix_id=job_detail.stix_id, + version=job_detail.version, + message=job_detail.message, + status=job_detail.status, + ) + for job_detail in job_details + ], + ) + + def get_object( + self, + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[Optional[List[entities.STIXObject]], bool]: + """ + Get all versions of single object from database. + + Should return `None` when object matching object_id doesn't exist. + """ + if ( + not self.db.session.query(literal(True)) + .filter( + self.db.session.query(taxii2models.STIXObject) + .filter( + taxii2models.STIXObject.id == object_id, + taxii2models.STIXObject.collection_id == collection_id, + ) + .exists() + ) + .scalar() + ): + return (None, False) + query, more = self._filtered_objects_query( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=[object_id], + match_version=match_version, + match_spec_version=match_spec_version, + ) + return ( + [ + entities.STIXObject( + id=obj.id, + collection_id=collection_id, + type=obj.type, + spec_version=obj.spec_version, + date_added=obj.date_added, + version=obj.version, + serialized_data=obj.serialized_data, + ) + for obj in query.all() + ], + more, + ) + + def delete_object( + self, + collection_id: str, + object_id: str, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, + ) -> None: + if match_version is None: + match_version = ["all"] + query, _ = self._filtered_objects_query( + collection_id=collection_id, + match_id=[object_id], + match_version=match_version, + match_spec_version=match_spec_version, + ordered=False, + ) + query.delete("fetch") + + def get_versions( + self, + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_spec_version: Optional[List[str]] = None, + ) -> Tuple[List[entities.VersionRecord], bool]: + if ( + not self.db.session.query(literal(True)) + .filter( + self.db.session.query(taxii2models.STIXObject) + .filter( + taxii2models.STIXObject.id == object_id, + taxii2models.STIXObject.collection_id == collection_id, + ) + .exists() + ) + .scalar() + ): + return (None, False) + query, more = self._filtered_objects_query( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=[object_id], + match_version=["all"], + match_spec_version=match_spec_version, + ) + query = query.options( + load_only( + taxii2models.STIXObject.date_added, + taxii2models.STIXObject.version, + ) + ) + return ( + [ + entities.VersionRecord( + date_added=obj.date_added, + version=obj.version, + ) + for obj in query.all() + ], + more, + ) diff --git a/opentaxii/persistence/sqldb/common.py b/opentaxii/persistence/sqldb/common.py new file mode 100644 index 00000000..c37cbd77 --- /dev/null +++ b/opentaxii/persistence/sqldb/common.py @@ -0,0 +1,72 @@ +"""A module to put common database helper components.""" +import uuid +from datetime import timezone + +from sqlalchemy.dialects import mysql +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.types import CHAR, DateTime, TypeDecorator + + +class GUID(TypeDecorator): + """ + Platform-independent GUID type. + + Uses PostgreSQL's UUID type, otherwise uses + CHAR(32), storing as stringified hex values. + """ + + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect): + """Switch implementation based on database dialect.""" + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(32)) + + def process_bind_param(self, value, dialect): + """Convert from python to database representation.""" + if value is None: + return value + elif dialect.name == 'postgresql': + return str(value) + else: + if not isinstance(value, uuid.UUID): + return "%.32x" % uuid.UUID(value).int + else: + # hexstring + return "%.32x" % value.int + + def process_result_value(self, value, dialect): + """Convert from database to python representation.""" + if value is None: + return value + else: + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return value + + +class UTCDateTime(TypeDecorator): + """Platform-independent DateTime type that always stores and returns UTC.""" + + impl = DateTime + cache_ok = True + + def load_dialect_impl(self, dialect): + """Switch implementation based on database dialect.""" + if dialect.name == 'mysql': + return dialect.type_descriptor(mysql.DATETIME(fsp=6)) + else: + return dialect.type_descriptor(DateTime()) + + def process_bind_param(self, value, engine): + """Convert from python to database representation.""" + if value is not None: + return value.astimezone(timezone.utc) + + def process_result_value(self, value, engine): + """Convert from database to python representation.""" + if value is not None: + return value.replace(tzinfo=timezone.utc) diff --git a/opentaxii/persistence/sqldb/taxii2models.py b/opentaxii/persistence/sqldb/taxii2models.py index 860e5425..c95bb871 100644 --- a/opentaxii/persistence/sqldb/taxii2models.py +++ b/opentaxii/persistence/sqldb/taxii2models.py @@ -1,3 +1,147 @@ +"""Database models for taxii2 entities.""" +import datetime +import uuid + +import sqlalchemy +from opentaxii.persistence.sqldb.common import GUID, UTCDateTime +from opentaxii.taxii2 import entities +from sqlalchemy import literal from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship Base = declarative_base() + + +class ApiRoot(Base): + """Database equivalent of `entities.ApiRoot`.""" + + __tablename__ = "opentaxii_api_root" + + id = sqlalchemy.Column(GUID, primary_key=True, default=uuid.uuid4) + default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False) + title = sqlalchemy.Column(sqlalchemy.Text, nullable=False) + description = sqlalchemy.Column(sqlalchemy.Text) + + collections = relationship("Collection", back_populates="api_root") + + def set_default(self, session: sqlalchemy.orm.Session): + """Set this api root as default. Make sure there is always at most 1 default api root.""" + session.query(ApiRoot).filter(ApiRoot.default == literal(True)).update( + {ApiRoot.default: literal(False)} + ) + session.query(ApiRoot).filter(ApiRoot.id == self.id).update( + {ApiRoot.default: literal(True)} + ) + + @classmethod + def from_entity(cls, entity: entities.ApiRoot): + """Generate database model from input entity.""" + return cls(**entity.to_dict()) + + +class Job(Base): + """Database equivalent of `entities.Job`.""" + + __tablename__ = "opentaxii_job" + + id = sqlalchemy.Column(GUID, primary_key=True, default=uuid.uuid4) + api_root_id = sqlalchemy.Column( + GUID, sqlalchemy.ForeignKey("opentaxii_api_root.id") + ) + status = sqlalchemy.Column( + sqlalchemy.Enum("pending", "complete", name="job_status_enum") + ) + request_timestamp = sqlalchemy.Column(UTCDateTime, nullable=True) + completed_timestamp = sqlalchemy.Column(UTCDateTime, nullable=True) + + details = relationship("JobDetail", back_populates="job") + + @classmethod + def cleanup(cls, session: sqlalchemy.orm.Session) -> int: + """ + Remove jobs that are >24h old. + + :return: The number of removed jobs. + """ + return session.query(cls).filter( + cls.completed_timestamp + < ( + datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(hours=24) + ) + ).delete() + + @classmethod + def from_entity(cls, entity: entities.Job): + """Generate database model from input entity.""" + return cls(**entity.to_dict()) + + +class JobDetail(Base): + """Database equivalent of `entities.JobDetail`.""" + + __tablename__ = "opentaxii_job_detail" + + id = sqlalchemy.Column(GUID, primary_key=True, default=uuid.uuid4) + job_id = sqlalchemy.Column(GUID, sqlalchemy.ForeignKey("opentaxii_job.id")) + stix_id = sqlalchemy.Column(sqlalchemy.Text) + version = sqlalchemy.Column(UTCDateTime) + message = sqlalchemy.Column(sqlalchemy.Text) + status = sqlalchemy.Column( + sqlalchemy.Enum("success", "failure", "pending", name="job_detail_status_enum") + ) + + job = relationship("Job", back_populates="details") + + @classmethod + def from_entity(cls, entity: entities.JobDetail): + """Generate database model from input entity.""" + return cls(**entity.to_dict()) + + +class Collection(Base): + """Database equivalent of `entities.Collection`.""" + + __tablename__ = "opentaxii_collection" + __table_args__ = (sqlalchemy.UniqueConstraint("api_root_id", "alias"),) + + id = sqlalchemy.Column(GUID, primary_key=True, default=uuid.uuid4) + api_root_id = sqlalchemy.Column( + GUID, sqlalchemy.ForeignKey("opentaxii_api_root.id") + ) + title = sqlalchemy.Column(sqlalchemy.Text, nullable=False) + description = sqlalchemy.Column(sqlalchemy.Text) + alias = sqlalchemy.Column(sqlalchemy.String(100), nullable=True) + + api_root = relationship("ApiRoot", back_populates="collections") + objects = relationship("STIXObject", back_populates="collection") + + @classmethod + def from_entity(cls, entity: entities.Collection): + """Generate database model from input entity.""" + return cls(**entity.to_dict()) + + +class STIXObject(Base): + """Database equivalent of `entities.STIXObject`.""" + + __tablename__ = "opentaxii_stixobject" + __table_args__ = (sqlalchemy.UniqueConstraint("collection_id", "id", "version"),) + + pk = sqlalchemy.Column(GUID, primary_key=True, default=uuid.uuid4) + id = sqlalchemy.Column(sqlalchemy.String(100)) + collection_id = sqlalchemy.Column( + GUID, sqlalchemy.ForeignKey("opentaxii_collection.id") + ) + type = sqlalchemy.Column(sqlalchemy.Text) + spec_version = sqlalchemy.Column(sqlalchemy.Text) # STIX version + date_added = sqlalchemy.Column(UTCDateTime) + version = sqlalchemy.Column(UTCDateTime) + serialized_data = sqlalchemy.Column(sqlalchemy.JSON) + + collection = relationship("Collection", back_populates="objects") + + @classmethod + def from_entity(cls, entity: entities.STIXObject): + """Generate database model from input entity.""" + return cls(**entity.to_dict()) diff --git a/opentaxii/server.py b/opentaxii/server.py index 21ad146b..d679223f 100644 --- a/opentaxii/server.py +++ b/opentaxii/server.py @@ -1,9 +1,31 @@ import functools import importlib -from typing import Callable, ClassVar, NamedTuple, Optional, Type +import json + +try: + from re import Pattern +except ImportError: + from typing.re import Pattern + +from typing import Callable, ClassVar, NamedTuple, Optional, Tuple, Type import structlog -from flask import Flask, Response, abort, request +from flask import Flask, Response, request +from werkzeug.exceptions import (Forbidden, MethodNotAllowed, NotAcceptable, + NotFound, RequestEntityTooLarge, Unauthorized, + UnsupportedMediaType) + +from opentaxii.persistence.exceptions import (DoesNotExist, + NoReadNoWritePermission, + NoReadPermission, + NoWritePermission) +from opentaxii.taxii2.utils import get_next_param, taxii2_datetimeformat +from opentaxii.taxii2.validation import (validate_delete_filter_params, + validate_envelope, + validate_list_filter_params, + validate_object_filter_params, + validate_versions_filter_params) +from opentaxii.utils import register_handler from .auth import AuthManager from .config import ServerConfig @@ -12,6 +34,7 @@ from .local import context from .persistence import (BasePersistenceManager, Taxii1PersistenceManager, Taxii2PersistenceManager) +from .taxii2.http import make_taxii2_response from .taxii.bindings import (ALL_PROTOCOL_BINDINGS, MESSAGE_BINDINGS, SERVICE_BINDINGS) from .taxii.exceptions import (FailureStatus, StatusMessageException, @@ -39,6 +62,7 @@ class BaseTAXIIServer: """ PERSISTENCE_MANAGER_CLASS: ClassVar[Type[BasePersistenceManager]] + ENDPOINT_MAPPING: Tuple[(Pattern, Callable[[], Response])] app: Flask config: dict @@ -47,6 +71,16 @@ def __init__(self, config: dict): self.persistence = self.PERSISTENCE_MANAGER_CLASS( server=self, api=initialize_api(config["persistence_api"]) ) + self.setup_endpoint_mapping() + + def setup_endpoint_mapping(self): + mapping = [] + for attr_name in self.__dir__(): + attr = getattr(self, attr_name) + if hasattr(attr, "registered_url_re"): + mapping.append((attr.registered_url_re, attr)) + if mapping: + self.ENDPOINT_MAPPING = tuple(mapping) def init_app(self, app: Flask): """Connect server and persistence to flask.""" @@ -79,6 +113,17 @@ def handle_status_exception(self, error): """ return + def handle_http_exception(self, error): + return error.get_response() + + def handle_validation_exception(self, error): + """ + Handle validation exception and return appropriate response. + + Placeholder for subclasses to implement. + """ + return + def raise_unauthorized(self): """ Handle unauthorized access. @@ -148,6 +193,11 @@ def _create_services(self, service_entities): return services + def check_allowed_methods(self): + valid_methods = ["POST", "OPTIONS"] + if request.method not in valid_methods: + raise MethodNotAllowed(valid_methods=valid_methods) + def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: """Get first endpoint matching relative_path.""" for endpoint in self.get_services(): @@ -221,6 +271,7 @@ def handle_request(self, endpoint: TAXIIService): Process :class:`TAXIIService` with either :meth:`_process_with_service` or :meth:`_process_options_request`. """ + self.check_allowed_methods() if endpoint.authentication_required and context.account is None: raise UnauthorizedException( status_type=self.config["unauthorized_status"], @@ -340,6 +391,422 @@ class TAXII2Server(BaseTAXIIServer): PERSISTENCE_MANAGER_CLASS = Taxii2PersistenceManager + def handle_http_exception(self, error): + """Return JSON instead of HTML for HTTP errors.""" + # start with the correct headers and status code from the error + response = error.get_response() + # replace the body with JSON + response.data = json.dumps( + { + "code": error.code, + "name": error.name, + "description": error.description, + } + ) + response.content_type = "application/taxii+json;version=2.1" + return response + + def handle_validation_exception(self, error): + """ + Handle validation exception and return appropriate response. + """ + response = { + "code": 400, + "name": "validation error", + "description": error.messages, + } + return make_taxii2_response(response, status=400) + + def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: + endpoint = None + for regex, handler in self.ENDPOINT_MAPPING: + match = regex.match(relative_path) + if match: + endpoint = functools.partial(handler, **match.groupdict()) + break + if endpoint: + return functools.partial(self.handle_request, endpoint) + + def check_authentication(self): + if context.account is None: + raise Unauthorized() + + def check_content_length(self): + if (request.content_length or 0) > self.config["max_content_length"] or len( + request.data + ) > self.config[ + "max_content_length" + ]: # untestable with flask + raise RequestEntityTooLarge() + + def check_headers(self, endpoint: Callable[[], Response]): + if not any( + [ + valid_accept_mimetype in request.accept_mimetypes + for valid_accept_mimetype in endpoint.func.registered_valid_accept_mimetypes + ] + ): + raise NotAcceptable() + if ( + request.method == "POST" + and request.content_type not in endpoint.func.registered_valid_content_types + ): + raise UnsupportedMediaType() + + def check_allowed_methods(self, endpoint: Callable[[], Response]): + if request.method not in endpoint.func.registered_valid_methods: + raise MethodNotAllowed(valid_methods=endpoint.func.registered_valid_methods) + + def handle_request(self, endpoint: Callable[[], Response]): + self.check_authentication() + self.check_content_length() + self.check_allowed_methods(endpoint) + self.check_headers(endpoint) + return endpoint() + + @register_handler(r"^/taxii2/$") + def discovery_handler(self): + response = { + "title": self.config["title"], + } + for key in ["description", "contact"]: + if self.config.get(key): + response[key] = self.config.get(key) + default_api_root, api_roots = self.persistence.get_api_roots() + if default_api_root: + response["default"] = f"/{default_api_root.id}/" + response["api_roots"] = [f"/{api_root.id}/" for api_root in api_roots] + return make_taxii2_response(response) + + @register_handler(r"^/(?P[^/]+)/$") + def api_root_handler(self, api_root_id): + try: + api_root = self.persistence.get_api_root(api_root_id=api_root_id) + except DoesNotExist: + raise NotFound() + response = { + "title": api_root.title, + "versions": ["application/taxii+json;version=2.1"], + "max_content_length": self.config["max_content_length"], + } + if api_root.description: + response["description"] = api_root.description + return make_taxii2_response(response) + + @register_handler(r"^/(?P[^/]+)/status/(?P[^/]+)/$") + def job_handler(self, api_root_id, job_id): + try: + job, job_details = self.persistence.get_job_and_details( + api_root_id=api_root_id, job_id=job_id + ) + except DoesNotExist: + raise NotFound() + response = { + "id": job.id, + "status": job.status, + "request_timestamp": taxii2_datetimeformat(job.request_timestamp), + "total_count": job_details.total_count, + "success_count": len(job_details.success), + "successes": [ + job_detail.as_taxii2_dict() for job_detail in job_details.success + ], + "failure_count": len(job_details.failure), + "failures": [ + job_detail.as_taxii2_dict() for job_detail in job_details.failure + ], + "pending_count": len(job_details.pending), + "pendings": [ + job_detail.as_taxii2_dict() for job_detail in job_details.pending + ], + } + return make_taxii2_response(response) + + @register_handler(r"^/(?P[^/]+)/collections/$") + def collections_handler(self, api_root_id): + try: + self.persistence.get_api_root(api_root_id=api_root_id) + except DoesNotExist: + raise NotFound() + collections = self.persistence.get_collections(api_root_id=api_root_id) + response = {} + if collections: + response["collections"] = [] + for collection in collections: + data = { + "id": collection.id, + "title": collection.title, + "can_read": collection.can_read(context.account), + "can_write": collection.can_write(context.account), + "media_types": ["application/stix+json;version=2.1"], + } + for key in ["description", "alias"]: + value = getattr(collection, key, None) + if value: + data[key] = value + response["collections"].append(data) + return make_taxii2_response(response) + + @register_handler( + r"^/(?P[^/]+)/collections/(?P[^/]+)/$" + ) + def collection_handler(self, api_root_id, collection_id_or_alias): + try: + collection = self.persistence.get_collection( + api_root_id=api_root_id, collection_id_or_alias=collection_id_or_alias + ) + except DoesNotExist: + raise NotFound() + response = { + "id": collection.id, + "title": collection.title, + "can_read": collection.can_read(context.account), + "can_write": collection.can_write(context.account), + "media_types": ["application/stix+json;version=2.1"], + } + for key in ["description", "alias"]: + value = getattr(collection, key, None) + if value: + response[key] = value + return make_taxii2_response(response) + + @register_handler( + r"^/(?P[^/]+)/collections/(?P[^/]+)/manifest/$" + ) + def manifest_handler(self, api_root_id, collection_id_or_alias): + filter_params = validate_list_filter_params(request.args) + try: + manifest, more = self.persistence.get_manifest( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + **filter_params, + ) + except (DoesNotExist, NoReadPermission): + raise NotFound() + if manifest: + response = { + "more": more, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in manifest + ], + } + headers = { + "X-TAXII-Date-Added-First": min( + obj["date_added"] for obj in response["objects"] + ), + "X-TAXII-Date-Added-Last": max( + obj["date_added"] for obj in response["objects"] + ), + } + else: + response = {} + headers = {} + return make_taxii2_response( + response, + extra_headers=headers, + ) + + @register_handler( + r"^/(?P[^/]+)/collections/(?P[^/]+)/objects/$", + ("GET", "POST"), + valid_content_types=("application/taxii+json;version=2.1",), + ) + def objects_handler(self, api_root_id, collection_id_or_alias): + if request.method == "GET": + return self.objects_get_handler(api_root_id, collection_id_or_alias) + if request.method == "POST": + return self.objects_post_handler(api_root_id, collection_id_or_alias) + + def objects_get_handler(self, api_root_id, collection_id_or_alias): + filter_params = validate_list_filter_params(request.args) + try: + objects, more = self.persistence.get_objects( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + **filter_params, + ) + except (DoesNotExist, NoReadPermission): + raise NotFound() + if objects: + response = { + "more": more, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in objects + ], + } + headers = { + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + min(obj.date_added for obj in objects) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + max(obj.date_added for obj in objects) + ), + } + if more: + response["next"] = get_next_param(objects[-1]).decode() + else: + response = {} + headers = {} + return make_taxii2_response( + response, + extra_headers=headers, + ) + + def objects_post_handler(self, api_root_id, collection_id_or_alias): + validate_envelope(request.data) + try: + job, job_details = self.persistence.add_objects( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + data=request.get_json(), + ) + except (DoesNotExist, NoWritePermission): + raise NotFound() + response = { + "id": job.id, + "status": job.status, + "request_timestamp": taxii2_datetimeformat(job.request_timestamp), + "total_count": job_details.total_count, + "success_count": len(job_details.success), + "successes": [ + job_detail.as_taxii2_dict() for job_detail in job_details.success + ], + "failure_count": len(job_details.failure), + "failures": [ + job_detail.as_taxii2_dict() for job_detail in job_details.failure + ], + "pending_count": len(job_details.pending), + "pendings": [ + job_detail.as_taxii2_dict() for job_detail in job_details.pending + ], + } + headers = {} + return make_taxii2_response( + response, + 202, + extra_headers=headers, + ) + + @register_handler( + r"^/(?P[^/]+)/collections/(?P[^/]+)/objects/(?P[^/]+)/$", + ("GET", "DELETE"), + ) + def object_handler(self, api_root_id, collection_id_or_alias, object_id): + if request.method == "GET": + return self.object_get_handler( + api_root_id, collection_id_or_alias, object_id + ) + if request.method == "DELETE": + return self.object_delete_handler( + api_root_id, collection_id_or_alias, object_id + ) + + def object_get_handler(self, api_root_id, collection_id_or_alias, object_id): + filter_params = validate_object_filter_params(request.args) + try: + versions, more = self.persistence.get_object( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + object_id=object_id, + **filter_params, + ) + except (DoesNotExist, NoReadPermission): + raise NotFound() + if versions: + response = { + "more": more, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in versions + ], + } + headers = { + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + min(obj.date_added for obj in versions) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + max(obj.date_added for obj in versions) + ), + } + if more: + response["next"] = get_next_param(versions[-1]).decode() + else: + response = {} + headers = {} + return make_taxii2_response( + response, + extra_headers=headers, + ) + + def object_delete_handler(self, api_root_id, collection_id_or_alias, object_id): + filter_params = validate_delete_filter_params(request.args) + try: + self.persistence.delete_object( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + object_id=object_id, + **filter_params, + ) + except (DoesNotExist, NoReadNoWritePermission): + raise NotFound() + except (NoReadPermission, NoWritePermission): + raise Forbidden() + return make_taxii2_response("") + + @register_handler( + ( + r"^/(?P[^/]+)/collections/(?P[^/]+)" + r"/objects/(?P[^/]+)/versions/$" + ), + ) + def versions_handler(self, api_root_id, collection_id_or_alias, object_id): + filter_params = validate_versions_filter_params(request.args) + try: + versions, more = self.persistence.get_versions( + api_root_id=api_root_id, + collection_id_or_alias=collection_id_or_alias, + object_id=object_id, + **filter_params, + ) + except (DoesNotExist, NoReadPermission): + raise NotFound() + if versions: + response = { + "more": more, + "versions": [taxii2_datetimeformat(obj.version) for obj in versions], + } + headers = { + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + min(obj.date_added for obj in versions) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + max(obj.date_added for obj in versions) + ), + } + else: + response = {} + headers = {} + return make_taxii2_response( + response, + extra_headers=headers, + ) + class ServerMapping(NamedTuple): taxii1: Optional[TAXII1Server] @@ -381,12 +848,26 @@ def __init__(self, config: ServerConfig): server=self, api=initialize_api(config["auth_api"]) ) + @property + def real_servers(self): + return [server for server in self.servers if server is not None] + + @property + def current_server(self): + try: + server = context.taxiiserver + except AttributeError: + if len(self.real_servers) == 1: + server = self.real_servers[0] + else: + server = None + return server + def init_app(self, app: Flask): """Connect taxii1, taxii2 and auth to flask.""" self.app = app - for server in self.servers: - if server is not None: - server.init_app(app) + for server in self.real_servers: + server.init_app(app) self.auth.api.init_app(app) def is_basic_auth_supported(self): @@ -395,7 +876,7 @@ def is_basic_auth_supported(self): def get_endpoint(self, relative_path: str) -> Optional[Callable[[], Response]]: """Get first endpoint matching relative_path.""" - for server in self.servers: + for server in self.real_servers: endpoint = server.get_endpoint(relative_path) if endpoint: endpoint.server = server @@ -406,7 +887,7 @@ def handle_request(self, relative_path: str) -> Response: relative_path = "/" + relative_path endpoint = self.get_endpoint(relative_path) if not endpoint: - abort(404) + raise NotFound() context.taxiiserver = endpoint.server return endpoint() @@ -418,6 +899,22 @@ def handle_status_exception(self, error): """Dispatch status exception handling to appropriate taxii* server.""" return context.taxiiserver.handle_status_exception(error) + def handle_http_exception(self, error): + """Dispatch http exception handling to appropriate taxii* server.""" + server = self.current_server + if server: + return server.handle_http_exception(error) + else: + return error.get_response() + + def handle_validation_exception(self, error): + """Dispatch validation exception handling to appropriate taxii* server.""" + server = self.current_server + if server: + return server.handle_validation_exception(error) + else: + return error.get_response() + def raise_unauthorized(self): """Dispatch unauthorized handling to appropriate taxii* server.""" endpoint = self.get_endpoint(request.path) @@ -426,4 +923,5 @@ def raise_unauthorized(self): else: server = self.servers.taxii1 context.taxiiserver = server - server.raise_unauthorized() + # TODO: broken + return server.raise_unauthorized() diff --git a/opentaxii/sqldb_helper.py b/opentaxii/sqldb_helper.py index a534862c..ee5391d1 100644 --- a/opentaxii/sqldb_helper.py +++ b/opentaxii/sqldb_helper.py @@ -1,4 +1,8 @@ -from flask import _app_ctx_stack +try: + from greenlet import getcurrent as _ident_func +except ImportError: + from threading import get_ident as _ident_func + from sqlalchemy import engine, orm from sqlalchemy.orm.exc import UnmappedClassError @@ -26,8 +30,9 @@ class SQLAlchemyDB: def __init__(self, db_connection, base_model, session_options=None, **kwargs): self.engine = engine.create_engine(db_connection, **kwargs) self.Query = orm.Query - self.session = self.create_scoped_session(session_options) + self.session_options = session_options self.Model = self.extend_base_model(base_model) + self._session = None def extend_base_model(self, base): if not getattr(base, 'query_class', None): @@ -36,6 +41,12 @@ def extend_base_model(self, base): base.query = _QueryProperty(self) return base + @property + def session(self): + if self._session is None: + self._session = self.create_scoped_session(self.session_options) + return self._session + @property def metadata(self): return self.Model.metadata @@ -44,14 +55,18 @@ def create_scoped_session(self, options=None): options = options or {} - scopefunc = _app_ctx_stack.__ident_func__ + scopefunc = _ident_func options.setdefault('query_cls', self.Query) return orm.scoped_session( self.create_session(options), scopefunc=scopefunc) def create_session(self, options): - return orm.sessionmaker(bind=self.engine, **options) + kwargs = { + "bind": self.engine, + **options, + } + return orm.sessionmaker(**kwargs) def create_all_tables(self): self.metadata.create_all(bind=self.engine) diff --git a/opentaxii/taxii2/entities.py b/opentaxii/taxii2/entities.py index 5585eec1..c74e91fc 100644 --- a/opentaxii/taxii2/entities.py +++ b/opentaxii/taxii2/entities.py @@ -1,19 +1,23 @@ +"""Taxii2 entities.""" from datetime import datetime from opentaxii.common.entities import Entity +from opentaxii.entities import Account +from opentaxii.taxii2.utils import taxii2_datetimeformat class ApiRoot(Entity): """ TAXII2 API Root entity. - :param id int: id of this API root + :param str id: id of this API root :param bool default: indicator of default api root, should only be True once :param str title: human readable plain text name used to identify this API Root :param str description: human readable plain text description for this API Root """ - def __init__(self, id: int, default: bool, title: str, description: str): + def __init__(self, id: str, default: bool, title: str, description: str): + """Initialize ApiRoot.""" self.id = id self.default = default self.title = title @@ -25,21 +29,30 @@ class Collection(Entity): TAXII2 Collection entity. :param str id: id of this collection - :param int api_root_id: id of the :class:`ApiRoot` this collection belongs to + :param str api_root_id: id of the :class:`ApiRoot` this collection belongs to :param str title: human readable plain text name used to identify this collection :param str description: human readable plain text description for this collection :param str alias: human readable collection name that can be used on systems to alias a collection id """ def __init__( - self, id: str, api_root_id: int, title: str, description: str, alias: str + self, id: str, api_root_id: str, title: str, description: str, alias: str ): + """Initialize Collection.""" self.id = id self.api_root_id = api_root_id self.title = title self.description = description self.alias = alias + def can_read(self, account: Account): + """Determine if `account` is allowed to read from this collection.""" + return account.is_admin or "read" in set(account.permissions.get(self.id, [])) + + def can_write(self, account: Account): + """Determine if `account` is allowed to write to this collection.""" + return account.is_admin or "write" in set(account.permissions.get(self.id, [])) + class STIXObject(Entity): """ @@ -64,6 +77,7 @@ def __init__( version: datetime, serialized_data: dict, ): + """Initialize STIXObject.""" self.id = id self.collection_id = collection_id self.type = type @@ -73,11 +87,58 @@ def __init__( self.serialized_data = serialized_data +class ManifestRecord(Entity): + """ + TAXII2 ManifestRecord entity. + + This is a cut-down version of :class:`STIXObject`, for efficiency. + + :param str id: id of this stix object + :param datetime date_added: the date and time this object was added + :param datetime version: the version of this object + :param str spec_version: stix version this object matches + """ + + def __init__( + self, + id: str, + date_added: datetime, + version: datetime, + spec_version: str, + ): + """Initialize ManifestRecord.""" + self.id = id + self.date_added = date_added + self.version = version + self.spec_version = spec_version + + +class VersionRecord(Entity): + """ + TAXII2 VersionRecord entity. + + This is a cut-down version of :class:`STIXObject`, for efficiency. + + :param datetime date_added: the date and time this object was added + :param datetime version: the version of this object + """ + + def __init__( + self, + date_added: datetime, + version: datetime, + ): + """Initialize VersionRecord.""" + self.date_added = date_added + self.version = version + + class Job(Entity): """ TAXII2 Job entity, called a "status resource" in taxii2 docs. :param str id: id of this job + :param str api_root_id: id of the :class:`ApiRoot` this collection belongs to :param str status: status of this job :param datetime request_timestamp: the datetime of the request that this status resource is monitoring :param datetime completed_timestamp: the datetime of the completion of this job (used for cleanup) @@ -86,11 +147,14 @@ class Job(Entity): def __init__( self, id: str, + api_root_id: str, status: str, request_timestamp: datetime, completed_timestamp: datetime, ): + """Initialize Job.""" self.id = id + self.api_root_id = api_root_id self.status = status self.request_timestamp = request_timestamp self.completed_timestamp = completed_timestamp @@ -103,14 +167,24 @@ class JobDetail(Entity): :param str id: id of this job detail :param str job_id: id of the job this detail belongs to :param str stix_id: id of the :class:`STIXObject` this detail tracks - :param str version: the version of this object + :param datetime version: the version of this object :param str message: message indicating more information about the object being created, its pending state, or why the object failed to be created. + :param str status: status of this job """ - def __init__(self, id: str, job_id: str, stix_id: str, version: str, message: str): + def __init__(self, id: str, job_id: str, stix_id: str, version: datetime, message: str, status: str): + """Initialize JobDetail.""" self.id = id self.job_id = job_id self.stix_id = stix_id self.version = version self.message = message + self.status = status + + def as_taxii2_dict(self): + """Turn this object into a taxii2 dict.""" + response = {"id": self.stix_id, "version": taxii2_datetimeformat(self.version)} + if self.message: + response["message"] = self.message + return response diff --git a/opentaxii/taxii2/exceptions.py b/opentaxii/taxii2/exceptions.py index 09944976..a6bc44a6 100644 --- a/opentaxii/taxii2/exceptions.py +++ b/opentaxii/taxii2/exceptions.py @@ -1,4 +1,8 @@ -class ValidationError(Exception): +from marshmallow.exceptions import \ + ValidationError as MarshmallowValidationError + + +class ValidationError(MarshmallowValidationError): """ Exception used when taxii2 envelope doesn't pass validation. """ diff --git a/opentaxii/taxii2/http.py b/opentaxii/taxii2/http.py new file mode 100644 index 00000000..a0408d33 --- /dev/null +++ b/opentaxii/taxii2/http.py @@ -0,0 +1,15 @@ +"""Taxii2 http helper functions.""" +import json +from typing import Dict, Optional + +from flask import Response, make_response + + +def make_taxii2_response(data, status: Optional[int] = 200, extra_headers: Optional[Dict] = None) -> Response: + """Turn input data into valid taxii2 response.""" + if not isinstance(data, str): + data = json.dumps(data) + response = make_response((data, status)) + response.content_type = "application/taxii+json;version=2.1" + response.headers.update(extra_headers or {}) + return response diff --git a/opentaxii/taxii2/utils.py b/opentaxii/taxii2/utils.py new file mode 100644 index 00000000..81a522bb --- /dev/null +++ b/opentaxii/taxii2/utils.py @@ -0,0 +1,44 @@ +"""Utility functions for taxii2.""" +import base64 +import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from opentaxii.taxii2.entities import STIXObject + +DATETIMEFORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def taxii2_datetimeformat(input_value: datetime.datetime) -> str: + """ + Format datetime according to taxii2 spec. + + :param `datetime.datetime` input_value: The datetime object to format + + :return: The taxii2 string representation of `input_value` + :rtype: string + """ + return input_value.astimezone(datetime.timezone.utc).strftime(DATETIMEFORMAT) + + +def get_next_param(obj: "STIXObject") -> bytes: + """ + Get value for `next` based on :class:`STIXObject` instance. + + :param :class:`STIXObject` obj: The object to base the `next` param on + + :return: The value to use as `next` param + :rtype: str + """ + return base64.b64encode(f"{obj.date_added.isoformat()}|{obj.id}".encode("utf-8")) + + +def parse_next_param(next_param: bytes): + """ + Parse provided `next_param` into kwargs to be used to filter stix objects. + """ + date_added_str, obj_id = base64.b64decode(next_param).decode().split("|") + date_added = datetime.datetime.strptime( + date_added_str.split('+')[0], "%Y-%m-%dT%H:%M:%S.%f" + ).replace(tzinfo=datetime.timezone.utc) + return {"id": obj_id, "date_added": date_added} diff --git a/opentaxii/taxii2/validation.py b/opentaxii/taxii2/validation.py index 02557210..20a71e48 100644 --- a/opentaxii/taxii2/validation.py +++ b/opentaxii/taxii2/validation.py @@ -1,8 +1,13 @@ +"""Taxii2 validation functions.""" +import datetime import json +from marshmallow import Schema, fields from opentaxii.taxii2.exceptions import ValidationError +from opentaxii.taxii2.utils import DATETIMEFORMAT, parse_next_param from stix2 import parse from stix2.exceptions import STIXError +from werkzeug.datastructures import ImmutableMultiDict def validate_envelope(json_data: str, allow_custom: bool = False) -> None: @@ -27,3 +32,111 @@ def validate_envelope(json_data: str, allow_custom: bool = False) -> None: raise ValidationError( f"Invalid stix object: {json.dumps(item)}; {str(e)}" ) from e + + +class Taxii2DateTime(fields.DateTime): + """Taxii2 formatting compliant datetime field.""" + + DEFAULT_FORMAT = DATETIMEFORMAT + + def _deserialize(self, value, attr, data, **kwargs): + value = super()._deserialize(value, attr, data, **kwargs) + return value.replace(tzinfo=datetime.timezone.utc) + + +class Taxii2Next(fields.Field): + """Implemenatation of the taxii2 `next` query param.""" + + def _deserialize(self, value, attr, data, **kwargs): + value = super()._deserialize(value, attr, data, **kwargs) + try: + value = parse_next_param(value) + except: # noqa + raise ValidationError("Not a valid value.") + return value + + +class Taxii2Filter(fields.Field): + """General taxii2 filter implementation.""" + + def _deserialize(self, value, attr, data, **kwargs): + value = super()._deserialize(value, attr, data, **kwargs) + return value.split(',') + + +class Taxii2VersionFilter(Taxii2Filter): + """Taxii2 compliant version filter.""" + + def _deserialize(self, value, attr, data, **kwargs): + values = super()._deserialize(value, attr, data, **kwargs) + new_values = [] + for value in values: + if value not in ["first", "last", "all"]: + try: + value = datetime.datetime.strptime(value, DATETIMEFORMAT).replace(tzinfo=datetime.timezone.utc) + except ValueError: + pass + new_values.append(value) + return new_values + + +class VersionFilterParamsSchema(Schema): + """Schema for the versions endpoint filters.""" + + limit = fields.Int() + added_after = Taxii2DateTime() + next_kwargs = Taxii2Next(data_key="next") + match_spec_version = Taxii2Filter(data_key="match[spec_version]") + + +class ObjectFilterParamsSchema(Schema): + """Schema for the object endpoint filters.""" + + limit = fields.Int() + added_after = Taxii2DateTime() + next_kwargs = Taxii2Next(data_key="next") + match_spec_version = Taxii2Filter(data_key="match[spec_version]") + match_version = Taxii2VersionFilter(data_key="match[version]") + + +class ListFilterParamsSchema(Schema): + """Schema for the object list endpoint filters.""" + + limit = fields.Int() + added_after = Taxii2DateTime() + next_kwargs = Taxii2Next(data_key="next") + match_spec_version = Taxii2Filter(data_key="match[spec_version]") + match_version = Taxii2VersionFilter(data_key="match[version]") + match_id = Taxii2Filter(data_key="match[id]") + match_type = Taxii2Filter(data_key="match[type]") + + +class DeleteFilterParamsSchema(Schema): + """Schema for the object delete endpoint filters.""" + + match_version = Taxii2VersionFilter(data_key="match[version]") + match_spec_version = Taxii2Filter(data_key="match[spec_version]") + + +def validate_object_filter_params(filter_params: ImmutableMultiDict) -> dict: + """Validate and load filter params for the object endpoint.""" + parsed_params = ObjectFilterParamsSchema().load(filter_params) + return parsed_params + + +def validate_list_filter_params(filter_params: ImmutableMultiDict) -> dict: + """Validate and load filter params for the list endpoint.""" + parsed_params = ListFilterParamsSchema().load(filter_params) + return parsed_params + + +def validate_versions_filter_params(filter_params: ImmutableMultiDict) -> dict: + """Validate and load filter params for the versions endpoint.""" + parsed_params = ListFilterParamsSchema().load(filter_params) + return parsed_params + + +def validate_delete_filter_params(filter_params: ImmutableMultiDict) -> dict: + """Validate and load filter params for the delete endpoint.""" + parsed_params = DeleteFilterParamsSchema().load(filter_params) + return parsed_params diff --git a/opentaxii/utils.py b/opentaxii/utils.py index 3f06126b..626d1b79 100644 --- a/opentaxii/utils.py +++ b/opentaxii/utils.py @@ -1,8 +1,11 @@ import base64 import binascii +import functools import importlib import logging +import re import sys +from typing import Optional, Tuple import structlog from six.moves import urllib @@ -23,15 +26,15 @@ def get_path_and_address(domain, address): def import_class(module_class_name): - module_name, _, class_name = module_class_name.rpartition('.') + module_name, _, class_name = module_class_name.rpartition(".") module = importlib.import_module(module_name) return getattr(module, class_name) def initialize_api(api_config): - class_name = api_config['class'] + class_name = api_config["class"] cls = import_class(class_name) - params = api_config.get('parameters', None) + params = api_config.get("parameters", None) if params: instance = cls(**params) @@ -50,52 +53,48 @@ def parse_basic_auth_token(token): raise InvalidAuthHeader("Can't decode Basic Auth header value") try: - value = value.decode('utf-8') - username, password = value.split(':', 1) + value = value.decode("utf-8") + username, password = value.split(":", 1) return (username, password) except ValueError: raise InvalidAuthHeader("Invalid Basic Auth header value") class PlainRenderer: - def __call__(self, logger, name, event_dict): details = event_dict.copy() - timestamp = details.pop('timestamp') - logger = details.pop('logger') - level = details.pop('level') - event = details.pop('event') - pairs = ', '.join(['%s=%s' % (k, v) for k, v in details.items()]) - return ( - '{timestamp} [{logger}] {level}: {event} {pairs}' - .format( - timestamp=timestamp, - logger=logger, - level=level, - event=event, - pairs=('{{{}}}'.format(pairs) if pairs else ""))) + timestamp = details.pop("timestamp") + logger = details.pop("logger") + level = details.pop("level") + event = details.pop("event") + pairs = ", ".join(["%s=%s" % (k, v) for k, v in details.items()]) + return "{timestamp} [{logger}] {level}: {event} {pairs}".format( + timestamp=timestamp, + logger=logger, + level=level, + event=event, + pairs=("{{{}}}".format(pairs) if pairs else ""), + ) def configure_logging(logging_levels, plain=False, stream=sys.stderr): - renderer = ( - PlainRenderer() if plain else - structlog.processors.JSONRenderer()) + renderer = PlainRenderer() if plain else structlog.processors.JSONRenderer() attr_processors = [ structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, - structlog.processors.TimeStamper(fmt='iso') + structlog.processors.TimeStamper(fmt="iso"), ] structlog.configure_once( processors=( - [structlog.stdlib.filter_by_level] + - attr_processors + - [ + [structlog.stdlib.filter_by_level] + + attr_processors + + [ structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, - structlog.stdlib.ProcessorFormatter.wrap_for_formatter + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, ] ), context_class=dict, @@ -105,7 +104,8 @@ def configure_logging(logging_levels, plain=False, stream=sys.stderr): ) formatter = structlog.stdlib.ProcessorFormatter( - processor=renderer, foreign_pre_chain=attr_processors) + processor=renderer, foreign_pre_chain=attr_processors + ) handler = AtomicStreamHandler(stream) handler.setFormatter(formatter) @@ -119,8 +119,8 @@ def configure_logging(logging_levels, plain=False, stream=sys.stderr): for logger, level in logging_levels.items(): - if logger.lower() == 'root': - logger = '' + if logger.lower() == "root": + logger = "" logging.getLogger(logger).setLevel(level.upper()) @@ -154,37 +154,38 @@ class AtomicStreamHandler(logging.StreamHandler): def emit(self, record): try: msg = self.format(record) - self.stream.write('%s\n' % msg) + self.stream.write("%s\n" % msg) self.flush() except Exception: self.handleError(record) def sync_conf_dict_into_db(server, config, force_collection_deletion=False): - services = config.get('services', []) + services = config.get("services", []) sync_services(server.servers.taxii1, services) - collections = config.get('collections', []) + collections = config.get("collections", []) sync_collections( - server.servers.taxii1, collections, force_deletion=force_collection_deletion) - accounts = config.get('accounts', []) + server.servers.taxii1, collections, force_deletion=force_collection_deletion + ) + accounts = config.get("accounts", []) sync_accounts(server, accounts) def sync_services(server, services): manager = server.persistence - defined_by_id = {s['id']: s for s in services} + defined_by_id = {s["id"]: s for s in services} existing_by_id = {s.id: s for s in manager.get_services()} created_counter = 0 updated_counter = 0 for service in services: - existing = existing_by_id.get(service['id']) + existing = existing_by_id.get(service["id"]) if existing: properties = service.copy() - properties.pop('id') - existing.type = properties.pop('type') + properties.pop("id") + existing.type = properties.pop("type") existing.properties = properties existing = manager.update_service(existing) log.info("sync_services.updated", id=existing.id) @@ -206,40 +207,39 @@ def sync_services(server, services): "sync_services.stats", updated=updated_counter, created=created_counter, - deleted=deleted_counter) + deleted=deleted_counter, + ) def sync_collections(server, collections, force_deletion=False): manager = server.persistence - defined_by_name = {c['name']: c for c in collections} + defined_by_name = {c["name"]: c for c in collections} existing_by_name = {c.name: c for c in manager.get_collections()} created_counter = 0 updated_counter = 0 for collection in collections: - existing = existing_by_name.get(collection['name']) + existing = existing_by_name.get(collection["name"]) collection_data = collection.copy() - service_ids = collection_data.pop('service_ids') + service_ids = collection_data.pop("service_ids") if existing: - collection_data.pop('id', None) + collection_data.pop("id", None) bindings = deserialize_content_bindings( - collection_data.pop('supported_content', [])) + collection_data.pop("supported_content", []) + ) for k, v in collection_data.items(): setattr(existing, k, v) existing.supported_content = bindings cobj = manager.update_collection(existing) manager.set_collection_services(cobj.id, service_ids) - log.info( - "sync_collections.updated", name=cobj.name, id=cobj.id) + log.info("sync_collections.updated", name=cobj.name, id=cobj.id) updated_counter += 1 else: - cobj = manager.create_collection( - CollectionEntity(**collection_data)) + cobj = manager.create_collection(CollectionEntity(**collection_data)) manager.set_collection_services(cobj.id, service_ids) - log.info( - "sync_collections.created", name=cobj.name, id=cobj.id) + log.info("sync_collections.created", name=cobj.name, id=cobj.id) created_counter += 1 disabled_counter = 0 @@ -261,39 +261,42 @@ def sync_collections(server, collections, force_deletion=False): updated=updated_counter, created=created_counter, disabled=disabled_counter, - deleted=deleted_counter) + deleted=deleted_counter, + ) def sync_accounts(server, accounts): manager = server.auth - defined_by_username = {a['username']: a for a in accounts} + defined_by_username = {a["username"]: a for a in accounts} existing_by_username = {a.username: a for a in manager.get_accounts()} created_counter = 0 updated_counter = 0 for account in accounts: - existing = existing_by_username.get(account['username']) + existing = existing_by_username.get(account["username"]) if existing: properties = account.copy() - password = properties.pop('password') - existing.permissions = properties.get('permissions', {}) - existing.is_admin = properties.get('is_admin', False) + password = properties.pop("password") + existing.permissions = properties.get("permissions", {}) + existing.is_admin = properties.get("is_admin", False) existing = manager.update_account(existing, password) log.info("sync_accounts.updated", username=existing.username) updated_counter += 1 else: obj = Account( id=None, - username=account['username'], - permissions=account.get('permissions', {}), - is_admin=account.get('is_admin', False)) - obj = manager.update_account(obj, account['password']) + username=account["username"], + permissions=account.get("permissions", {}), + is_admin=account.get("is_admin", False), + ) + obj = manager.update_account(obj, account["password"]) log.info("sync_accounts.created", username=obj.username) created_counter += 1 deleted_counter = 0 - missing_usernames = ( - set(existing_by_username.keys()) - set(defined_by_username.keys())) + missing_usernames = set(existing_by_username.keys()) - set( + defined_by_username.keys() + ) for username in missing_usernames: manager.delete_account(username) deleted_counter += 1 @@ -303,4 +306,44 @@ def sync_accounts(server, accounts): "sync_accounts.stats", updated=updated_counter, created=created_counter, - deleted=deleted_counter) + deleted=deleted_counter, + ) + + +def register_handler( + url_re: str, + valid_methods: Optional[Tuple[str]] = None, + valid_accept_mimetypes: Optional[Tuple[str]] = None, + valid_content_types: Optional[Tuple[str]] = None, +): + """ + Register decorated method as handler function for `url_re`. + + :param str url_re: The regex to trigger the handler on + :param list valid_methods: The list of methods to accept for this handler, defaults to ("GET",) + :param list valid_accept_mimetypes: + The list of accepted mimetypes to accept for this handler, defaults to + ("application/taxii+json;version=2.1",) + :param list valid_content_types: + The list of content types to accept for this handler, defaults to + ("application/json",) + """ + if valid_methods is None: + valid_methods = ("GET",) + if valid_accept_mimetypes is None: + valid_accept_mimetypes = ("application/taxii+json;version=2.1",) + if valid_content_types is None: + valid_content_types = ("application/json",) + + def inner_decorator(method): + @functools.wraps(method) + def inner(*args, **kwargs): + return method(*args, **kwargs) + + inner.registered_url_re = re.compile(url_re) + inner.registered_valid_methods = valid_methods + inner.registered_valid_accept_mimetypes = valid_accept_mimetypes + inner.registered_valid_content_types = valid_content_types + return inner + + return inner_decorator diff --git a/pytest.ini b/pytest.ini index b042c542..f3bbb73e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,8 +1,10 @@ [pytest] addopts = --verbose - --showlocals + --showlocals --cov=opentaxii --cov-config .coveragerc --cov-report term-missing --cov-report html +markers = + truncate: marks tests as needing database truncate (no transactions) diff --git a/requirements-dev.txt b/requirements-dev.txt index d6c260ab..16a1ba0f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,5 @@ pytest>=4.6 pytest-pythonpath flake8 ipdb +factory-boy>=3.2.1 -r requirements-interrogate.txt diff --git a/requirements.txt b/requirements.txt index d579d74c..df2ddbdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyjwt>=1.4.0 six>=1.10.0 mypy-extensions>=0.4.3 stix2>=3.0.1 +marshmallow>=3.14.1 diff --git a/setup.py b/setup.py index 3dea2fde..edad9d65 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,9 @@ def get_file_contents(filename): 'opentaxii-update-account = opentaxii.cli.auth:update_account', 'opentaxii-sync-data = opentaxii.cli.persistence:sync_data_configuration', 'opentaxii-delete-blocks = opentaxii.cli.persistence:delete_content_blocks', + 'opentaxii-add-api-root = opentaxii.cli.persistence:add_api_root', + 'opentaxii-add-collection = opentaxii.cli.persistence:add_collection', + 'opentaxii-job-cleanup = opentaxii.cli.persistence:job_cleanup', ] }, install_requires=install_requires, diff --git a/tests/conftest.py b/tests/conftest.py index 72e1d44a..ee7bffaf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,45 @@ +import base64 import os from tempfile import mkstemp +from unittest.mock import patch import pytest +from flask.testing import FlaskClient from opentaxii.config import ServerConfig from opentaxii.local import context, release_context from opentaxii.middleware import create_app +from opentaxii.persistence.sqldb.taxii2models import (ApiRoot, Collection, Job, + JobDetail, STIXObject) from opentaxii.server import TAXIIServer from opentaxii.taxii.converters import dict_to_service_entity +from opentaxii.taxii.http import HTTP_AUTHORIZATION from opentaxii.utils import configure_logging -from fixtures import DOMAIN, SERVICES +from tests.fixtures import (ACCOUNT, COLLECTIONS_B, DOMAIN, PASSWORD, SERVICES, + USERNAME, VALID_TOKEN) +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, JOB_DETAILS, JOBS, + STIX_OBJECTS) + + +class CustomClient(FlaskClient): + def __init__(self, *args, **kwargs): + self.headers = kwargs.pop("headers", {}) + super().__init__(*args, **kwargs) + + def open(self, *args, **kwargs): + headers = kwargs.pop("headers", {}) + new_headers = {**self.headers, **headers} + if new_headers: + kwargs["headers"] = new_headers + return super().open(*args, **kwargs) + DBTYPE = os.getenv("DBTYPE", "sqlite") if DBTYPE == "sqlite": @pytest.fixture(scope="session") def dbconn(): - filehandle, filename = mkstemp(suffix='.db') + filehandle, filename = mkstemp(suffix=".db") os.close(filehandle) try: yield f"sqlite:///{filename}" @@ -26,6 +49,7 @@ def dbconn(): except FileNotFoundError: pass + elif DBTYPE in ("mysql", "mariadb"): import MySQLdb @@ -151,17 +175,56 @@ def clean_db(dbconn): ) +@pytest.fixture(scope="session") +def session_taxiiserver(dbconn): + clean_db(dbconn) + yield TAXIIServer(prepare_test_config(dbconn)) + + @pytest.fixture() -def app(dbconn): +def app(request, dbconn, session_taxiiserver): + truncate = request.node.get_closest_marker("truncate") or DBTYPE == "sqlite" + if truncate: + yield from truncate_app(dbconn) + else: + yield from transaction_app(dbconn, session_taxiiserver) + + +def transaction_app(dbconn, taxiiserver): + context.server = taxiiserver + app = create_app(context.server) + app.config["TESTING"] = True + managers = [taxiiserver.auth] + [subserver.persistence for subserver in taxiiserver.servers] + transactions = [] + connections = [] + sessions = [] + for manager in managers: + connection = manager.api.db.engine.connect() + transaction = connection.begin() + manager.api.db.session_options["bind"] = connection + transactions.append(transaction) + connections.append(connection) + sessions.append(manager.api.db.session) + yield app + for (transaction, connection, session, manager) in zip(transactions, connections, sessions, managers): + transaction.rollback() + connection.close() + session.remove() + manager.api.db._session = None + + +def truncate_app(dbconn): clean_db(dbconn) - server = TAXIIServer(prepare_test_config(dbconn)) - context.server = server + taxiiserver = TAXIIServer(prepare_test_config(dbconn)) + context.server = taxiiserver app = create_app(context.server) app.config["TESTING"] = True yield app - for part in [server.auth] + [subserver.persistence for subserver in server.servers]: - part.api.db.session.commit() - part.api.db.engine.dispose() + + +@pytest.fixture() +def taxii2_sqldb_api(app): + yield app.taxii_server.servers.taxii2.persistence.api @pytest.fixture() @@ -172,12 +235,112 @@ def server(app, anonymous_user): @pytest.fixture() def client(app): + app.test_client_class = CustomClient return app.test_client() +def basic_auth_token(username, password): + return base64.b64encode("{}:{}".format(username, password).encode("utf-8")) + + +def MOCK_AUTHENTICATE(username, password): + if username == USERNAME and password == PASSWORD: + return VALID_TOKEN + return None + + +def MOCK_GET_ACCOUNT(token): + if token == VALID_TOKEN: + return ACCOUNT + return None + + +@pytest.fixture() +def authenticated_client(client): + basic_auth_header = "Basic {}".format( + basic_auth_token(USERNAME, PASSWORD).decode("utf-8") + ) + headers = { + HTTP_AUTHORIZATION: basic_auth_header, + } + client.headers = headers + client.account = ACCOUNT + with patch.object( + client.application.taxii_server.auth.api, + "authenticate", + side_effect=MOCK_AUTHENTICATE, + ), patch.object( + client.application.taxii_server.auth.api, + "get_account", + side_effect=MOCK_GET_ACCOUNT, + ): + yield client + + @pytest.fixture() def services(server): for service in SERVICES: server.servers.taxii1.persistence.update_service( dict_to_service_entity(service) ) + + +@pytest.fixture() +def collections(server): + for collection in COLLECTIONS_B: + server.servers.taxii1.persistence.create_collection(collection) + + +@pytest.fixture() +def account(server): + server.auth.api.create_account(ACCOUNT.username, "mypass") + + +@pytest.fixture(scope="function") +def db_api_roots(request, taxii2_sqldb_api): + try: + api_roots = request.param + except AttributeError: + api_roots = API_ROOTS + for api_root in api_roots: + taxii2_sqldb_api.db.session.add(ApiRoot.from_entity(api_root)) + taxii2_sqldb_api.db.session.commit() + yield api_roots + + +@pytest.fixture(scope="function") +def db_jobs(request, taxii2_sqldb_api, db_api_roots): + try: + (jobs, job_details) = request.param + except AttributeError: + (jobs, job_details) = (JOBS, JOB_DETAILS) + for job in jobs: + taxii2_sqldb_api.db.session.add(Job.from_entity(job)) + for job_detail in job_details: + taxii2_sqldb_api.db.session.add(JobDetail.from_entity(job_detail)) + taxii2_sqldb_api.db.session.commit() + yield (jobs, job_details) + + +@pytest.fixture(scope="function") +def db_collections(request, taxii2_sqldb_api, db_api_roots): + try: + collections = request.param + except AttributeError: + collections = COLLECTIONS + for collection in collections: + taxii2_sqldb_api.db.session.add(Collection.from_entity(collection)) + taxii2_sqldb_api.db.session.commit() + yield collections + + +@pytest.fixture(scope="function") +def db_stix_objects(request, taxii2_sqldb_api, db_collections): + try: + stix_objects = request.param + except AttributeError: + stix_objects = STIX_OBJECTS + for stix_object in stix_objects: + taxii2_sqldb_api.db.session.add(STIXObject.from_entity(stix_object)) + taxii2_sqldb_api.db.session.commit() + yield stix_objects diff --git a/tests/fixtures.py b/tests/fixtures.py index d8d72169..6fae04bd 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,7 +1,8 @@ -from libtaxii.constants import ( - VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10, - CB_STIX_XML_111) +from uuid import uuid4 +from libtaxii.constants import (CB_STIX_XML_111, VID_TAXII_HTTP_10, + VID_TAXII_HTTPS_10) +from opentaxii.entities import Account from opentaxii.taxii import entities PROTOCOL_BINDINGS = [VID_TAXII_HTTP_10, VID_TAXII_HTTPS_10] @@ -128,3 +129,8 @@ 'available': False }] ] + +USERNAME = "some-username" +PASSWORD = "some-password" +ACCOUNT = Account(str(uuid4()), USERNAME, {}) +VALID_TOKEN = "valid-token" diff --git a/tests/services/test_inbox.py b/tests/services/test_inbox.py index 419f23c9..da61850e 100644 --- a/tests/services/test_inbox.py +++ b/tests/services/test_inbox.py @@ -10,35 +10,33 @@ from utils import as_tm, prepare_headers -def make_content(version, content_binding=CUSTOM_CONTENT_BINDING, - content=CONTENT, subtype=None): +def make_content( + version, content_binding=CUSTOM_CONTENT_BINDING, content=CONTENT, subtype=None +): if version == 10: return tm10.ContentBlock(content_binding, content) elif version == 11: - content_block = tm11.ContentBlock( - tm11.ContentBinding(content_binding), content) + content_block = tm11.ContentBlock(tm11.ContentBinding(content_binding), content) if subtype: content_block.content_binding.subtype_ids.append(subtype) return content_block else: - raise ValueError('Unknown TAXII message version: %s' % version) + raise ValueError("Unknown TAXII message version: %s" % version) def make_inbox_message(version, blocks=None, dest_collection=None): if version == 10: - inbox_message = tm10.InboxMessage( - message_id=MESSAGE_ID, content_blocks=blocks) + inbox_message = tm10.InboxMessage(message_id=MESSAGE_ID, content_blocks=blocks) elif version == 11: - inbox_message = tm11.InboxMessage( - message_id=MESSAGE_ID, content_blocks=blocks) + inbox_message = tm11.InboxMessage(message_id=MESSAGE_ID, content_blocks=blocks) if dest_collection: inbox_message.destination_collection_names.append(dest_collection) else: - raise ValueError('Unknown TAXII message version: %s' % version) + raise ValueError("Unknown TAXII message version: %s" % version) return inbox_message @@ -47,9 +45,7 @@ def make_inbox_message(version, blocks=None, dest_collection=None): def prepare_server(server, services): from opentaxii.persistence.sqldb.models import DataCollection - coll_mapping = { - 'inbox-A': COLLECTIONS_A, - 'inbox-B': COLLECTIONS_B} + coll_mapping = {"inbox-A": COLLECTIONS_A, "inbox-B": COLLECTIONS_B} names = set() for service, collections in coll_mapping.items(): @@ -59,18 +55,25 @@ def prepare_server(server, services): names.add(coll.name) service_ids = [service] else: - coll = DataCollection.query.filter_by(name=coll.name).one() + coll = ( + server.servers.taxii1.persistence.api.db.session.query( + DataCollection + ) + .filter_by(name=coll.name) + .one() + ) service_ids = {s.id for s in coll.services} | {service} server.servers.taxii1.persistence.set_collection_services( - coll.id, service_ids=service_ids) + coll.id, service_ids=service_ids + ) @pytest.mark.parametrize("https", [True, False]) @pytest.mark.parametrize("version", [11, 10]) def test_inbox_request_all_content(server, version, https): - inbox_a = server.servers.taxii1.get_service('inbox-A') + inbox_a = server.servers.taxii1.get_service("inbox-A") headers = prepare_headers(version, https) @@ -78,10 +81,9 @@ def test_inbox_request_all_content(server, version, https): make_content( version, content_binding=CUSTOM_CONTENT_BINDING, - subtype=CONTENT_BINDING_SUBTYPE), - make_content( - version, - content_binding=INVALID_CONTENT_BINDING) + subtype=CONTENT_BINDING_SUBTYPE, + ), + make_content(version, content_binding=INVALID_CONTENT_BINDING), ] inbox_message = make_inbox_message(version, blocks=blocks) @@ -102,17 +104,18 @@ def test_inbox_request_destination_collection(server, https): version = 11 inbox_message = make_inbox_message( - version, blocks=[make_content(version)], dest_collection=None) + version, blocks=[make_content(version)], dest_collection=None + ) headers = prepare_headers(version, https) - inbox = server.servers.taxii1.get_service('inbox-A') + inbox = server.servers.taxii1.get_service("inbox-A") # destination collection is not required for inbox-A response = inbox.process(headers, inbox_message) assert isinstance(response, as_tm(version).StatusMessage) assert response.status_type == ST_SUCCESS - inbox = server.servers.taxii1.get_service('inbox-B') + inbox = server.servers.taxii1.get_service("inbox-B") # destination collection is required for inbox-B with pytest.raises(exceptions.StatusMessageException): response = inbox.process(headers, inbox_message) @@ -122,20 +125,20 @@ def test_inbox_request_destination_collection(server, https): @pytest.mark.parametrize("version", [11, 10]) def test_inbox_request_inbox_valid_content_binding(server, version, https): - inbox = server.servers.taxii1.get_service('inbox-B') + inbox = server.servers.taxii1.get_service("inbox-B") blocks = [ make_content( version, content_binding=CUSTOM_CONTENT_BINDING, - subtype=CONTENT_BINDING_SUBTYPE), - make_content( - version, - content_binding=CB_STIX_XML_111) + subtype=CONTENT_BINDING_SUBTYPE, + ), + make_content(version, content_binding=CB_STIX_XML_111), ] inbox_message = make_inbox_message( - version, dest_collection=COLLECTION_OPEN, blocks=blocks) + version, dest_collection=COLLECTION_OPEN, blocks=blocks + ) headers = prepare_headers(version, https) response = inbox.process(headers, inbox_message) @@ -153,11 +156,12 @@ def test_inbox_request_inbox_valid_content_binding(server, version, https): @pytest.mark.parametrize("version", [11, 10]) def test_inbox_req_inbox_invalid_inbox_content_binding(server, version, https): - inbox = server.servers.taxii1.get_service('inbox-B') + inbox = server.servers.taxii1.get_service("inbox-B") content = make_content(version, content_binding=INVALID_CONTENT_BINDING) inbox_message = make_inbox_message( - version, dest_collection=COLLECTION_OPEN, blocks=[content]) + version, dest_collection=COLLECTION_OPEN, blocks=[content] + ) headers = prepare_headers(version, https) @@ -177,7 +181,7 @@ def test_inbox_req_inbox_invalid_inbox_content_binding(server, version, https): @pytest.mark.parametrize("version", [11, 10]) def test_inbox_req_coll_content_bindings_filtering(server, version, https): - inbox = server.servers.taxii1.get_service('inbox-B') + inbox = server.servers.taxii1.get_service("inbox-B") headers = prepare_headers(version, https) blocks = [ @@ -186,7 +190,8 @@ def test_inbox_req_coll_content_bindings_filtering(server, version, https): ] inbox_message = make_inbox_message( - version, dest_collection=COLLECTION_ONLY_STIX, blocks=blocks) + version, dest_collection=COLLECTION_ONLY_STIX, blocks=blocks + ) response = inbox.process(headers, inbox_message) diff --git a/tests/taxii2/factories.py b/tests/taxii2/factories.py new file mode 100644 index 00000000..58d88794 --- /dev/null +++ b/tests/taxii2/factories.py @@ -0,0 +1,22 @@ +"""Factories for taxii2 entities.""" +import datetime +from uuid import uuid4 + +import factory +import stix2 +from opentaxii.taxii2.entities import STIXObject + + +class STIXObjectFactory(factory.Factory): + id = factory.LazyAttribute(lambda o: f"{o.type}--{str(uuid4())}") + collection_id = factory.Faker("uuid4") + type = factory.Faker("random_element", elements=tuple(stix2.v21.OBJ_MAP.keys())) + spec_version = factory.Faker("random_element", elements=("2.0", "2.1")) + date_added = factory.Faker("date_time", tzinfo=datetime.timezone.utc) + version = factory.Faker("date_time", tzinfo=datetime.timezone.utc) + serialized_data = factory.Faker( + "pydict" + ) # TODO replace with valid stix object data generator + + class Meta: + model = STIXObject diff --git a/tests/taxii2/test_taxii2_api_root.py b/tests/taxii2/test_taxii2_api_root.py new file mode 100644 index 00000000..d81e3bc3 --- /dev/null +++ b/tests/taxii2/test_taxii2_api_root.py @@ -0,0 +1,270 @@ +import json +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from opentaxii.persistence.sqldb import taxii2models +from sqlalchemy import literal +from tests.taxii2.utils import (API_ROOTS, API_ROOTS_WITH_DEFAULT, + GET_API_ROOT_MOCK, config_noop, + config_override, server_mapping_noop, + server_mapping_remove_fields) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "config_override_func", + "server_mapping_override_func", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": API_ROOTS[0].title, + "description": API_ROOTS[0].description, + "versions": ["application/taxii+json;version=2.1"], + "max_content_length": 1024, + }, + id="good, first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": API_ROOTS[1].title, + "description": API_ROOTS[1].description, + "versions": ["application/taxii+json;version=2.1"], + "max_content_length": 1024, + }, + id="good, second", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[2].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": API_ROOTS[2].title, + "versions": ["application/taxii+json;version=2.1"], + "max_content_length": 1024, + }, + id="good, no description", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + config_override({"max_content_length": 1024000}), + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": API_ROOTS[0].title, + "description": API_ROOTS[0].description, + "versions": ["application/taxii+json;version=2.1"], + "max_content_length": 1024000, + }, + id="good, first, max_content_length override", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + config_noop, + server_mapping_remove_fields("taxii1"), + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii2 only config", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + config_noop, + server_mapping_noop, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii1/2 config", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_api_root( + authenticated_client, + method, + api_root_id, + headers, + config_override_func, + server_mapping_override_func, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), + ): + func = getattr(authenticated_client, method) + response = func(f"/{api_root_id}/", headers=headers) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_api_root_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/") + assert response.status_code == 401 + + +@pytest.mark.parametrize( + ["title", "description", "default", "db_api_roots"], + [ + pytest.param( + "my new api root", # title + None, # description + False, # default + [], # db_api_roots + id="title only", + ), + pytest.param( + "my new api root", # title + "my description", # description + False, # default + [], # db_api_roots + id="title, description", + ), + pytest.param( + "my new api root", # title + None, # description + True, # default + [], # db_api_roots + id="title, default", + ), + pytest.param( + "my new api root", # title + "my description", # description + True, # default + API_ROOTS_WITH_DEFAULT, # db_api_roots + id="title, description, default, existing", + ), + ], + indirect=["db_api_roots"], +) +def test_add_api_root(app, title, description, default, db_api_roots): + api_root = app.taxii_server.servers.taxii2.persistence.api.add_api_root( + title, description, default + ) + assert api_root.id is not None + assert api_root.title == title + assert api_root.description == description + assert api_root.default == default + db_api_root = ( + app.taxii_server.servers.taxii2.persistence.api.db.session.query( + taxii2models.ApiRoot + ) + .filter(taxii2models.ApiRoot.id == api_root.id) + .one() + ) + assert db_api_root.title == title + assert db_api_root.description == description + assert db_api_root.default == default + if default: + assert ( + app.taxii_server.servers.taxii2.persistence.api.db.session.query( + taxii2models.ApiRoot + ) + .filter(taxii2models.ApiRoot.default == literal(True)) + .count() + ) == 1 diff --git a/tests/taxii2/test_taxii2_collection.py b/tests/taxii2/test_taxii2_collection.py new file mode 100644 index 00000000..bbdcf187 --- /dev/null +++ b/tests/taxii2/test_taxii2_collection.py @@ -0,0 +1,267 @@ +import json +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from opentaxii.persistence.sqldb import taxii2models +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_API_ROOT_MOCK, + GET_COLLECTION_MOCK) + + +@pytest.mark.parametrize( + "method,headers,api_root_id,collection_id,expected_status,expected_headers,expected_content", + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[0].id, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": COLLECTIONS[0].id, + "title": "0Read only", + "description": "Read only description", + "can_read": True, + "can_write": False, + "media_types": ["application/stix+json;version=2.1"], + }, + id="good", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[4].id, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": COLLECTIONS[4].id, + "title": "4No description", + "can_read": True, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + id="good, no description", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": COLLECTIONS[5].id, + "title": "5With alias", + "description": "With alias description", + "alias": "this-is-an-alias", + "can_read": False, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + id="good, with description", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].alias, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": COLLECTIONS[5].id, + "title": "5With alias", + "description": "With alias description", + "alias": "this-is-an-alias", + "can_read": False, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + id="good, by alias", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + COLLECTIONS[0].id, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + COLLECTIONS[0].id, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown collection", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[0].id, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[0].id, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_collection( + authenticated_client, + method, + api_root_id, + collection_id, + headers, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write"], + }, + ): + func = getattr(authenticated_client, method) + response = func(f"/{api_root_id}/collections/{collection_id}/", headers=headers) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_collection_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/collections/{COLLECTIONS[0].id}/") + assert response.status_code == 401 + + +@pytest.mark.parametrize( + ["api_root_id", "title", "description", "alias"], + [ + pytest.param( + API_ROOTS[0].id, # api_root_id + "my new collection", # title + None, # description + None, # alias + id="api_root_id, title", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + "my new collection", # title + "my description", # description + None, # alias + id="api_root_id, title, description", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + "my new collection", # title + "my description", # description + "my-alias", # alias + id="api_root_id, title, description, alias", + ), + ], +) +def test_add_collection( + app, api_root_id, title, description, alias, db_api_roots, db_collections +): + collection = app.taxii_server.servers.taxii2.persistence.api.add_collection( + api_root_id=api_root_id, + title=title, + description=description, + alias=alias, + ) + assert collection.id is not None + assert str(collection.api_root_id) == api_root_id + assert collection.title == title + assert collection.description == description + assert collection.alias == alias + db_collection = ( + app.taxii_server.servers.taxii2.persistence.api.db.session.query( + taxii2models.Collection + ) + .filter(taxii2models.Collection.id == collection.id) + .one() + ) + assert str(db_collection.api_root_id) == api_root_id + assert db_collection.title == title + assert db_collection.description == description + assert db_collection.alias == alias diff --git a/tests/taxii2/test_taxii2_collections.py b/tests/taxii2/test_taxii2_collections.py new file mode 100644 index 00000000..ff01c95c --- /dev/null +++ b/tests/taxii2/test_taxii2_collections.py @@ -0,0 +1,230 @@ +import json +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_API_ROOT_MOCK, + GET_COLLECTIONS_MOCK, config_noop, + server_mapping_noop, + server_mapping_remove_fields) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "config_override_func", + "server_mapping_override_func", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "collections": [ + { + "id": COLLECTIONS[0].id, + "title": "0Read only", + "description": "Read only description", + "can_read": True, + "can_write": False, + "media_types": ["application/stix+json;version=2.1"], + }, + { + "id": COLLECTIONS[1].id, + "title": "1Write only", + "description": "Write only description", + "can_read": False, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + { + "id": COLLECTIONS[2].id, + "title": "2Read/Write", + "description": "Read/Write description", + "can_read": True, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + { + "id": COLLECTIONS[3].id, + "title": "3No permissions", + "description": "No permissions description", + "can_read": False, + "can_write": False, + "media_types": ["application/stix+json;version=2.1"], + }, + { + "id": COLLECTIONS[4].id, + "title": "4No description", + "can_read": True, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + { + "id": COLLECTIONS[5].id, + "title": "5With alias", + "description": "With alias description", + "alias": "this-is-an-alias", + "can_read": False, + "can_write": True, + "media_types": ["application/stix+json;version=2.1"], + }, + ] + }, + id="good, first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + {}, + id="good, second (no collections)", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + config_noop, + server_mapping_remove_fields("taxii1"), + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii2 only config", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + config_noop, + server_mapping_noop, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii1/2 config", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + config_noop, + server_mapping_noop, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_collections( + authenticated_client, + method, + api_root_id, + headers, + config_override_func, + server_mapping_override_func, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_root", + side_effect=GET_API_ROOT_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collections", + side_effect=GET_COLLECTIONS_MOCK, + ), patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write"], + }, + ): + func = getattr(authenticated_client, method) + response = func(f"/{api_root_id}/collections/", headers=headers) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_collections_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/collections/") + assert response.status_code == 401 diff --git a/tests/taxii2/test_taxii2_discovery.py b/tests/taxii2/test_taxii2_discovery.py new file mode 100644 index 00000000..52837b61 --- /dev/null +++ b/tests/taxii2/test_taxii2_discovery.py @@ -0,0 +1,159 @@ +import json +from unittest.mock import patch + +import pytest +from tests.taxii2.utils import (API_ROOTS_WITH_DEFAULT, + API_ROOTS_WITHOUT_DEFAULT, config_noop, + config_remove_fields) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "config_override_func", + "api_roots", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + config_noop, + API_ROOTS_WITH_DEFAULT, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": "Some TAXII Server", + "description": "This TAXII Server contains a listing of...", + "contact": "string containing contact information", + "default": f"/{API_ROOTS_WITH_DEFAULT[0].id}/", + "api_roots": [f"/{item.id}/" for item in API_ROOTS_WITH_DEFAULT], + }, + id="good, with default api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + config_noop, + API_ROOTS_WITHOUT_DEFAULT, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": "Some TAXII Server", + "description": "This TAXII Server contains a listing of...", + "contact": "string containing contact information", + "api_roots": [f"/{item.id}/" for item in API_ROOTS_WITHOUT_DEFAULT], + }, + id="good, without default api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + config_noop, + [], + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": "Some TAXII Server", + "description": "This TAXII Server contains a listing of...", + "contact": "string containing contact information", + "api_roots": [], + }, + id="good, no api roots configured", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + config_remove_fields("description", "contact"), + [], + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "title": "Some TAXII Server", + "api_roots": [], + }, + id="good, no api roots and no description or contact", + ), + pytest.param( + "get", + {"Accept": "xml"}, + config_noop, + [], + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + config_noop, + [], + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_discovery( + authenticated_client, + method, + headers, + config_override_func, + api_roots, + expected_status, + expected_headers, + expected_content, +): + config_defaults = { + "title": "Some TAXII Server", + "description": "This TAXII Server contains a listing of...", + "contact": "string containing contact information", + } + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + { + **authenticated_client.application.taxii_server.servers.taxii2.config, + **config_defaults, + } + ), + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_roots", + return_value=api_roots, + ): + func = getattr(authenticated_client, method) + response = func("/taxii2/", headers=headers) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + assert json.loads(response.data) == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_discovery_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func("/taxii2/") + assert response.status_code == 401 diff --git a/tests/taxii2/test_taxii2_manifest.py b/tests/taxii2/test_taxii2_manifest.py new file mode 100644 index 00000000..fcb76538 --- /dev/null +++ b/tests/taxii2/test_taxii2_manifest.py @@ -0,0 +1,751 @@ +import datetime +import json +from unittest.mock import patch +from urllib.parse import urlencode +from uuid import uuid4 + +import pytest +from opentaxii.taxii2.utils import get_next_param, taxii2_datetimeformat +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_COLLECTION_MOCK, + GET_MANIFEST_MOCK, NOW, STIX_OBJECTS) + + +@pytest.mark.parametrize( + "method,headers,api_root_id,collection_id,filter_kwargs,expected_status,expected_headers,expected_content", + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].alias, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, by alias", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW)}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[1:2] + ], + }, + id="good, added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=3))}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, no results", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW).replace("Z", "+00:00")}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"added_after": ["Not a valid datetime."]}, + "name": "validation error", + }, + id="broken added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 1}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": True, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 2}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, limit exact", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 999}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, limit high", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"limit": ["Not a valid integer."]}, + "name": "validation error", + }, + id="broken limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"next": get_next_param(STIX_OBJECTS[0])}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[1:2] + ], + }, + id="good, next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"next": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"next": ["Not a valid value."]}, + "name": "validation error", + }, + id="broken next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[id]": STIX_OBJECTS[0].id}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, id", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[id]": ",".join([obj.id for obj in STIX_OBJECTS[:3]])}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, ids", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[type]": STIX_OBJECTS[0].type}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, type", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[type]": ",".join([obj.type for obj in STIX_OBJECTS[:3]])}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, types", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": taxii2_datetimeformat(STIX_OBJECTS[0].version)}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "last"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, version last", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "first"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=2)), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=3)), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[1:3] + ], + }, + id="good, version first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "all"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:3] + ], + }, + id="good, version all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + { + "match[version]": ",".join( + [taxii2_datetimeformat(obj.version) for obj in STIX_OBJECTS[:3]] + ) + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:3] + ], + }, + id="good, versions", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[spec_version]": STIX_OBJECTS[0].spec_version}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, spec_version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + { + "match[spec_version]": ",".join( + [obj.spec_version for obj in STIX_OBJECTS[:3]] + ) + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "date_added": taxii2_datetimeformat(obj.date_added), + "version": taxii2_datetimeformat(obj.version), + "media_type": f"application/stix+json;version={obj.spec_version}", + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, spec_versions", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[1].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="write-only collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + COLLECTIONS[5].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + COLLECTIONS[5].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown collection", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_manifest( + authenticated_client, + method, + api_root_id, + collection_id, + filter_kwargs, + headers, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_manifest", + side_effect=GET_MANIFEST_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ): + func = getattr(authenticated_client, method) + if filter_kwargs: + querystring = f"?{urlencode(filter_kwargs)}" + else: + querystring = "" + response = func( + f"/{api_root_id}/collections/{collection_id}/manifest/{querystring}", + headers=headers, + ) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize( + "method", + ["get", "post", "delete"] +) +def test_manifest_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/collections/{COLLECTIONS[5].id}/manifest/") + assert response.status_code == 401 diff --git a/tests/taxii2/test_taxii2_object.py b/tests/taxii2/test_taxii2_object.py new file mode 100644 index 00000000..4c066b99 --- /dev/null +++ b/tests/taxii2/test_taxii2_object.py @@ -0,0 +1,866 @@ +import datetime +import json +from unittest.mock import patch +from urllib.parse import urlencode +from uuid import uuid4 + +import pytest +from opentaxii.taxii2.utils import (DATETIMEFORMAT, get_next_param, + taxii2_datetimeformat) +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, DELETE_OBJECT_MOCK, + GET_COLLECTION_MOCK, GET_OBJECT_MOCK, NOW, + STIX_OBJECTS) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "collection_id", + "object_id", + "filter_kwargs", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].alias, + STIX_OBJECTS[0].id, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, by alias", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + { + "added_after": taxii2_datetimeformat(NOW), + "match[version]": "all", + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[2]] + ], + }, + id="good, added_after, all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"added_after": taxii2_datetimeformat(NOW)}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"added_after": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=3))}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, added_after, no results", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"added_after": taxii2_datetimeformat(NOW).replace("Z", "+00:00")}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"added_after": ["Not a valid datetime."]}, + "name": "validation error", + }, + id="broken added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": 1}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + { + "limit": 1, + "match[version]": "all", + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": True, + "next": get_next_param(STIX_OBJECTS[0]).decode(), + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, limit, all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + { + "limit": 2, + "match[version]": "all", + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0], STIX_OBJECTS[2]] + ], + }, + id="good, limit exact, all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": 999}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, limit high", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"limit": ["Not a valid integer."]}, + "name": "validation error", + }, + id="broken limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"next": get_next_param(STIX_OBJECTS[0])}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + { + "next": get_next_param(STIX_OBJECTS[0]), + "match[version]": "all", + }, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[2]] + ], + }, + id="good, next, all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"next": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"next": ["Not a valid value."]}, + "name": "validation error", + }, + id="broken next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": taxii2_datetimeformat(STIX_OBJECTS[0].version)}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": "last"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, version last", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": "first"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[2]] + ], + }, + id="good, version first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": "all"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0], STIX_OBJECTS[2]] + ], + }, + id="good, version all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": "a"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, unknown version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[spec_version]": STIX_OBJECTS[0].spec_version}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, spec_version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[spec_version]": "a"}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, unknown spec_version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[1].id, + STIX_OBJECTS[2].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="write-only collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + str(uuid4()), + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown object", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[version]": taxii2_datetimeformat(STIX_OBJECTS[0].version)}, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + b"", + id="delete, version filter", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[spec_version]": STIX_OBJECTS[0].spec_version}, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + b"", + id="delete, spec_version filter", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + b"", + id="delete, good", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + str(uuid4()), + {}, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + b"", + id="delete, unknown object", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[3].id, + str(uuid4()), + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="delete, no read, no write", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[0].id, + str(uuid4()), + {}, + 403, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 403, + "description": ( + "You don't have the permission to access the requested " + "resource. It is either read-protected or not readable by the " + "server." + ), + "name": "Forbidden", + }, + id="delete, read, no write", + ), + pytest.param( + "delete", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[1].id, + str(uuid4()), + {}, + 403, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 403, + "description": ( + "You don't have the permission to access the requested " + "resource. It is either read-protected or not readable by the " + "server." + ), + "name": "Forbidden", + }, + id="delete, no read, write", + ), + pytest.param( + "delete", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + str(uuid4()), + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="delete, wrong accept header", + ), + ], +) +def test_object( + authenticated_client, + method, + api_root_id, + collection_id, + object_id, + filter_kwargs, + headers, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_object", + side_effect=GET_OBJECT_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "delete_object", + side_effect=DELETE_OBJECT_MOCK, + ) as delete_object_mock: + func = getattr(authenticated_client, method) + if filter_kwargs: + querystring = f"?{urlencode(filter_kwargs)}" + else: + querystring = "" + kwargs = {"headers": headers} + response = func( + f"/{api_root_id}/collections/{collection_id}/objects/{object_id}/{querystring}", + **kwargs, + ) + assert response.status_code == expected_status + if method == "delete" and expected_status == 200: + expected_kwargs = { + "match_version": [ + datetime.datetime.strptime( + filter_kwargs["match[version]"], DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc) + ] + if "match[version]" in filter_kwargs + else None, + "match_spec_version": [filter_kwargs["match[spec_version]"]] + if "match[spec_version]" in filter_kwargs + else None, + } + delete_object_mock.assert_called_once_with( + collection_id=COLLECTIONS[5].id, object_id=object_id, **expected_kwargs + ) + else: + delete_object_mock.assert_not_called() + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ) and response.data != b"": + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_object_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func( + f"/{API_ROOTS[0].id}/collections/{COLLECTIONS[5].id}/objects/{STIX_OBJECTS[0].id}/" + ) + assert response.status_code == 401 diff --git a/tests/taxii2/test_taxii2_objects.py b/tests/taxii2/test_taxii2_objects.py new file mode 100644 index 00000000..57a13610 --- /dev/null +++ b/tests/taxii2/test_taxii2_objects.py @@ -0,0 +1,1088 @@ +import datetime +import json +from unittest.mock import patch +from urllib.parse import urlencode +from uuid import uuid4 + +import pytest +from opentaxii.taxii2.utils import get_next_param, taxii2_datetimeformat +from tests.taxii2.utils import (ADD_OBJECTS_MOCK, API_ROOTS, COLLECTIONS, + GET_COLLECTION_MOCK, GET_JOB_AND_DETAILS_MOCK, + GET_OBJECTS_MOCK, JOBS, NOW, STIX_OBJECTS) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "collection_id", + "filter_kwargs", + "post_data", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].alias, + {}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, by alias", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW)}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[1:2] + ], + }, + id="good, added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=3))}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + {}, + id="good, no results", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"added_after": taxii2_datetimeformat(NOW).replace("Z", "+00:00")}, + {}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"added_after": ["Not a valid datetime."]}, + "name": "validation error", + }, + id="broken added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 1}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": True, + "next": get_next_param(STIX_OBJECTS[0]).decode(), + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 2}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, limit exact", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": 999}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, limit high", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"limit": "a"}, + {}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"limit": ["Not a valid integer."]}, + "name": "validation error", + }, + id="broken limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"next": get_next_param(STIX_OBJECTS[0])}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[1:2] + ], + }, + id="good, next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"next": "a"}, + {}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"next": ["Not a valid value."]}, + "name": "validation error", + }, + id="broken next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[id]": STIX_OBJECTS[0].id}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, id", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[id]": ",".join([obj.id for obj in STIX_OBJECTS[:3]])}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, ids", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[type]": STIX_OBJECTS[0].type}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in [STIX_OBJECTS[0]] + ], + }, + id="good, type", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[type]": ",".join([obj.type for obj in STIX_OBJECTS[:3]])}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, types", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": taxii2_datetimeformat(STIX_OBJECTS[0].version)}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "last"}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, version last", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "first"}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[1:3] + ], + }, + id="good, version first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[version]": "all"}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:3] + ], + }, + id="good, version all", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + { + "match[version]": ",".join( + [taxii2_datetimeformat(obj.version) for obj in STIX_OBJECTS[:3]] + ) + }, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:3] + ], + }, + id="good, versions", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {"match[spec_version]": STIX_OBJECTS[0].spec_version}, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:1] + ], + }, + id="good, spec_version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + { + "match[spec_version]": ",".join( + [obj.spec_version for obj in STIX_OBJECTS[:3]] + ) + }, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=2) + ), + }, + { + "more": False, + "objects": [ + { + "id": obj.id, + "type": obj.type, + "spec_version": obj.type, + **obj.serialized_data, + } + for obj in STIX_OBJECTS[:2] + ], + }, + id="good, spec_versions", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[1].id, + {}, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="write-only collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + COLLECTIONS[5].id, + {}, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + COLLECTIONS[5].id, + {}, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + {}, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown collection", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 202, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": JOBS[0].id, + "status": JOBS[0].status, + "request_timestamp": taxii2_datetimeformat(JOBS[0].request_timestamp), + "total_count": 4, + "success_count": 1, + "successes": [ + { + "id": "indicator--c410e480-e42b-47d1-9476-85307c12bcbf", + "version": "2018-05-27T12:02:41.312000Z", + } + ], + "failure_count": 1, + "failures": [ + { + "id": "malware--664fa29d-bf65-4f28-a667-bdb76f29ec98", + "version": "2018-05-28T14:03:42.543000Z", + "message": "Unable to process object", + } + ], + "pending_count": 2, + "pendings": [ + { + "id": "indicator--252c7c11-daf2-42bd-843b-be65edca9f61", + "version": "2018-05-18T20:16:21.148000Z", + }, + { + "id": "relationship--045585ad-a22f-4333-af33-bfd503a683b5", + "version": "2018-05-15T10:13:32.579000Z", + }, + ], + }, + id="post, good", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + {}, + 400, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 400, + "description": ["No objects"], + "name": "validation error", + }, + id="post, missing data", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[0].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="post, read-only collection", + ), + pytest.param( + "post", + { + "Accept": "xml", + "Content-Type": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="post, wrong accept header", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 415, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 415, + "description": ( + "The server does not support the media type transmitted in the " + "request." + ), + "name": "Unsupported Media Type", + }, + id="post, missing content-type header", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "xml", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 415, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 415, + "description": ( + "The server does not support the media type transmitted in the " + "request." + ), + "name": "Unsupported Media Type", + }, + id="post, wrong content-type header", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "application/taxii+json;version=2.1", + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy" * 33, + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 413, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 413, + "description": "The data value transmitted exceeds the capacity limit.", + "name": "Request Entity Too Large", + }, + id="post, too big", + ), + pytest.param( + "post", + { + "Accept": "application/taxii+json;version=2.1", + "Content-Type": "application/taxii+json;version=2.1", + "Content-Length": 1, + }, + API_ROOTS[0].id, + COLLECTIONS[5].id, + {}, + { + "objects": [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy" * 33, + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ] + }, + 413, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 413, + "description": "The data value transmitted exceeds the capacity limit.", + "name": "Request Entity Too Large", + }, + id="post, too big, fake content-length", + ), + ], +) +def test_objects( + authenticated_client, + method, + api_root_id, + collection_id, + filter_kwargs, + post_data, + headers, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_objects", + side_effect=GET_OBJECTS_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "add_objects", + side_effect=ADD_OBJECTS_MOCK, + ) as add_objects_mock, patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_job_and_details", + side_effect=GET_JOB_AND_DETAILS_MOCK, + ): + func = getattr(authenticated_client, method) + if filter_kwargs: + querystring = f"?{urlencode(filter_kwargs)}" + else: + querystring = "" + kwargs = {"headers": headers} + if method == "post": + kwargs["json"] = post_data + response = func( + f"/{api_root_id}/collections/{collection_id}/objects/{querystring}", + **kwargs, + ) + assert response.status_code == expected_status + if method == "post" and expected_status == 202: + add_objects_mock.assert_called_once_with( + api_root_id=API_ROOTS[0].id, collection_id=COLLECTIONS[5].id, objects=post_data["objects"] + ) + else: + add_objects_mock.assert_not_called() + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_objects_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/collections/{COLLECTIONS[5].id}/objects/") + assert response.status_code == 401 diff --git a/tests/taxii2/test_taxii2_sqldb.py b/tests/taxii2/test_taxii2_sqldb.py new file mode 100644 index 00000000..4081b959 --- /dev/null +++ b/tests/taxii2/test_taxii2_sqldb.py @@ -0,0 +1,1365 @@ +import datetime +from uuid import uuid4 + +import pytest +from opentaxii.persistence.sqldb.taxii2models import Job, JobDetail, STIXObject +from opentaxii.taxii2 import entities +from opentaxii.taxii2.utils import (DATETIMEFORMAT, get_next_param, + parse_next_param) +from tests.taxii2.utils import (API_ROOTS, API_ROOTS_WITH_DEFAULT, + API_ROOTS_WITHOUT_DEFAULT, COLLECTIONS, + GET_API_ROOT_MOCK, GET_COLLECTION_MOCK, + GET_COLLECTIONS_MOCK, GET_JOB_AND_DETAILS_MOCK, + GET_MANIFEST_MOCK, GET_OBJECT_MOCK, + GET_OBJECTS_MOCK, GET_VERSIONS_MOCK, JOBS, NOW, + STIX_OBJECTS) + + +@pytest.mark.parametrize( + ["db_api_roots"], + [ + pytest.param( + API_ROOTS_WITHOUT_DEFAULT, # db_api_roots + id="without default", + ), + pytest.param( + API_ROOTS_WITH_DEFAULT, # db_api_roots + id="with default", + ), + ], + indirect=["db_api_roots"], +) +def test_get_api_roots(taxii2_sqldb_api, db_api_roots): + response = taxii2_sqldb_api.get_api_roots() + assert response == [api_root for api_root in db_api_roots] + + +@pytest.mark.parametrize( + ["api_root_id"], + [ + pytest.param( + API_ROOTS[0].id, # api_root_id + id="first", + ), + pytest.param( + API_ROOTS[1].id, # api_root_id + id="second", + ), + pytest.param( + str(uuid4()), # api_root_id + id="unknown", + ), + ], +) +def test_get_api_root(taxii2_sqldb_api, db_api_roots, api_root_id): + response = taxii2_sqldb_api.get_api_root(api_root_id) + assert response == GET_API_ROOT_MOCK(api_root_id) + + +@pytest.mark.parametrize( + ["api_root_id", "job_id"], + [ + pytest.param( + API_ROOTS[0].id, # api_root_id + JOBS[0].id, # job_id + id="first", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + JOBS[1].id, # job_id + id="second", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + JOBS[2].id, # job_id + id="wrong api root", + ), + pytest.param( + str(uuid4()), # api_root_id + JOBS[0].id, # job_id + id="unknown api root", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + str(uuid4()), # job_id + id="unknown job id", + ), + ], +) +def test_get_job_and_details(taxii2_sqldb_api, db_jobs, api_root_id, job_id): + response = taxii2_sqldb_api.get_job_and_details(api_root_id, job_id) + assert response == GET_JOB_AND_DETAILS_MOCK(api_root_id, job_id) + + +@pytest.mark.parametrize( + ["api_root_id"], + [ + pytest.param( + API_ROOTS[0].id, + id="first", + ), + pytest.param( + API_ROOTS[1].id, + id="second", + ), + pytest.param( + str(uuid4()), + id="unknown", + ), + ], +) +def test_get_collections(taxii2_sqldb_api, db_collections, api_root_id): + response = taxii2_sqldb_api.get_collections(api_root_id) + assert response == GET_COLLECTIONS_MOCK(api_root_id) + + +@pytest.mark.parametrize( + ["api_root_id", "collection_id_or_alias"], + [ + pytest.param( + API_ROOTS[0].id, + COLLECTIONS[0].id, + id="first", + ), + pytest.param( + API_ROOTS[0].id, + COLLECTIONS[1].id, + id="second", + ), + pytest.param( + API_ROOTS[0].id, + COLLECTIONS[5].alias, + id="alias", + ), + pytest.param( + API_ROOTS[1].id, + COLLECTIONS[0].id, + id="wrong api root", + ), + pytest.param( + str(uuid4()), + COLLECTIONS[0].id, + id="unknown api root", + ), + pytest.param( + API_ROOTS[0].id, + str(uuid4()), + id="unknown collection id", + ), + ], +) +def test_get_collection(taxii2_sqldb_api, db_collections, api_root_id, collection_id_or_alias): + response = taxii2_sqldb_api.get_collection(api_root_id, collection_id_or_alias) + assert response == GET_COLLECTION_MOCK(api_root_id, collection_id_or_alias) + + +@pytest.mark.parametrize( + [ + "collection_id", + "limit", + "added_after", + "next_kwargs", + "match_id", + "match_type", + "match_version", + "match_spec_version", + ], + [ + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="default", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 1, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit low", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 2, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit exact", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 999, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit high", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + NOW, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="added_after", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + parse_next_param(get_next_param(STIX_OBJECTS[0])), # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="next_kwargs", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[0].id], # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="match_id", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + [obj.id for obj in STIX_OBJECTS[:3]], # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="match_id multiple", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + [STIX_OBJECTS[0].type], # match_type + None, # match_version + None, # match_spec_version + id="match_type", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + [obj.type for obj in STIX_OBJECTS[:3]], # match_type + None, # match_version + None, # match_spec_version + id="match_type multiple", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[0].version], # match_version + None, # match_spec_version + id="version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[1].version], # match_version + None, # match_spec_version + id="version [1]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[0].version, STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [0, 2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["first"], # match_version + None, # match_spec_version + id="version first", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["last"], # match_version + None, # match_spec_version + id="version last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["all"], # match_version + None, # match_spec_version + id="version all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["first", "last"], # match_version + None, # match_spec_version + id="version first, last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[2].version, "last"], # match_version + None, # match_spec_version + id="version [2], last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [STIX_OBJECTS[0].spec_version], # match_spec_version + id="spec_version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [STIX_OBJECTS[1].spec_version], # match_spec_version + id="spec_version [1]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [ + STIX_OBJECTS[0].spec_version, + STIX_OBJECTS[1].spec_version, + ], # match_spec_version + id="spec_version [0, 1]", + ), + ], +) +def test_get_manifest( + taxii2_sqldb_api, + db_stix_objects, + collection_id, + limit, + added_after, + next_kwargs, + match_id, + match_type, + match_version, + match_spec_version, +): + response = taxii2_sqldb_api.get_manifest( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + assert response == GET_MANIFEST_MOCK( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + +@pytest.mark.parametrize( + [ + "collection_id", + "limit", + "added_after", + "next_kwargs", + "match_id", + "match_type", + "match_version", + "match_spec_version", + ], + [ + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="default", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 1, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit low", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 2, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit exact", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + 999, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="limit high", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + NOW, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="added_after", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + parse_next_param(get_next_param(STIX_OBJECTS[0])), # next_kwargs + None, # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="next_kwargs", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[0].id], # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="match_id", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + [obj.id for obj in STIX_OBJECTS[:3]], # match_id + None, # match_type + None, # match_version + None, # match_spec_version + id="match_id multiple", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + [STIX_OBJECTS[0].type], # match_type + None, # match_version + None, # match_spec_version + id="match_type", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + [obj.type for obj in STIX_OBJECTS[:3]], # match_type + None, # match_version + None, # match_spec_version + id="match_type multiple", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[0].version], # match_version + None, # match_spec_version + id="version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[1].version], # match_version + None, # match_spec_version + id="version [1]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[0].version, STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [0, 2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["first"], # match_version + None, # match_spec_version + id="version first", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["last"], # match_version + None, # match_spec_version + id="version last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["all"], # match_version + None, # match_spec_version + id="version all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + ["first", "last"], # match_version + None, # match_spec_version + id="version first, last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + [STIX_OBJECTS[2].version, "last"], # match_version + None, # match_spec_version + id="version [2], last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [STIX_OBJECTS[0].spec_version], # match_spec_version + id="spec_version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [STIX_OBJECTS[1].spec_version], # match_spec_version + id="spec_version [1]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_id + None, # match_type + None, # match_version + [ + STIX_OBJECTS[0].spec_version, + STIX_OBJECTS[1].spec_version, + ], # match_spec_version + id="spec_version [0, 1]", + ), + ], +) +def test_get_objects( + taxii2_sqldb_api, + db_stix_objects, + collection_id, + limit, + added_after, + next_kwargs, + match_id, + match_type, + match_version, + match_spec_version, +): + response = taxii2_sqldb_api.get_objects( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + assert response == GET_OBJECTS_MOCK( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + +@pytest.mark.parametrize( + [ + "api_root_id", + "collection_id", + "objects", + ], + [ + pytest.param( + API_ROOTS[0].id, # api_root_id + COLLECTIONS[5].id, # collection_id + [ + { + "type": "indicator", + "spec_version": "2.1", + "id": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": "2016-04-06T20:03:48.000Z", + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + } + ], # objects + id="single object", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + COLLECTIONS[5].id, # collection_id + [ + { + "type": "relationship", + "spec_version": "2.1", + "id": "relationship--44298a74-ba52-4f0c-87a3-1824e67d7fad", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:06:37.000Z", + "modified": "2016-04-06T20:06:37.000Z", + "relationship_type": "indicates", + "source_ref": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "target_ref": "malware--31b940d4-6f7f-459a-80ea-9c1f17b5891b", + }, + { + "type": "malware", + "spec_version": "2.1", + "id": "malware--31b940d4-6f7f-459a-80ea-9c1f17b5891b", + "is_family": True, + "created": "2016-04-06T20:07:09.000Z", + "modified": "2016-04-06T20:07:09.000Z", + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "name": "Poison Ivy", + "malware_types": ["trojan"], + }, + ], # objects + id="multiple objects", + ), + pytest.param( + API_ROOTS[0].id, # api_root_id + COLLECTIONS[5].id, # collection_id + [ + { + "id": STIX_OBJECTS[0].id, + "type": STIX_OBJECTS[0].type, + "spec_version": STIX_OBJECTS[0].spec_version, + **STIX_OBJECTS[0].serialized_data, + } + ], # objects + id="existing object", + ), + ], +) +def test_add_objects( + taxii2_sqldb_api, + db_stix_objects, + api_root_id, + collection_id, + objects, +): + job, job_details = taxii2_sqldb_api.add_objects( + api_root_id=api_root_id, + collection_id=collection_id, + objects=objects, + ) + # Check response entities + assert job == entities.Job( + id=job.id, + api_root_id=api_root_id, + status="complete", + request_timestamp=job.request_timestamp, + completed_timestamp=job.completed_timestamp, + ) + assert isinstance(job.request_timestamp, datetime.datetime) + assert isinstance(job.completed_timestamp, datetime.datetime) + assert len(job_details) == len(objects) + for (job_detail, obj) in zip(job_details, objects): + assert job_detail == entities.JobDetail( + id=job_detail.id, + job_id=job.id, + stix_id=obj["id"], + version=datetime.datetime.strptime(obj["modified"], DATETIMEFORMAT).replace( + tzinfo=datetime.timezone.utc + ), + message="", + status="success", + ) + # Check database state + db_job = taxii2_sqldb_api.db.session.query(Job).one() + assert str(db_job.api_root_id) == api_root_id + assert db_job.status == "complete" + assert isinstance(db_job.request_timestamp, datetime.datetime) + assert isinstance(db_job.completed_timestamp, datetime.datetime) + for obj in objects: + db_obj = ( + taxii2_sqldb_api.db.session.query(STIXObject) + .filter( + STIXObject.id == obj["id"], + STIXObject.version + == datetime.datetime.strptime(obj["modified"], DATETIMEFORMAT).replace( + tzinfo=datetime.timezone.utc + ), + ) + .one() + ) + assert db_obj.id == obj["id"] + assert str(db_obj.collection_id) == collection_id + assert db_obj.type == obj["type"] + assert db_obj.spec_version == obj["spec_version"] + assert isinstance(db_obj.date_added, datetime.datetime) + assert db_obj.version == datetime.datetime.strptime( + obj["modified"], DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc) + assert db_obj.serialized_data == { + key: value + for (key, value) in obj.items() + if key not in ["id", "type", "spec_version"] + } + db_job_detail = ( + taxii2_sqldb_api.db.session.query(JobDetail) + .filter(JobDetail.stix_id == obj["id"]) + .one() + ) + assert db_job_detail.job_id == db_job.id + assert db_job_detail.stix_id == obj["id"] + assert db_job_detail.version == datetime.datetime.strptime( + obj["modified"], DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc) + assert db_job_detail.message == "" + assert db_job_detail.status == "success" + + +@pytest.mark.parametrize( + [ + "collection_id", + "object_id", + "limit", + "added_after", + "next_kwargs", + "match_version", + "match_spec_version", + ], + [ + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="default", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + str(uuid4()), # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="unknown object", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + NOW, # added_after + None, # next_kwargs + ["all"], # match_version + None, # match_spec_version + id="added_after, all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + NOW, # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="added_after", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + NOW + datetime.timedelta(seconds=3), # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="added_after, no results", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 1, # limit + None, # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="limit", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 1, # limit + None, # added_after + None, # next_kwargs + ["all"], # match_version + None, # match_spec_version + id="limit, all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 2, # limit + None, # added_after + None, # next_kwargs + ["all"], # match_version + None, # match_spec_version + id="limit exact, all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 999, # limit + None, # added_after + None, # next_kwargs + None, # match_version + None, # match_spec_version + id="limit high", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + parse_next_param(get_next_param(STIX_OBJECTS[0])), # next_kwargs + None, # match_version + None, # match_spec_version + id="next", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + parse_next_param(get_next_param(STIX_OBJECTS[0])), # next_kwargs + ["all"], # match_version + None, # match_spec_version + id="next, all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[0].version], # match_version + None, # match_spec_version + id="version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[1].version], # match_version + None, # match_spec_version + id="version [1]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[0].version, STIX_OBJECTS[2].version], # match_version + None, # match_spec_version + id="version [0, 2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + ["first"], # match_version + None, # match_spec_version + id="version first", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + ["last"], # match_version + None, # match_spec_version + id="version last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + ["all"], # match_version + None, # match_spec_version + id="version all", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + ["first", "last"], # match_version + None, # match_spec_version + id="version first, last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[2].version, "last"], # match_version + None, # match_spec_version + id="version [2], last", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_version + [STIX_OBJECTS[0].spec_version], # match_spec_version + id="spec_version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_version + [STIX_OBJECTS[2].spec_version], # match_spec_version + id="spec_version [2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_version + [ + STIX_OBJECTS[0].spec_version, + STIX_OBJECTS[2].spec_version, + ], # match_spec_version + id="spec_version [0, 2]", + ), + ], +) +def test_get_object( + taxii2_sqldb_api, + db_stix_objects, + collection_id, + object_id, + limit, + added_after, + next_kwargs, + match_version, + match_spec_version, +): + response = taxii2_sqldb_api.get_object( + collection_id=collection_id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_version=match_version, + match_spec_version=match_spec_version, + ) + assert response == GET_OBJECT_MOCK( + collection_id=collection_id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_version=match_version, + match_spec_version=match_spec_version, + ) + + +@pytest.mark.parametrize( + [ + "collection_id", + "object_id", + "match_version", + "match_spec_version", + "expected_objects", + ], + [ + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # match_version + None, # match_spec_version + [STIX_OBJECTS[1], STIX_OBJECTS[3]], # expected_objects + id="default", + ), + pytest.param( + COLLECTIONS[4].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # match_version + None, # match_spec_version + STIX_OBJECTS, # expected_objects + id="wrong collection", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + [STIX_OBJECTS[0].version], # match_version + None, # match_spec_version + STIX_OBJECTS[1:], # expected_objects + id="version", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # match_version + [STIX_OBJECTS[0].spec_version], # match_spec_version + STIX_OBJECTS[1:], # expected_objects + id="version", + ), + ], +) +def test_delete_object( + taxii2_sqldb_api, + db_stix_objects, + collection_id, + object_id, + match_version, + match_spec_version, + expected_objects, +): + taxii2_sqldb_api.delete_object( + collection_id=collection_id, + object_id=object_id, + match_version=match_version, + match_spec_version=match_spec_version, + ) + assert set( + (str(db_obj.collection_id), db_obj.id, db_obj.version) + for db_obj in taxii2_sqldb_api.db.session.query(STIXObject).all() + ) == set((obj.collection_id, obj.id, obj.version) for obj in expected_objects) + + +@pytest.mark.parametrize( + [ + "collection_id", + "object_id", + "limit", + "added_after", + "next_kwargs", + "match_spec_version", + ], + [ + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_spec_version + id="default", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + str(uuid4()), # object_id + None, # limit + None, # added_after + None, # next_kwargs + None, # match_spec_version + id="unknown object", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + NOW, # added_after + None, # next_kwargs + None, # match_spec_version + id="added_after", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + NOW + datetime.timedelta(seconds=3), # added_after + None, # next_kwargs + None, # match_spec_version + id="added_after, no results", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 1, # limit + None, # added_after + None, # next_kwargs + None, # match_spec_version + id="limit", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + 999, # limit + None, # added_after + None, # next_kwargs + None, # match_spec_version + id="limit high", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + parse_next_param(get_next_param(STIX_OBJECTS[0])), # next_kwargs + None, # match_spec_version + id="next", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[0].spec_version], # match_spec_version + id="spec_version [0]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [STIX_OBJECTS[2].spec_version], # match_spec_version + id="spec_version [2]", + ), + pytest.param( + COLLECTIONS[5].id, # collection_id + STIX_OBJECTS[0].id, # object_id + None, # limit + None, # added_after + None, # next_kwargs + [ + STIX_OBJECTS[0].spec_version, + STIX_OBJECTS[2].spec_version, + ], # match_spec_version + id="spec_version [0, 2]", + ), + ], +) +def test_get_versions( + taxii2_sqldb_api, + db_stix_objects, + collection_id, + object_id, + limit, + added_after, + next_kwargs, + match_spec_version, +): + response = taxii2_sqldb_api.get_versions( + collection_id=collection_id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_spec_version=match_spec_version, + ) + assert response == GET_VERSIONS_MOCK( + collection_id=collection_id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_spec_version=match_spec_version, + ) diff --git a/tests/taxii2/test_taxii2_status.py b/tests/taxii2/test_taxii2_status.py new file mode 100644 index 00000000..d69cbec2 --- /dev/null +++ b/tests/taxii2/test_taxii2_status.py @@ -0,0 +1,282 @@ +import json +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from opentaxii.persistence.sqldb import taxii2models +from opentaxii.taxii2.utils import taxii2_datetimeformat +from tests.taxii2.utils import (API_ROOTS, GET_JOB_AND_DETAILS_MOCK, JOBS, + config_noop, server_mapping_noop, + server_mapping_remove_fields) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "job_id", + "config_override_func", + "server_mapping_override_func", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + JOBS[0].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": JOBS[0].id, + "status": JOBS[0].status, + "request_timestamp": taxii2_datetimeformat(JOBS[0].request_timestamp), + "total_count": 4, + "success_count": 1, + "successes": [ + { + "id": "indicator--c410e480-e42b-47d1-9476-85307c12bcbf", + "version": "2018-05-27T12:02:41.312000Z", + } + ], + "failure_count": 1, + "failures": [ + { + "id": "malware--664fa29d-bf65-4f28-a667-bdb76f29ec98", + "version": "2018-05-28T14:03:42.543000Z", + "message": "Unable to process object", + } + ], + "pending_count": 2, + "pendings": [ + { + "id": "indicator--252c7c11-daf2-42bd-843b-be65edca9f61", + "version": "2018-05-18T20:16:21.148000Z", + }, + { + "id": "relationship--045585ad-a22f-4333-af33-bfd503a683b5", + "version": "2018-05-15T10:13:32.579000Z", + }, + ], + }, + id="good, first, first", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + JOBS[3].id, + config_noop, + server_mapping_noop, + 200, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "id": JOBS[3].id, + "status": JOBS[3].status, + "request_timestamp": taxii2_datetimeformat(JOBS[3].request_timestamp), + "total_count": 0, + "success_count": 0, + "successes": [], + "failure_count": 0, + "failures": [], + "pending_count": 0, + "pendings": [], + }, + id="good, second, second", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + JOBS[3].id, + config_noop, + server_mapping_noop, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api id", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + config_noop, + server_mapping_remove_fields("taxii1"), + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown job id, taxii2 only config", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + config_noop, + server_mapping_noop, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown job id, taxii1/2 config", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + JOBS[0].id, + config_noop, + server_mapping_remove_fields("taxii1"), + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii2 only config", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + JOBS[0].id, + config_noop, + server_mapping_noop, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root, taxii1/2 config", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + JOBS[0].id, + config_noop, + server_mapping_noop, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + JOBS[0].id, + config_noop, + server_mapping_noop, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_status( + authenticated_client, + method, + api_root_id, + job_id, + headers, + config_override_func, + server_mapping_override_func, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2, + "config", + config_override_func( + authenticated_client.application.taxii_server.servers.taxii2.config + ), + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_api_roots", + return_value=API_ROOTS, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_job_and_details", + side_effect=GET_JOB_AND_DETAILS_MOCK, + ), patch.object( + authenticated_client.application.taxii_server, + "servers", + server_mapping_override_func( + authenticated_client.application.taxii_server.servers + ), + ): + func = getattr(authenticated_client, method) + response = func(f"/{api_root_id}/status/{job_id}/", headers=headers) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ): + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_status_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func(f"/{API_ROOTS[0].id}/status/{JOBS[0].id}/") + assert response.status_code == 401 + + +def test_job_cleanup(app, db_jobs): + number_removed = app.taxii_server.servers.taxii2.persistence.api.job_cleanup() + assert number_removed == 3 + assert ( + app.taxii_server.servers.taxii2.persistence.api.db.session.query( + taxii2models.Job + ).count() + == 3 + ) diff --git a/tests/taxii2/test_taxii2_utils.py b/tests/taxii2/test_taxii2_utils.py new file mode 100644 index 00000000..4e477be5 --- /dev/null +++ b/tests/taxii2/test_taxii2_utils.py @@ -0,0 +1,102 @@ +import datetime + +import pytest +from opentaxii.taxii2.utils import (get_next_param, parse_next_param, + taxii2_datetimeformat) +from tests.taxii2.factories import STIXObjectFactory + + +@pytest.mark.parametrize( + ["input_value", "expected"], + [ + pytest.param( + datetime.datetime(2022, 1, 1, 12, 0, 0, 0, tzinfo=datetime.timezone.utc), + "2022-01-01T12:00:00.000000Z", + id="utc", + ), + pytest.param( + datetime.datetime( + 2022, + 1, + 1, + 12, + 0, + 0, + 0, + tzinfo=datetime.timezone(datetime.timedelta(hours=1)), + ), + "2022-01-01T11:00:00.000000Z", + id="utc+1", + ), + pytest.param( + datetime.datetime( + 2022, + 1, + 1, + 12, + 0, + 0, + 0, + tzinfo=datetime.timezone(datetime.timedelta(hours=1, minutes=30)), + ), + "2022-01-01T10:30:00.000000Z", + id="utc+1.5", + ), + pytest.param( + datetime.datetime( + 2022, + 1, + 1, + 12, + 0, + 0, + 0, + tzinfo=datetime.timezone(datetime.timedelta(hours=-1)), + ), + "2022-01-01T13:00:00.000000Z", + id="utc+1", + ), + pytest.param( + datetime.datetime( + 2022, + 1, + 1, + 12, + 0, + 0, + 0, + tzinfo=datetime.timezone(datetime.timedelta(hours=-1, minutes=-30)), + ), + "2022-01-01T13:30:00.000000Z", + id="utc+1.5", + ), + ], +) +def test_taxii2_datetimeformat(input_value, expected): + assert taxii2_datetimeformat(input_value) == expected + + +@pytest.mark.parametrize( + "stix_id, date_added, next_param", + [ + pytest.param( + "indicator--fa641b92-94d7-42dd-aa0e-63cfe1ee148a", + datetime.datetime( + 2022, 2, 4, 18, 40, 6, 297204, tzinfo=datetime.timezone.utc + ), + ( + b"MjAyMi0wMi0wNFQxODo0MDowNi4yOTcyMDQrMDA6MDB8aW5kaWNhdG9yLS1mYTY0M" + b"WI5Mi05NGQ3LTQyZGQtYWEwZS02M2NmZTFlZTE0OGE=" + ), + id="simple", + ), + ], +) +def test_next_param(stix_id, date_added, next_param): + stix_object = STIXObjectFactory.build( + id=stix_id, + type=stix_id.split("--")[0], + date_added=date_added, + ) + assert get_next_param(stix_object) == next_param + assert parse_next_param(next_param) == {"id": stix_id, "date_added": date_added} diff --git a/tests/taxii2/test_taxii2_versions.py b/tests/taxii2/test_taxii2_versions.py new file mode 100644 index 00000000..31dc715c --- /dev/null +++ b/tests/taxii2/test_taxii2_versions.py @@ -0,0 +1,431 @@ +import datetime +import json +from unittest.mock import patch +from urllib.parse import urlencode +from uuid import uuid4 + +import pytest +from opentaxii.taxii2.utils import get_next_param, taxii2_datetimeformat +from tests.taxii2.utils import (API_ROOTS, COLLECTIONS, GET_COLLECTION_MOCK, + GET_VERSIONS_MOCK, NOW, STIX_OBJECTS) + + +@pytest.mark.parametrize( + [ + "method", + "headers", + "api_root_id", + "collection_id", + "object_id", + "filter_kwargs", + "expected_status", + "expected_headers", + "expected_content", + ], + [ + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[0].version), + taxii2_datetimeformat(STIX_OBJECTS[2].version), + ], + }, + id="good", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"added_after": taxii2_datetimeformat(NOW)}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[2].version), + ], + }, + id="good, added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"added_after": taxii2_datetimeformat(NOW).replace("Z", "+00:00")}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"added_after": ["Not a valid datetime."]}, + "name": "validation error", + }, + id="broken added_after", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": 1}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": True, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[0].version), + ], + }, + id="good, limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": 2}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[0].version), + taxii2_datetimeformat(STIX_OBJECTS[2].version), + ], + }, + id="good, limit exact", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": 999}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[0].version), + taxii2_datetimeformat(STIX_OBJECTS[2].version), + ], + }, + id="good, limit high", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"limit": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"limit": ["Not a valid integer."]}, + "name": "validation error", + }, + id="broken limit", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"next": get_next_param(STIX_OBJECTS[0])}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat( + NOW + datetime.timedelta(seconds=3) + ), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[2].version), + ], + }, + id="good, next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"next": "a"}, + 400, + { + "Content-Type": "application/taxii+json;version=2.1", + }, + { + "code": 400, + "description": {"next": ["Not a valid value."]}, + "name": "validation error", + }, + id="broken next", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {"match[spec_version]": STIX_OBJECTS[0].spec_version}, + 200, + { + "Content-Type": "application/taxii+json;version=2.1", + "X-TAXII-Date-Added-First": taxii2_datetimeformat(NOW), + "X-TAXII-Date-Added-Last": taxii2_datetimeformat(NOW), + }, + { + "more": False, + "versions": [ + taxii2_datetimeformat(STIX_OBJECTS[0].version), + ], + }, + id="good, spec_version", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[1].id, + STIX_OBJECTS[2].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="write-only collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[1].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="wrong api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + str(uuid4()), + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown api root", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + str(uuid4()), + STIX_OBJECTS[0].id, + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown collection", + ), + pytest.param( + "get", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + str(uuid4()), + {}, + 404, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 404, + "description": "The requested URL was not found on the server. If you entered " + "the URL manually please check your spelling and try again.", + "name": "Not Found", + }, + id="unknown object", + ), + pytest.param( + "get", + {"Accept": "xml"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 406, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 406, + "name": "Not Acceptable", + "description": ( + "The resource identified by the request is only capable of generating response entities which" + " have content characteristics not acceptable according to the accept headers sent in the" + " request." + ), + }, + id="wrong accept header", + ), + pytest.param( + "post", + {"Accept": "application/taxii+json;version=2.1"}, + API_ROOTS[0].id, + COLLECTIONS[5].id, + STIX_OBJECTS[0].id, + {}, + 405, + {"Content-Type": "application/taxii+json;version=2.1"}, + { + "code": 405, + "description": "The method is not allowed for the requested URL.", + "name": "Method Not Allowed", + }, + id="wrong method", + ), + ], +) +def test_versions( + authenticated_client, + method, + api_root_id, + collection_id, + object_id, + filter_kwargs, + headers, + expected_status, + expected_headers, + expected_content, +): + with patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_versions", + side_effect=GET_VERSIONS_MOCK, + ), patch.object( + authenticated_client.application.taxii_server.servers.taxii2.persistence.api, + "get_collection", + side_effect=GET_COLLECTION_MOCK, + ), patch.object( + authenticated_client.account, + "permissions", + { + COLLECTIONS[0].id: ["read"], + COLLECTIONS[1].id: ["write"], + COLLECTIONS[2].id: ["read", "write"], + COLLECTIONS[4].id: ["read", "write"], + COLLECTIONS[5].id: ["write", "read"], + }, + ): + func = getattr(authenticated_client, method) + if filter_kwargs: + querystring = f"?{urlencode(filter_kwargs)}" + else: + querystring = "" + kwargs = {"headers": headers} + response = func( + f"/{api_root_id}/collections/{collection_id}/objects/{object_id}/versions/{querystring}", + **kwargs, + ) + assert response.status_code == expected_status + assert { + key: response.headers.get(key) for key in expected_headers + } == expected_headers + if ( + response.headers.get("Content-Type", "application/taxii+json;version=2.1") + == "application/taxii+json;version=2.1" + ) and response.data != b"": + content = json.loads(response.data) + else: + content = response.data + assert content == expected_content + + +@pytest.mark.parametrize("method", ["get", "post", "delete"]) +def test_versions_unauthenticated( + client, + method, +): + func = getattr(client, method) + response = func( + f"/{API_ROOTS[0].id}/collections/{COLLECTIONS[5].id}/objects/{STIX_OBJECTS[0].id}/versions/" + ) + assert response.status_code == 401 diff --git a/tests/taxii2/utils.py b/tests/taxii2/utils.py new file mode 100644 index 00000000..bb1d719d --- /dev/null +++ b/tests/taxii2/utils.py @@ -0,0 +1,432 @@ +import datetime +from typing import Dict, List, Optional +from uuid import uuid4 + +from opentaxii.server import ServerMapping +from opentaxii.taxii2.entities import (ApiRoot, Collection, Job, JobDetail, + ManifestRecord, STIXObject, + VersionRecord) +from opentaxii.taxii2.utils import DATETIMEFORMAT, taxii2_datetimeformat + +API_ROOTS_WITH_DEFAULT = ( + ApiRoot(str(uuid4()), True, "first title", "first description"), + ApiRoot(str(uuid4()), False, "second title", "second description"), +) +API_ROOTS_WITHOUT_DEFAULT = ( + ApiRoot(str(uuid4()), False, "first title", "first description"), + ApiRoot(str(uuid4()), False, "second title", "second description"), + ApiRoot(str(uuid4()), False, "third title", None), +) +API_ROOTS = API_ROOTS_WITHOUT_DEFAULT +NOW = datetime.datetime.now(datetime.timezone.utc) +JOBS = tuple() +for api_root in API_ROOTS: + JOBS = JOBS + ( + Job(str(uuid4()), api_root.id, "complete", NOW, NOW - datetime.timedelta(hours=24, minutes=1)), + Job(str(uuid4()), api_root.id, "pending", NOW, None), + ) + +JOB_DETAILS_KWARGS = sorted(( + { + "stix_id": "indicator--c410e480-e42b-47d1-9476-85307c12bcbf", + "version": datetime.datetime.strptime( + "2018-05-27T12:02:41.312Z", DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc), + "message": "", + "status": "success", + }, + { + "stix_id": "malware--664fa29d-bf65-4f28-a667-bdb76f29ec98", + "version": datetime.datetime.strptime( + "2018-05-28T14:03:42.543Z", DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc), + "message": "Unable to process object", + "status": "failure", + }, + { + "stix_id": "indicator--252c7c11-daf2-42bd-843b-be65edca9f61", + "version": datetime.datetime.strptime( + "2018-05-18T20:16:21.148Z", DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc), + "message": "", + "status": "pending", + }, + { + "stix_id": "relationship--045585ad-a22f-4333-af33-bfd503a683b5", + "version": datetime.datetime.strptime( + "2018-05-15T10:13:32.579Z", DATETIMEFORMAT + ).replace(tzinfo=datetime.timezone.utc), + "message": "", + "status": "pending", + }, +), key=lambda item: item["stix_id"]) +JOB_DETAILS = tuple( + JobDetail(id=str(uuid4()), job_id=JOBS[0].id, **kwargs) + for kwargs in JOB_DETAILS_KWARGS +) +COLLECTIONS = ( + Collection( + str(uuid4()), API_ROOTS[0].id, "0Read only", "Read only description", None + ), + Collection( + str(uuid4()), API_ROOTS[0].id, "1Write only", "Write only description", None + ), + Collection( + str(uuid4()), API_ROOTS[0].id, "2Read/Write", "Read/Write description", None + ), + Collection( + str(uuid4()), + API_ROOTS[0].id, + "3No permissions", + "No permissions description", + None, + ), + Collection(str(uuid4()), API_ROOTS[0].id, "4No description", "", None), + Collection( + str(uuid4()), + API_ROOTS[0].id, + "5With alias", + "With alias description", + "this-is-an-alias", + ), +) +STIX_OBJECTS = ( + STIXObject( + f"indicator--{str(uuid4())}", + COLLECTIONS[5].id, + "indicator", + "2.0", + NOW, + NOW + datetime.timedelta(seconds=1), + { + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=1)), + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + }, + ), + STIXObject( + f"relationship--{str(uuid4())}", + COLLECTIONS[5].id, + "relationship", + "2.1", + NOW + datetime.timedelta(seconds=2), + NOW + datetime.timedelta(seconds=3), + { + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:06:37.000Z", + "modified": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=3)), + "relationship_type": "indicates", + "source_ref": "indicator--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + "target_ref": "malware--31b940d4-6f7f-459a-80ea-9c1f17b5891b", + }, + ), +) +STIX_OBJECTS = STIX_OBJECTS + ( + STIXObject( + STIX_OBJECTS[0].id, + COLLECTIONS[5].id, + "indicator", + "2.1", + NOW + datetime.timedelta(seconds=3), + NOW + datetime.timedelta(seconds=-1), + { + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=-1)), + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + }, + ), + STIXObject( + f"indicator--{str(uuid4())}", + COLLECTIONS[1].id, + "indicator", + "2.1", + NOW, + NOW + datetime.timedelta(seconds=1), + { + "created_by_ref": "identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + "created": "2016-04-06T20:03:48.000Z", + "modified": taxii2_datetimeformat(NOW + datetime.timedelta(seconds=1)), + "indicator_types": ["malicious-activity"], + "name": "Poison Ivy Malware", + "description": "This file is part of Poison Ivy", + "pattern": "[ file:hashes.'SHA-256' = " + "'4bac27393bdd9777ce02453256c5577cd02275510b2227f473d03f533924f877' ]", + "pattern_type": "stix", + "valid_from": "2016-01-01T00:00:00Z", + }, + ), +) + + +def process_match_version(match_version): + if match_version is None: + match_version = ["last"] + versions_per_id = {} + for stix_obj in STIX_OBJECTS: + if stix_obj.id not in versions_per_id: + versions_per_id[stix_obj.id] = [] + versions_per_id[stix_obj.id].append(stix_obj.version) + id_version_combos = [] + for value in match_version: + if value == "last": + for obj_id, versions in versions_per_id.items(): + id_version_combos.append((obj_id, max(versions))) + elif value == "first": + for obj_id, versions in versions_per_id.items(): + id_version_combos.append((obj_id, min(versions))) + elif value == "all": + for obj_id, versions in versions_per_id.items(): + for version in versions: + id_version_combos.append((obj_id, version)) + else: + for obj_id in versions_per_id: + id_version_combos.append((obj_id, value)) + return id_version_combos + + +def GET_API_ROOT_MOCK(api_root_id): + for api_root in API_ROOTS: + if api_root.id == api_root_id: + return api_root + return None + + +def GET_JOB_AND_DETAILS_MOCK(api_root_id, job_id): + job_response = None + details_response = [] + for job in JOBS: + if job.api_root_id == api_root_id and job.id == job_id: + job_response = job + break + if job_response is None: + return None, [] + for job_detail in JOB_DETAILS: + if job_detail.job_id == job_id: + details_response.append(job_detail) + return job_response, details_response + + +def GET_COLLECTIONS_MOCK(api_root_id): + response = [] + for collection in COLLECTIONS: + if collection.api_root_id == api_root_id: + response.append(collection) + return response + + +def GET_COLLECTION_MOCK(api_root_id, collection_id_or_alias): + for collection in COLLECTIONS: + if collection.api_root_id == api_root_id and ( + collection.id == collection_id_or_alias + or collection.alias == collection_id_or_alias + ): + return collection + return None + + +def STIX_OBJECT_FROM_MANIFEST(stix_id): + for obj in STIX_OBJECTS: + if obj.id == stix_id: + return obj + + +def GET_MANIFEST_MOCK( + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, +): + stix_objects, more = GET_OBJECTS_MOCK( + collection_id=collection_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_id=match_id, + match_type=match_type, + match_version=match_version, + match_spec_version=match_spec_version, + ) + response = [ + ManifestRecord(obj.id, obj.date_added, obj.version, obj.spec_version) + for obj in stix_objects + ] + return response, more + + +def GET_OBJECTS_MOCK( + collection_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_id: Optional[List[str]] = None, + match_type: Optional[List[str]] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, +): + id_version_combos = process_match_version(match_version) + response = [] + more = False + for stix_object in STIX_OBJECTS: + if ( + stix_object.collection_id == collection_id + and (stix_object.id, stix_object.version) in id_version_combos + ): + if limit is not None and limit == len(response): + more = True + break + if added_after is not None: + if stix_object.date_added <= added_after: + continue + if next_kwargs is not None: + if stix_object.date_added < next_kwargs["date_added"] or ( + stix_object.date_added == next_kwargs["date_added"] + and stix_object.id == next_kwargs["id"] + ): + continue + if match_id is not None and stix_object.id not in match_id: + continue + if match_type is not None and stix_object.type not in match_type: + continue + if ( + match_spec_version is not None + and stix_object.spec_version not in match_spec_version + ): + continue + response.append(stix_object) + return response, more + + +def GET_OBJECT_MOCK( + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, +): + id_version_combos = process_match_version(match_version) + response = [] + more = False + at_least_one = False + for stix_object in STIX_OBJECTS: + if ( + stix_object.collection_id == collection_id + and stix_object.id == object_id + ): + at_least_one = True + if ( + stix_object.id, + stix_object.version, + ) not in id_version_combos: + continue + if limit is not None and limit == len(response): + more = True + break + if added_after is not None: + if stix_object.date_added <= added_after: + continue + if next_kwargs is not None: + if stix_object.date_added < next_kwargs["date_added"] or ( + stix_object.date_added == next_kwargs["date_added"] + and stix_object.id == next_kwargs["id"] + ): + continue + if ( + match_spec_version is not None + and stix_object.spec_version not in match_spec_version + ): + continue + response.append(stix_object) + if not at_least_one: + response = None + return response, more + + +def ADD_OBJECTS_MOCK(api_root_id: str, collection_id: str, objects: List[Dict]): + return (JOBS[0], JOB_DETAILS) + + +def DELETE_OBJECT_MOCK( + collection_id: str, + object_id: str, + match_version: Optional[List[str]] = None, + match_spec_version: Optional[List[str]] = None, +): + return + + +def GET_VERSIONS_MOCK( + collection_id: str, + object_id: str, + limit: Optional[int] = None, + added_after: Optional[datetime.datetime] = None, + next_kwargs: Optional[Dict] = None, + match_spec_version: Optional[List[str]] = None, +): + versions, more = GET_OBJECT_MOCK( + collection_id=collection_id, + object_id=object_id, + limit=limit, + added_after=added_after, + next_kwargs=next_kwargs, + match_spec_version=match_spec_version, + match_version=["all"], + ) + return ( + [VersionRecord(obj.date_added, obj.version) for obj in versions] + if versions is not None + else None, + more, + ) + + +def config_noop(config): + return config + + +def config_remove_fields(*fields): + def inner(config): + for field in fields: + del config[field] + return config + + return inner + + +def config_override(override): + def inner(config): + return {**config, **override} + + return inner + + +def server_mapping_remove_fields(*fields): + def inner(original): + override = {field: None for field in fields} + kwargs = {**dict(original._asdict()), **override} + return ServerMapping(**kwargs) + + return inner + + +def server_mapping_noop(original): + return original diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..7686afa9 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,479 @@ +import os +from unittest import mock + +import pytest +from opentaxii.cli.auth import create_account, update_account +from opentaxii.cli.persistence import (add_api_root, add_collection, + delete_content_blocks, job_cleanup, + sync_data_configuration) + +from tests.fixtures import ACCOUNT, COLLECTION_OPEN +from tests.taxii2.utils import API_ROOTS +from tests.utils import assert_str_equal_no_formatting, conditional_raises + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr"], + [ + pytest.param( + [os.path.join("examples", "data-configuration.yml")], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + id="default", + ), + pytest.param( + [os.path.join("examples", "data-configuration.yml"), "-f"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + id="default, -f", + ), + pytest.param( + [ + os.path.join("examples", "data-configuration.yml"), + "--force-delete", + ], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + id="default, --force-delete", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "[-f]", + "config", + ":error: the following arguments are required: config", + ] + ), + id="no args", + ), + pytest.param( + ["this_file_does_not_exist"], # argv + FileNotFoundError, # raises + "[Errno 2] No such file or directory: 'this_file_does_not_exist'", # message + "", # stdout + "", # stderr + id="missing config", + ), + pytest.param( + ["this_file_does_not_exist", "-f"], # argv + FileNotFoundError, # raises + "[Errno 2] No such file or directory: 'this_file_does_not_exist'", # message + "", # stdout + "", # stderr + id="missing config, with -f", + ), + ], +) +def test_sync_data_configuration(app, capsys, argv, raises, message, stdout, stderr): + with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( + "sys.argv", [""] + argv + ): + with conditional_raises(raises) as exception: + sync_data_configuration() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr"], + [ + pytest.param( + ["-c", COLLECTION_OPEN, "--begin", "2000-01-01"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + id="good", + ), + pytest.param( + ["-c", "collectiondoesnotexist", "--begin", "2000-01-01"], # argv + ValueError, # raises + "Collection with name 'collectiondoesnotexist' does not exist", # message + "", # stdout + "", # stderr + id="collection does not exist", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-c COLLECTION", + "[-m]", + "--begin BEGIN", + "[--end END]", + ": error: the following arguments are required: -c/--collection, --begin", + ] + ), + id="no args", + ), + ], +) +def test_delete_content_blocks( + app, collections, capsys, argv, raises, message, stdout, stderr +): + with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( + "sys.argv", [""] + argv + ): + with conditional_raises(raises) as exception: + delete_content_blocks() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr"], + [ + pytest.param( + ["-u", "myuser", "-p", "mypass"], # argv + False, # raises + None, # message + "token: JWT_TOKEN", # stdout + "", # stderr + id="good", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-u USERNAME", + "-p PASSWORD", + "[-a]", + ": error: the following arguments are required: -u/--username, -p/--password", + ] + ), + id="no args", + ), + ], +) +def test_create_account(app, capsys, argv, raises, message, stdout, stderr): + with mock.patch("opentaxii.cli.auth.app", app), mock.patch("sys.argv", [""] + argv): + with conditional_raises(raises) as exception: + create_account() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr"], + [ + pytest.param( + ["-u", ACCOUNT.username, "-f", "admin", "-v", "y"], # argv + False, # raises + None, # message + "now user is admin", # stdout + "", # stderr + id="make admin", + ), + pytest.param( + ["-u", ACCOUNT.username, "-f", "admin", "-v", "n"], # argv + False, # raises + None, # message + "now user is mortal", # stdout + "", # stderr + id="drop admin", + ), + pytest.param( + ["-u", ACCOUNT.username, "-f", "password", "-v", "newpass"], # argv + False, # raises + None, # message + "password has been changed", # stdout + "", # stderr + id="change password", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-u USERNAME", + "-f {password,admin}", + "-v VALUE", + ": error: the following arguments are required: -u/--username, -f/--field, -v/--value", + ] + ), + id="no args", + ), + ], +) +def test_update_account(app, account, capsys, argv, raises, message, stdout, stderr): + with mock.patch("opentaxii.cli.auth.app", app), mock.patch("sys.argv", [""] + argv): + with conditional_raises(raises) as exception: + update_account() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr", "expected_call"], + [ + pytest.param( + ["-t", "my new api root"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "title": "my new api root", + "description": None, + "default": False, + }, # expected_call + id="title only", + ), + pytest.param( + ["-t", "my new api root", "-d", "my description"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "title": "my new api root", + "description": "my description", + "default": False, + }, # expected_call + id="title, description", + ), + pytest.param( + ["-t", "my new api root", "--default"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "title": "my new api root", + "description": None, + "default": True, + }, # expected_call + id="title, default", + ), + pytest.param( + ["-t", "my new api root", "-d", "my description", "--default"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "title": "my new api root", + "description": "my description", + "default": True, + }, # expected_call + id="title, description, default", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-t TITLE", + "[-d DESCRIPTION]", + "[--default]", + ": error: the following arguments are required: -t/--title", + ] + ), + None, # expected_call + id="no args", + ), + ], +) +def test_add_api_root( + app, capsys, argv, raises, message, stdout, stderr, expected_call +): + with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( + "sys.argv", [""] + argv + ), mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, "add_api_root" + ) as mock_add_api_root: + with conditional_raises(raises) as exception: + add_api_root() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + if expected_call is None: + mock_add_api_root.assert_not_called() + else: + mock_add_api_root.assert_called_once_with(**expected_call) + + +@pytest.mark.parametrize( + ["argv", "raises", "message", "stdout", "stderr", "expected_call"], + [ + pytest.param( + ["-r", API_ROOTS[0].id, "-t", "my new collection"], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "api_root_id": API_ROOTS[0].id, + "title": "my new collection", + "description": None, + "alias": None, + }, # expected_call + id="rootid, title only", + ), + pytest.param( + [ + "-r", + API_ROOTS[0].id, + "-t", + "my new collection", + "-d", + "my description", + ], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "api_root_id": API_ROOTS[0].id, + "title": "my new collection", + "description": "my description", + "alias": None, + }, # expected_call + id="rootid, title, description", + ), + pytest.param( + [ + "-r", + API_ROOTS[0].id, + "-t", + "my new collection", + "-d", + "my description", + "-a", + "my-alias", + ], # argv + False, # raises + None, # message + "", # stdout + "", # stderr + { + "api_root_id": API_ROOTS[0].id, + "title": "my new collection", + "description": "my description", + "alias": "my-alias", + }, # expected_call + id="rootid, title, description, alias", + ), + pytest.param( + ["-r", "fake-uuid", "-t", "my new collection"], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-r {ROOTIDS}", + "-t TITLE", + "[-d DESCRIPTION]", + "[-a ALIAS]", + ": error: argument -r/--rootid: invalid choice: 'fake-uuid'", + "(choose from WRAPPED_ROOTIDS)", + ] + ), + None, # expected_call + id="unknown api root", + ), + pytest.param( + [], # argv + SystemExit, # raises + "2", # message + "", # stdout + "".join( # stderr + [ + "usage:", + "[-h]", + "-r {ROOTIDS}", + "-t TITLE", + "[-d DESCRIPTION]", + "[-a ALIAS]", + ": error: the following arguments are required: -r/--rootid, -t/--title", + ] + ), + None, # expected_call + id="no args", + ), + ], +) +def test_add_collection( + app, db_api_roots, capsys, argv, raises, message, stdout, stderr, expected_call +): + stderr = stderr.replace( + "WRAPPED_ROOTIDS", + ",".join([f"'{api_root.id}'" for api_root in db_api_roots]), + ) + stderr = stderr.replace( + "ROOTIDS", + ",".join([api_root.id for api_root in db_api_roots]), + ) + with mock.patch("opentaxii.cli.persistence.app", app), mock.patch( + "sys.argv", [""] + argv + ), mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, "add_collection" + ) as mock_add_collection: + with conditional_raises(raises) as exception: + add_collection() + if raises: + assert str(exception.value) == message + captured = capsys.readouterr() + assert_str_equal_no_formatting(captured.out, stdout) + assert_str_equal_no_formatting(captured.err, stderr) + if expected_call is None: + mock_add_collection.assert_not_called() + else: + mock_add_collection.assert_called_once_with(**expected_call) + + +def test_job_cleanup(app, capsys): + with mock.patch("opentaxii.cli.persistence.app", app), mock.patch.object( + app.taxii_server.servers.taxii2.persistence.api, "job_cleanup", return_value=2 + ) as mock_cleanup: + job_cleanup() + mock_cleanup.assert_called_once_with() + captured = capsys.readouterr() + assert captured.out == "2 removed\n" + assert captured.err == "" diff --git a/tests/test_config.py b/tests/test_config.py index 5de141e2..b075e23f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -205,7 +205,10 @@ def test_custom_config_file(config_file_name_expected_value): }, "expected": { "taxii1": {"persistence_api": {"class": "something.Else", "other": 1}}, - "taxii2": {"persistence_api": {"class": "something.Else2", "other": 2}, "max_content_length": 1024}, + "taxii2": { + "persistence_api": {"class": "something.Else2", "other": 2}, + "max_content_length": 1024, + }, }, } TAXII2_ENVVARS = { @@ -215,7 +218,10 @@ def test_custom_config_file(config_file_name_expected_value): "OPENTAXII__TAXII2__MAX_CONTENT_LENGTH": "1024", }, "expected": { - "taxii2": {"persistence_api": {"class": "something.Else2", "other": 2}, "max_content_length": 1024}, + "taxii2": { + "persistence_api": {"class": "something.Else2", "other": 2}, + "max_content_length": 1024, + }, }, } diff --git a/tests/test_server.py b/tests/test_server.py index a7a1d60f..9fc12806 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -42,13 +42,13 @@ SERVICES = INTERNAL_SERVICES + [DISCOVERY_EXTERNAL] -@pytest.fixture(autouse=True) +@pytest.fixture() def local_services(server): for service in SERVICES: server.servers.taxii1.persistence.update_service(dict_to_service_entity(service)) -def test_services_configured(server): +def test_services_configured(server, local_services): assert len(server.servers.taxii1.get_services()) == len(SERVICES) with_paths = [ @@ -77,7 +77,8 @@ def test_taxii2_configured(server): ) -def test_multithreaded_access(server): +@pytest.mark.truncate() +def test_multithreaded_access(server, local_services): def testfunc(): server.servers.taxii1.get_services() diff --git a/tests/utils.py b/tests/utils.py index 257bc189..763c7a57 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,18 +1,19 @@ +import re + +import pytest from libtaxii import messages_10 as tm10 from libtaxii import messages_11 as tm11 - from opentaxii.taxii import entities -from opentaxii.taxii.http import ( - TAXII_10_HTTPS_HEADERS, TAXII_10_HTTP_HEADERS, TAXII_11_HTTPS_HEADERS, - TAXII_11_HTTP_HEADERS -) -from opentaxii.taxii.http import ( - HTTP_ACCEPT, HTTP_CONTENT_XML) - +from opentaxii.taxii.http import (HTTP_ACCEPT, HTTP_CONTENT_XML, + TAXII_10_HTTP_HEADERS, + TAXII_10_HTTPS_HEADERS, + TAXII_11_HTTP_HEADERS, + TAXII_11_HTTPS_HEADERS) from opentaxii.taxii.utils import get_utc_now -from fixtures import ( - CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID) +from fixtures import CB_STIX_XML_111, CONTENT, MESSAGE, MESSAGE_ID + +JWT_RE = re.compile(r'[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*') def as_tm(version): @@ -21,7 +22,7 @@ def as_tm(version): elif version == 11: return tm11 else: - raise ValueError('Unknown TAXII message version: %s' % version) + raise ValueError("Unknown TAXII message version: %s" % version) def prepare_headers(version, https): @@ -37,38 +38,45 @@ def prepare_headers(version, https): else: headers.update(TAXII_11_HTTP_HEADERS) else: - raise ValueError('Unknown TAXII message version: %s' % version) + raise ValueError("Unknown TAXII message version: %s" % version) headers[HTTP_ACCEPT] = HTTP_CONTENT_XML return headers -def persist_content(manager, collection_name, service_id, timestamp=None, - binding=CB_STIX_XML_111, subtypes=[]): +def persist_content( + manager, + collection_name, + service_id, + timestamp=None, + binding=CB_STIX_XML_111, + subtypes=[], +): timestamp = timestamp or get_utc_now() - content_binding = entities.ContentBindingEntity( - binding=binding, - subtypes=subtypes - ) + content_binding = entities.ContentBindingEntity(binding=binding, subtypes=subtypes) content = entities.ContentBlockEntity( - content=CONTENT, timestamp_label=timestamp, - message=MESSAGE, content_binding=content_binding) + content=CONTENT, + timestamp_label=timestamp, + message=MESSAGE, + content_binding=content_binding, + ) collection = manager.get_collection(collection_name, service_id) if not collection: - raise ValueError('No collection with name {}'.format(collection_name)) + raise ValueError("No collection with name {}".format(collection_name)) content = manager.create_content(content, collections=[collection]) return content -def prepare_subscription_request(collection, action, version, - subscription_id=None, params=None): +def prepare_subscription_request( + collection, action, version, subscription_id=None, params=None +): data = dict( action=action, @@ -80,14 +88,17 @@ def prepare_subscription_request(collection, action, version, if version == 11: cls = mod.ManageCollectionSubscriptionRequest - data.update(dict( - collection_name=collection, - subscription_parameters=( - mod.SubscriptionParameters(**params) if params else None) - )) + data.update( + dict( + collection_name=collection, + subscription_parameters=( + mod.SubscriptionParameters(**params) if params else None + ), + ) + ) else: cls = mod.ManageFeedSubscriptionRequest - data['feed_name'] = collection + data["feed_name"] = collection return cls(**data) @@ -108,7 +119,7 @@ def is_headers_valid(headers, version, https): else: return includes(headers, TAXII_11_HTTP_HEADERS) else: - raise ValueError('Unknown TAXII message version: %s' % version) + raise ValueError("Unknown TAXII message version: %s" % version) class conditional: @@ -127,3 +138,25 @@ def __enter__(self): def __exit__(self, *args): if self.condition: return self.contextmanager.__exit__(*args) + + +class conditional_raises(conditional): + """ + Assert if wrapped code raises, but only when given an exception class + """ + + def __init__(self, condition): + if condition: + contextmanager = pytest.raises(condition) + else: + contextmanager = None + super().__init__(condition, contextmanager) + + +def assert_str_equal_no_formatting(str1, str2): + if "JWT_TOKEN" in str2: + jwt_token = JWT_RE.findall(str1)[0] + str2 = str2.replace("JWT_TOKEN", jwt_token) + assert "".join([part.strip() for part in str1.split()]) == "".join( + [part.strip() for part in str2.split()] + )