Skip to content

Commit

Permalink
perf: refactor to avoid excessive walking (#877)
Browse files Browse the repository at this point in the history
Store result of vars found in `walk` into an intermediate object,
keyed by rule index and the variable's "context".

Signed-off-by: Anders Eknert <[email protected]>
  • Loading branch information
anderseknert committed Jul 1, 2024
1 parent 0dad5f0 commit 300eef2
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 43 deletions.
6 changes: 1 addition & 5 deletions bundle/regal/ast/ast.rego
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ all_rules_refs contains value if {
# scope: document
all_refs contains value if some value in all_rules_refs

all_refs contains value if {
walk(input.imports, [_, value])

is_ref(value)
}
all_refs contains imported.path if some imported in input.imports

# METADATA
# title: ref_to_string
Expand Down
18 changes: 9 additions & 9 deletions bundle/regal/ast/ast_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ test_find_vars_in_local_scope if {
"e": {"col": 4, "row": 14},
}

var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.a)) == set()
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.b)) == {"a"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.c)) == {"a", "b", "c"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.d)) == {"a", "b", "c", "d"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.e)) == {"a", "b", "c", "d", "e"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.a)) with input as module == set()
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.b)) with input as module == {"a"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.c)) with input as module == {"a", "b", "c"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.d)) with input as module == {"a", "b", "c", "d"}
var_names(ast.find_vars_in_local_scope(allow_rule, var_locations.e)) with input as module == {"a", "b", "c", "d", "e"}
}

