Skip to content

Commit

Permalink
fix: Fixed identifier bug in the z3_driver
Browse files Browse the repository at this point in the history
  • Loading branch information
sillydan1 committed Sep 11, 2022
1 parent c6dae37 commit f2e6fbd
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 120 deletions.
14 changes: 10 additions & 4 deletions include/drivers/z3_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
#define EXPR_Z3_DRIVER_H
#include "operations.h"
#include "drivers/driver.h"
#include <c++/z3++.h>

namespace expr {
struct z3_driver : public driver {
struct impl;
z3_driver(const symbol_table_t& known_env, const symbol_table_t& unknown_env);
~z3_driver() override;

Expand All @@ -37,13 +37,19 @@ namespace expr {
void add_tree(const syntax_tree_t& tree) override;
void add_tree(const std::string& identifier, const syntax_tree_t& tree) override;

auto as_symbol_value(const z3::expr& e) -> symbol_value_t;
auto as_z3_expression(const syntax_tree_t& tree) -> z3::expr;
auto as_z3_expression(const identifier_t& ref) -> z3::expr;
auto as_z3_expression(const symbol_value_t& val) -> z3::expr;

symbol_table_t result{};
protected:
std::unique_ptr<impl> pimpl;

z3::context c{};
z3::solver s;
const symbol_table_t& known;
const symbol_table_t& unknown;
void solve();
};

}

#endif //ENABLE_Z3
Expand Down
47 changes: 23 additions & 24 deletions src/drivers/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,44 +80,43 @@ namespace expr {
}

auto interpreter::evaluate(const syntax_tree_t& tree) -> symbol_value_t {
symbol_value_t v{};
std::visit(ya::overload(
return std::visit(ya::overload(
[&](const identifier_t& r){
auto s = find(r.ident);
if(s == end)
throw std::out_of_range("not found: " + r.ident);
v = s->second;
return s->second;
},
[&](const operator_t& o) {
switch (o.operator_type) {
case operator_type_t::minus: v = sub(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::plus: v = add(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::star: v = mul(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::slash: v = div(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::percent: v = mod(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::hat: v = pow(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_and: v = _and(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_or: v = _or(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_xor: v = _xor(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_not: v = _not(evaluate(tree.children[0])); break;
case operator_type_t::_implies: v = _implies(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::gt: v = gt(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ge: v = ge(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ne: v = ne(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ee: v = ee(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::le: v = le(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::lt: v = lt(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::parentheses: v = evaluate(tree.children[0]); break;
case operator_type_t::minus: return sub(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::plus: return add(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::star: return mul(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::slash: return div(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::percent: return mod(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::hat: return pow(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_and: return _and(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_or: return _or(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_xor: return _xor(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::_not: return _not(evaluate(tree.children[0])); break;
case operator_type_t::_implies: return _implies(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::gt: return gt(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ge: return ge(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ne: return ne(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::ee: return ee(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::le: return le(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::lt: return lt(evaluate(tree.children[0]), evaluate(tree.children[1])); break;
case operator_type_t::parentheses: return evaluate(tree.children[0]); break;
}
},
[&v](const symbol_value_t& o){ v = o; },
[](const symbol_value_t& o){ return o; },
[&](const root_t& r){
if(!tree.children.empty())
v = evaluate(tree.children[0]);
return evaluate(tree.children[0]);
throw std::logic_error("ROOT has no children");
},
[](auto&&){ throw std::logic_error("operator type not recognized"); }
), static_cast<const underlying_syntax_node_t&>(tree.node));
return v;
}

auto interpreter::evaluate(const compiler::compiled_expr_collection_t& trees) -> symbol_table_t {
Expand Down
140 changes: 51 additions & 89 deletions src/drivers/z3_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,12 @@
* SOFTWARE.
*/
#include "drivers/z3_driver.h"
#include <c++/z3++.h>

namespace expr {
struct z3_driver::impl {
z3::context c{};
z3::solver s;
const symbol_table_t& known;
const symbol_table_t& unknown;
auto as_symbol_value(const z3::expr& e) -> symbol_value_t;
auto as_z3_expression(const syntax_tree_t& tree) -> z3::expr;
auto as_z3_expression(const symbol_reference_t& ref) -> z3::expr;
auto as_z3_expression(const c_symbol_reference_t& ref) -> z3::expr;
auto as_z3_expression(const symbol_value_t& val) -> z3::expr;
void solve();
impl(const symbol_table_t& known, const symbol_table_t& unknown) : c{}, s{c}, known{known}, unknown{unknown} {}
};
z3_driver::z3_driver(const symbol_table_t& known_env, const symbol_table_t& unknown_env)
: driver{{known_env, unknown_env}}, c{}, s{c}, known{known_env}, unknown{unknown_env} {}

z3_driver::z3_driver(const symbol_table_t& known_env,
const symbol_table_t& unknown_env)
: pimpl{std::make_unique<z3_driver::impl>(known_env,unknown_env)}, driver{}
{}

z3_driver::~z3_driver() {}
z3_driver::~z3_driver() = default;

int z3_driver::parse(const std::string &f) {
if (f.empty())
Expand All @@ -67,16 +50,13 @@ namespace expr {
}

auto z3_driver::get_symbol(const std::string &identifier) -> syntax_tree_t {
if (!pimpl->known.contains(identifier)) {
if(!pimpl->unknown.contains(identifier))
throw std::out_of_range(identifier + " not found");
return syntax_tree_t{pimpl->unknown.find(identifier)};
}
return syntax_tree_t{pimpl->known.at(identifier)};
if(!contains(identifier))
throw std::out_of_range(identifier + " not found");
return syntax_tree_t{identifier_t{identifier}};
}

void z3_driver::add_tree(const syntax_tree_t& tree) {
pimpl->s.add(pimpl->as_z3_expression(tree)); // Note: Only accepts boolean expressions (will throw if not)
s.add(as_z3_expression(tree)); // Note: Only accepts boolean expressions (will throw if not)
solve();
}

Expand All @@ -85,22 +65,22 @@ namespace expr {
}

void z3_driver::solve() {
switch (pimpl->s.check()) {
case z3::unsat: pimpl->s.reset(); throw std::domain_error("unsat");
case z3::unknown: pimpl->s.reset(); throw std::logic_error("unknown");
switch (s.check()) {
case z3::unsat: s.reset(); throw std::domain_error("unsat");
case z3::unknown: s.reset(); throw std::logic_error("unknown");
case z3::sat:
auto m = pimpl->s.get_model();
pimpl->s.reset();
auto m = s.get_model();
s.reset();
for(int i = 0; i < m.size(); i++) {
auto xx = m[i];
auto interp = xx.is_const() ? m.get_const_interp(xx) : m.get_func_interp(xx).else_value();
result[xx.name().str()] = pimpl->as_symbol_value(interp);
result[xx.name().str()] = as_symbol_value(interp);
}
break;
}
}

auto z3_driver::impl::as_symbol_value(const z3::expr &e) -> symbol_value_t {
auto z3_driver::as_symbol_value(const z3::expr &e) -> symbol_value_t {
if(e.is_int())
return (int) e.as_int64();
if(e.is_real())
Expand All @@ -112,73 +92,55 @@ namespace expr {
throw std::logic_error("invalid z3::expr value - unable to convert to expr::symbol_value_t");
}

auto z3_driver::impl::as_z3_expression(const symbol_value_t& val) -> z3::expr {
z3::expr v = c.int_val(0); // placeholder value
std::visit(ya::overload(
[&v, this](const int& i) { v = c.int_val(i); },
[&v, this](const float& f) { v = c.real_val(std::to_string(f).c_str()); },
[&v, this](const bool& b) { v = c.bool_val(b); },
[&v, this](const std::string& sv) { v = c.string_val(sv); },
auto z3_driver::as_z3_expression(const symbol_value_t& val) -> z3::expr {
return std::visit(ya::overload(
[this](const int& i) { return c.int_val(i); },
[this](const float& f) { return c.real_val(std::to_string(f).c_str()); },
[this](const bool& b) { return c.bool_val(b); },
[this](const std::string& sv) { return c.string_val(sv); },
[](auto&& x){ throw std::logic_error("unable to convert symbol value to z3::expr"); }
), static_cast<const underlying_symbol_value_t&>(val));
return v;
}

auto z3_driver::impl::as_z3_expression(const symbol_reference_t &ref) -> z3::expr {
z3::expr v = c.int_val(0); // placeholder value
std::visit(ya::overload(
[&v, this, &ref](const int& _) { v = c.int_const(ref->first.c_str()); },
[&v, this, &ref](const float& _) { v = c.real_const(ref->first.c_str()); },
[&v, this, &ref](const bool& _) { v = c.bool_const(ref->first.c_str()); },
[&v, this, &ref](const std::string& _) { v = c.string_const(ref->first.c_str()); },
auto z3_driver::as_z3_expression(const identifier_t& ref) -> z3::expr {
auto it = find(ref.ident);
return std::visit(ya::overload(
[this, &ref](const int& _) { return c.int_const(ref.ident.c_str()); },
[this, &ref](const float& _) { return c.real_const(ref.ident.c_str()); },
[this, &ref](const bool& _) { return c.bool_const(ref.ident.c_str()); },
[this, &ref](const std::string& _) { return c.string_const(ref.ident.c_str()); },
[](auto&& x){ throw std::logic_error("unable to convert symbol reference to z3::expr"); }
), static_cast<const underlying_symbol_value_t&>(ref->second));
return v;
}

auto z3_driver::impl::as_z3_expression(const c_symbol_reference_t &ref) -> z3::expr {
z3::expr v = c.int_val(0); // placeholder value
std::visit(ya::overload(
[&v, this, &ref](const int& _) { v = c.int_const(ref->first.c_str()); },
[&v, this, &ref](const float& _) { v = c.real_const(ref->first.c_str()); },
[&v, this, &ref](const bool& _) { v = c.bool_const(ref->first.c_str()); },
[&v, this, &ref](const std::string& _) { v = c.string_const(ref->first.c_str()); },
[](auto&& x){ throw std::logic_error("unable to convert const symbol reference to z3::expr"); }
), static_cast<const underlying_symbol_value_t&>(ref->second));
return v;
), static_cast<const underlying_symbol_value_t&>(it->second));
}

auto z3_driver::impl::as_z3_expression(const syntax_tree_t &tree) -> z3::expr {
z3::expr v = c.int_val(0); // placeholder value
std::visit(ya::overload(
[&v,this](const symbol_reference_t& r) { v = as_z3_expression(r); },
[&v,this](const c_symbol_reference_t& r) { v = as_z3_expression(r); },
auto z3_driver::as_z3_expression(const syntax_tree_t &tree) -> z3::expr {
return std::visit(ya::overload(
[this](const identifier_t& r) { return as_z3_expression(r); },
[&](const operator_t& o) {
switch (o.operator_type) {
case operator_type_t::minus: v = as_z3_expression(tree.children[0]) - as_z3_expression(tree.children[1]); break;
case operator_type_t::plus: v = as_z3_expression(tree.children[0]) + as_z3_expression(tree.children[1]); break;
case operator_type_t::star: v = as_z3_expression(tree.children[0]) * as_z3_expression(tree.children[1]); break;
case operator_type_t::slash: v = as_z3_expression(tree.children[0]) / as_z3_expression(tree.children[1]); break;
case operator_type_t::percent: v = as_z3_expression(tree.children[0]) % as_z3_expression(tree.children[1]); break;
case operator_type_t::hat: v = z3::pw(as_z3_expression(tree.children[0]), as_z3_expression(tree.children[1])); break;
case operator_type_t::_and: v = as_z3_expression(tree.children[0]) && as_z3_expression(tree.children[1]); break;
case operator_type_t::_or: v = as_z3_expression(tree.children[0]) || as_z3_expression(tree.children[1]); break;
case operator_type_t::_xor: v = as_z3_expression(tree.children[0]) xor as_z3_expression(tree.children[1]); break;
case operator_type_t::_not: v =!as_z3_expression(tree.children[0]); break;
case operator_type_t::_implies: v = implies(as_z3_expression(tree.children[0]),as_z3_expression(tree.children[1])); break;
case operator_type_t::gt: v = (as_z3_expression(tree.children[0]) > as_z3_expression(tree.children[1])); break;
case operator_type_t::ge: v = (as_z3_expression(tree.children[0]) >= as_z3_expression(tree.children[1])); break;
case operator_type_t::ne: v = (as_z3_expression(tree.children[0]) != as_z3_expression(tree.children[1])); break;
case operator_type_t::ee: v = (as_z3_expression(tree.children[0]) == as_z3_expression(tree.children[1])); break;
case operator_type_t::le: v = (as_z3_expression(tree.children[0]) <= as_z3_expression(tree.children[1])); break;
case operator_type_t::lt: v = (as_z3_expression(tree.children[0]) < as_z3_expression(tree.children[1])); break;
case operator_type_t::parentheses: v = (as_z3_expression(tree.children[0])); break;
case operator_type_t::minus: return as_z3_expression(tree.children[0]) - as_z3_expression(tree.children[1]); break;
case operator_type_t::plus: return as_z3_expression(tree.children[0]) + as_z3_expression(tree.children[1]); break;
case operator_type_t::star: return as_z3_expression(tree.children[0]) * as_z3_expression(tree.children[1]); break;
case operator_type_t::slash: return as_z3_expression(tree.children[0]) / as_z3_expression(tree.children[1]); break;
case operator_type_t::percent: return as_z3_expression(tree.children[0]) % as_z3_expression(tree.children[1]); break;
case operator_type_t::hat: return z3::pw(as_z3_expression(tree.children[0]), as_z3_expression(tree.children[1])); break;
case operator_type_t::_and: return as_z3_expression(tree.children[0]) && as_z3_expression(tree.children[1]); break;
case operator_type_t::_or: return as_z3_expression(tree.children[0]) || as_z3_expression(tree.children[1]); break;
case operator_type_t::_xor: return as_z3_expression(tree.children[0]) xor as_z3_expression(tree.children[1]); break;
case operator_type_t::_not: return !as_z3_expression(tree.children[0]); break;
case operator_type_t::_implies: return implies(as_z3_expression(tree.children[0]),as_z3_expression(tree.children[1])); break;
case operator_type_t::gt: return (as_z3_expression(tree.children[0]) > as_z3_expression(tree.children[1])); break;
case operator_type_t::ge: return (as_z3_expression(tree.children[0]) >= as_z3_expression(tree.children[1])); break;
case operator_type_t::ne: return (as_z3_expression(tree.children[0]) != as_z3_expression(tree.children[1])); break;
case operator_type_t::ee: return (as_z3_expression(tree.children[0]) == as_z3_expression(tree.children[1])); break;
case operator_type_t::le: return (as_z3_expression(tree.children[0]) <= as_z3_expression(tree.children[1])); break;
case operator_type_t::lt: return (as_z3_expression(tree.children[0]) < as_z3_expression(tree.children[1])); break;
case operator_type_t::parentheses: return (as_z3_expression(tree.children[0])); break;
}
},
[&v,this](const symbol_value_t& o){ v = as_z3_expression(o); },
[&](const root_t& r){ v = as_z3_expression(tree.children[0]); },
[this](const symbol_value_t& o){ return as_z3_expression(o); },
[&](const root_t& r){ return as_z3_expression(tree.children[0]); },
[](auto&&){ throw std::logic_error("tree node type not recognized"); }
), static_cast<const underlying_syntax_node_t&>(tree.node));
return v;
}
}
Loading

0 comments on commit f2e6fbd

Please sign in to comment.