From 2f32e5d26b7f1ed5e9ecdc09013ba5c84b2816ac Mon Sep 17 00:00:00 2001 From: arobbins Date: Wed, 19 Jun 2024 14:33:11 -0400 Subject: [PATCH] lint code and fix unit tests --- src/attack_flow/cli.py | 2 +- src/attack_flow/graphviz.py | 41 +++++++++++++++++++++++-------------- tests/test_cli.py | 5 +++-- tests/test_graphviz.py | 2 +- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/attack_flow/cli.py b/src/attack_flow/cli.py index dfffad7d..09109759 100644 --- a/src/attack_flow/cli.py +++ b/src/attack_flow/cli.py @@ -97,7 +97,7 @@ def graphviz(args): converted = attack_flow.graphviz.convert_attack_tree(flow_bundle) else: converted = attack_flow.graphviz.convert_attack_flow(flow_bundle) - + with open(args.output, "w") as out: out.write(converted) return 0 diff --git a/src/attack_flow/graphviz.py b/src/attack_flow/graphviz.py index 238cef61..08901a3d 100644 --- a/src/attack_flow/graphviz.py +++ b/src/attack_flow/graphviz.py @@ -18,6 +18,7 @@ def label_escape(text): return graphviz.escape(html.escape(text)) + def convert_attack_flow(bundle): """ Convert an Attack Flow STIX bundle into Graphviz format. @@ -69,6 +70,7 @@ def convert_attack_flow(bundle): return gv.source + def convert_attack_tree(bundle): """ Convert an Attack Flow STIX bundle into Graphviz format. @@ -77,7 +79,7 @@ def convert_attack_tree(bundle): :rtype: str """ - gv = graphviz.Digraph(graph_attr={'rankdir':'BT'}) + gv = graphviz.Digraph(graph_attr={"rankdir": "BT"}) gv.body = _get_body_label(bundle) ignored_ids = get_viz_ignored_ids(bundle) @@ -85,10 +87,14 @@ def convert_attack_tree(bundle): id_to_remove = [] ids = [] - for i,o in enumerate(objects): + for i, o in enumerate(objects): if o.type == "attack-operator": id_to_remove.append( - {"id": o.id, "prev_id": objects[i-1].id, "next_id":o.effect_refs[0], "type": o.operator + { + "id": o.id, + "prev_id": objects[i - 1].id, + "next_id": o.effect_refs[0], + "type": o.operator, } ) @@ -96,22 +102,23 @@ def convert_attack_tree(bundle): objects = [item for item in objects if item.id not in ids] new_operator_ids = [i["next_id"] for i in id_to_remove] for operator in id_to_remove: - for i,o in enumerate(objects): - if o.type=="relationship" and o.source_ref == operator["id"]: - o.source_ref = operator.prev_id - if o.type=="relationship" and o.target_ref == operator["id"]: - o.target_ref = operator.next_id - if o.get("effect_refs") and operator["id"] in o.effect_refs: - for i,j in enumerate(o.effect_refs): - if j == operator["id"]: - o.effect_refs[i] = operator["next_id"] - + for i, o in enumerate(objects): + if o.type == "relationship" and o.source_ref == operator["id"]: + o.source_ref = operator.prev_id + if o.type == "relationship" and o.target_ref == operator["id"]: + o.target_ref = operator.next_id + if o.get("effect_refs") and operator["id"] in o.effect_refs: + for i, j in enumerate(o.effect_refs): + if j == operator["id"]: + o.effect_refs[i] = operator["next_id"] for o in objects: logger.debug("Processing object id=%s", o.id) if o.type == "attack-action": if o.id in new_operator_ids: - operator_type = [item["type"] for item in id_to_remove if item["next_id"] == o.id][0] + operator_type = [ + item["type"] for item in id_to_remove if item["next_id"] == o.id + ][0] gv.node( o.id, label=_get_operator_label(o, operator_type), @@ -191,6 +198,8 @@ def _get_action_label(action): ">", ] ) + + def _get_attack_tree_action_label(action): """ Generate the GraphViz label for an action node as a table. @@ -217,6 +226,7 @@ def _get_attack_tree_action_label(action): ] ) + def _get_asset_label(asset): """ Generate the GraphViz label for an asset node as a table. @@ -283,6 +293,7 @@ def _get_condition_label(condition): ] ) + def _get_operator_label(action, operator_type): """ Generate the GraphViz label for an action node as a table. @@ -311,4 +322,4 @@ def _get_operator_label(action, operator_type): f'Confidence{confidence}', ">", ] - ) \ No newline at end of file + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 0fbf9702..83f434f1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -95,12 +95,12 @@ def test_doc_schema(schema_mock, generate_mock, insert_mock, exit_mock): @patch("sys.exit") -@patch("attack_flow.graphviz.convert") +@patch("attack_flow.graphviz.convert_attack_flow") @patch("attack_flow.model.load_attack_flow_bundle") def test_graphviz(load_mock, convert_mock, exit_mock): """ Test that the script parses a JSON file and passes the resulting object - to convert(). + to convert_attack_flow(). """ convert_mock.return_value = dedent( r"""\ @@ -111,6 +111,7 @@ def test_graphviz(load_mock, convert_mock, exit_mock): ) bundle = stix2.Bundle() load_mock.return_value = bundle + print("printing resp bundle ", bundle) with NamedTemporaryFile() as flow, NamedTemporaryFile() as graphviz: sys.argv = ["af", "graphviz", flow.name, graphviz.name] runpy.run_module("attack_flow.cli", run_name="__main__") diff --git a/tests/test_graphviz.py b/tests/test_graphviz.py index 088825cc..26341449 100644 --- a/tests/test_graphviz.py +++ b/tests/test_graphviz.py @@ -9,7 +9,7 @@ def test_convert_attack_flow_to_graphviz(): - output = attack_flow.graphviz.convert(get_flow_bundle()) + output = attack_flow.graphviz.convert_attack_flow(get_flow_bundle()) assert output == dedent( """\ digraph {