test_find_vars_in_local_scope_complex_comprehension_term if {
Expand All @@ -210,7 +210,7 @@ test_find_vars_in_local_scope_complex_comprehension_term if {

allow_rule := module.rules[0]

ast.find_vars_in_local_scope(allow_rule, {"col": 10, "row": 10}) == [
ast.find_vars_in_local_scope(allow_rule, {"col": 10, "row": 10}) with input as module == [
{"location": {"col": 3, "row": 7, "text": "YQ=="}, "type": "var", "value": "a"},
{"location": {"col": 15, "row": 7, "text": "Yg=="}, "type": "var", "value": "b"},
{"location": {"col": 20, "row": 7, "text": "Yw=="}, "type": "var", "value": "c"},
Expand Down Expand Up @@ -264,7 +264,7 @@ test_find_some_decl_vars if {

module := regal.parse_module("p.rego", policy)

some_vars := ast.find_some_decl_vars(module.rules[0])
some_vars := ast.find_some_decl_vars(module.rules[0]) with input as module

var_names(some_vars) == {"x", "y", "z"}
}
Expand All @@ -284,8 +284,8 @@ test_find_some_decl_names_in_scope if {

module := regal.parse_module("p.rego", policy)

ast.find_some_decl_names_in_scope(module.rules[0], {"col": 1, "row": 8}) == {"x"}
ast.find_some_decl_names_in_scope(module.rules[0], {"col": 1, "row": 10}) == {"x", "y", "z"}
{"x"} == ast.find_some_decl_names_in_scope(module.rules[0], {"col": 1, "row": 8}) with input as module
{"x", "y", "z"} == ast.find_some_decl_names_in_scope(module.rules[0], {"col": 1, "row": 10}) with input as module
}

var_names(vars) := {var.value | some var in vars}
Expand Down
69 changes: 51 additions & 18 deletions bundle/regal/ast/search.rego
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ _find_object_comprehension_vars(value) := array.concat(key, val) if {
val := [value.value.value | value.value.value.type == "var"]
}

_find_vars(_, value, last) := find_term_vars(function_ret_args(fn_name, value)) if {
_find_vars(_, value, last) := {"term": find_term_vars(function_ret_args(fn_name, value))} if {
last == "terms"
value[0].type == "ref"
value[0].value[0].type == "var"
Expand All @@ -101,7 +101,7 @@ _find_vars(_, value, last) := find_term_vars(function_ret_args(fn_name, value))
function_ret_in_args(fn_name, value)
}

_find_vars(path, value, last) := _find_assign_vars(path, value) if {
_find_vars(path, value, last) := {"assign": _find_assign_vars(path, value)} if {
last == "terms"
value[0].type == "ref"
value[0].value[0].type == "var"
Expand All @@ -112,63 +112,96 @@ _find_vars(path, value, last) := _find_assign_vars(path, value) if {
# left-hand side is equally dubious, but we'll treat `x = 1` as `x := 1` for
# the purpose of this function until we have a more robust way of dealing with
# unification
_find_vars(path, value, last) := _find_assign_vars(path, value) if {
_find_vars(path, value, last) := {"assign": _find_assign_vars(path, value)} if {
last == "terms"
value[0].type == "ref"
value[0].value[0].type == "var"
value[0].value[0].value == "eq"
}

_find_vars(_, value, _) := find_ref_vars(value) if value.type == "ref"
_find_vars(_, value, _) := {"ref": find_ref_vars(value)} if value.type == "ref"

_find_vars(path, value, last) := _find_some_in_decl_vars(path, value) if {
_find_vars(path, value, last) := {"somein": _find_some_in_decl_vars(path, value)} if {
last == "symbols"
value[0].type == "call"
}

_find_vars(path, value, last) := _find_some_decl_vars(path, value) if {
_find_vars(path, value, last) := {"some": _find_some_decl_vars(path, value)} if {
last == "symbols"
value[0].type != "call"
}

_find_vars(path, value, last) := _find_every_vars(path, value) if {
_find_vars(path, value, last) := {"every": _find_every_vars(path, value)} if {
last == "terms"
value.domain
}

_find_vars(_, value, _) := _find_set_or_array_comprehension_vars(value) if {
_find_vars(_, value, _) := {"setorarraycomprehension": _find_set_or_array_comprehension_vars(value)} if {
value.type in {"setcomprehension", "arraycomprehension"}
}

_find_vars(_, value, _) := _find_object_comprehension_vars(value) if value.type == "objectcomprehension"

find_some_decl_vars(rule) := [var |
walk(rule, [path, value])
_find_vars(_, value, _) := {"objectcomprehension": _find_object_comprehension_vars(value)} if {
value.type == "objectcomprehension"
}

regal.last(path) == "symbols"
value[0].type != "call"
_rule_index(rule) := sprintf("%d", [i]) if {
some i, r in _rules # regal ignore:external-reference
r == rule
}

some var in _find_some_decl_vars(path, value)
]
find_some_decl_vars(rule) := [var | some var in vars[_rule_index(rule)]["some"]] # regal ignore:external-reference

# METADATA
# description: |
# traverses all nodes under provided node (using `walk`), and returns an array with
# all variables declared via assignment (:=), `some`, `every` and in comprehensions
# DEPRECATED: uses ast.vars instead
find_vars(node) := [var |
walk(node, [path, value])

some var in _find_vars(path, value, regal.last(path))
var := _find_vars(path, value, regal.last(path))[_][_]
]

# hack to work around the different input models of linting vs. the lsp package.. we
# should probably consider something more robust
_rules := input.rules

_rules := data.workspace.parsed[input.regal.file.uri].rules if not input.rules

# METADATA:
# description: |
# object containing all variables found in the input AST, keyed first by the index of
# the rule where the variables were found (as a numeric string), and then the context
# of the variable, which will be one of:
# - objectcomprehension
# - setorarraycomprehension
# - term
# - assign
# - every
# - some
# - somein
# - ref
vars[rule_index][context] contains var if {
some i, rule in _rules

# converting to string until https://github.com/open-policy-agent/opa/issues/6736 is fixed
rule_index := sprintf("%d", [i])

walk(rule, [path, value])

some context, vars in _find_vars(path, value, regal.last(path))
some var in vars
}

# METADATA
# description: |
# finds all vars declared in `rule` *before* the `location` provided
# note: this isn't 100% accurate, as it doesn't take into account `=`
# assignments / unification, but it's likely good enough since other rules
# recommend against those
find_vars_in_local_scope(rule, location) := [var |
some var in find_vars(rule)
var := vars[_rule_index(rule)][_][_] # regal ignore:external-reference

not startswith(var.value, "$")
_before_location(rule, var, location)
]
Expand Down
12 changes: 6 additions & 6 deletions bundle/regal/lsp/completion/location/location_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ another if {
}
`)

location.find_locals(module.rules, {"row": 6, "col": 1}) == set()
location.find_locals(module.rules, {"row": 6, "col": 10}) == {"x"}
location.find_locals(module.rules, {"row": 10, "col": 1}) == {"a", "b"}
location.find_locals(module.rules, {"row": 10, "col": 6}) == {"a", "b", "c"}
location.find_locals(module.rules, {"row": 15, "col": 1}) == {"x", "y"}
location.find_locals(module.rules, {"row": 16, "col": 1}) == {"x", "y", "z"}
location.find_locals(module.rules, {"row": 6, "col": 1}) with input as module == set()
location.find_locals(module.rules, {"row": 6, "col": 10}) with input as module == {"x"}
location.find_locals(module.rules, {"row": 10, "col": 1}) with input as module == {"a", "b"}
location.find_locals(module.rules, {"row": 10, "col": 6}) with input as module == {"a", "b", "c"}
location.find_locals(module.rules, {"row": 15, "col": 1}) with input as module == {"x", "y"}
location.find_locals(module.rules, {"row": 16, "col": 1}) with input as module == {"x", "y", "z"}
}
2 changes: 1 addition & 1 deletion bundle/regal/rules/custom/naming_convention.rego
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ report contains violation if {

target in {"var", "variable"}

some var in ast.find_vars(input.rules)
var := ast.vars[_][_][_]

not regex.match(convention.pattern, var.value)

Expand Down
2 changes: 1 addition & 1 deletion bundle/regal/rules/style/prefer_snake_case.rego
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ report contains violation if {
}

report contains violation if {
some var in ast.find_vars(input.rules)
var := ast.vars[_][_][_]
not util.is_snake_case(var.value)

violation := result.fail(rego.metadata.chain(), result.location(var))
Expand Down
6 changes: 3 additions & 3 deletions bundle/regal/rules/testing/metasyntactic_variable.rego
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ report contains violation if {
}

report contains violation if {
some rule in input.rules
some var in ast.find_vars(rule)
some i
var := ast.vars[i][_][_]

lower(var.value) in metasyntactic

ast.is_output_var(rule, var, var.location)
ast.is_output_var(input.rules[to_number(i)], var, var.location)

violation := result.fail(rego.metadata.chain(), result.location(var))
}

0 comments on commit 300eef2

Please sign in to comment.