simple example of generating structures is working

This commit is contained in:
Ken McMillan 2016-09-28 15:41:13 -07:00
Родитель a34e05a208
Коммит ebf472ef70
2 изменённых файлов: 183 добавлений и 24 удалений

Просмотреть файл

@ -150,20 +150,31 @@ def declare_ctuple(header,dom):
for idx,sort in enumerate(dom):
sym = il.Symbol('arg{}'.format(idx),sort)
declare_symbol(header,sym)
header.append(t+'(){}')
header.append(t+'('+','.join('const '+ctypefull(d)+' &arg'+str(idx) for idx,d in enumerate(dom))
+ ') : '+','.join('arg'+str(idx)+'(arg'+str(idx)+')' for idx,d in enumerate(dom))
+ '{}\n')
header.append(" size_t __hash() const { return "+struct_hash_fun(['arg{}'.format(n) for n in range(len(dom))],dom) + ";}\n")
header.append('};\n')
def ctuple_hash(dom):
if len(dom) == 1:
return 'hash<'+ctypefull(dom[0])+'>'
else:
return 'hash__' + ctuple(dom)
def declare_ctuple_hash(header,dom,classname):
t = ctuple(dom)
the_type = classname+'::'+t
header.append("""
namespace hash_space {
template <>
class hash<the_type> {
class the_hash_type {
public:
size_t operator()(const the_type &s) const {
size_t operator()(const the_type &__s) const {
return the_val;
}
};
}
""".replace('the_type',classname+'::'+t).replace('the_val','+'.join('hash<{}>()(arg{})'.format(ctype(s),i,classname=classname) for i,s in enumerate(dom))))
""".replace('the_hash_type',ctuple_hash(dom)).replace('the_type',the_type).replace('the_val','+'.join('hash_space::hash<{}>()(__s.arg{})'.format(ctype(s),i,classname=classname) for i,s in enumerate(dom))))
def declare_hash_thunk(header):
header.append("""
@ -174,10 +185,10 @@ struct thunk {
return 0;
}
};
template <typename D, typename R>
template <typename D, typename R, class HashFun = hash_space::hash<D> >
struct hash_thunk {
thunk<D,R> *fun;
hash_space::hash_map<D,R> memo;
hash_space::hash_map<D,R,HashFun> memo;
hash_thunk() : fun(0) {}
hash_thunk(thunk<D,R> *fun) : fun(fun) {}
~hash_thunk() {
@ -185,7 +196,7 @@ struct hash_thunk {
// delete fun;
}
R &operator[](const D& arg){
std::pair<typename hash_space::hash_map<D,R>::iterator,bool> foo = memo.insert(std::pair<D,R>(arg,D()));
std::pair<typename hash_space::hash_map<D,R>::iterator,bool> foo = memo.insert(std::pair<D,R>(arg,R()));
R &res = foo.first->second;
if (foo.second)
res = (*fun)(arg);
@ -206,7 +217,7 @@ def all_ctuples():
res = tuple(sym.sort.dom)
if res not in done:
done.add(res)
yield sym
yield res
def declare_all_ctuples(header):
for dom in all_ctuples():
@ -249,14 +260,23 @@ native_expr_full = native_type_full
thunk_counter = 0
def expr_to_z3(expr):
fmla = '(assert ' + slv.formula_to_z3(expr).sexpr().replace('\n',' ') + ')'
return 'z3::expr(g.ctx,Z3_parse_smtlib2_string(ctx, "{}", sort_names.size(), &sort_names[0], &sorts[0], decl_names.size(), &decl_names[0], &decls[0]))'.format(fmla)
def make_thunk(impl,vs,expr):
global the_classname
dom = [v.sort for v in vs]
D = ctuple(dom)
D = the_classname+'::'+ctuple(dom)
R = ctypefull(expr.sort)
global thunk_counter
name = '__thunk__{}'.format(thunk_counter)
thunk_counter += 1
open_scope(impl,line='struct {} : thunk<{},{}>'.format(name,D,R))
thunk_class = 'z3_thunk' if target.get() in ["gen","test"] else 'thunk'
open_scope(impl,line='struct {} : {}<{},{}>'.format(name,thunk_class,D,R))
env = list(ilu.used_symbols_ast(expr))
for sym in env:
declare_symbol(impl,sym)
@ -270,9 +290,32 @@ def make_thunk(impl,vs,expr):
expr = ilu.substitute_ast(expr,subst)
code_line(impl,'return ' + code_eval(impl,expr))
close_scope(impl)
if target.get() in ["gen","test"]:
open_scope(impl,line = 'z3::expr to_z3(gen &g, const z3::expr &v)')
if lu.free_variables(expr):
raise iu.IvyError(None,"cannot compile {}".format(expr))
if all(s.is_numeral() for s in ilu.used_symbols_ast(expr)):
code_line(impl,'z3::expr res = v == g.int_to_z3(g.sort("{}"),(int)({}))'.format(expr.sort.name,code_eval(impl,expr)))
else:
raise iu.IvyError(None,"cannot compile {}".format(expr))
code_line(impl,'return res')
close_scope(impl)
close_scope(impl,semi=True)
return 'hash_thunk<{},{}>(new {}({}))'.format(D,R,name,','.join(envnames))
def struct_hash_fun(field_names,field_sorts):
return '+'.join('hash_space::hash<{}>()({})'.format(ctype(s),varname(f)) for s,f in zip(field_sorts,field_names))
def emit_struct_hash(header,the_type,field_names,field_sorts):
header.append("""
template<> class hash<the_type> {
public:
size_t operator()(const the_type &__s) const {
return the_val;
}
};
""".replace('the_type',the_type).replace('the_val',struct_hash_fun(['__s.'+n for n in field_names],field_sorts)))
def emit_cpp_sorts(header):
for name in im.module.sort_order:
if name in im.module.native_types:
@ -280,9 +323,12 @@ def emit_cpp_sorts(header):
header.append(" typedef " + nt + ' ' + varname(name) + ";\n");
elif name in im.module.sort_destructors:
header.append(" struct " + varname(name) + " {\n");
for destr in im.module.sort_destructors[name]:
destrs = im.module.sort_destructors[name]
for destr in destrs:
declare_symbol(header,destr,skip_params=1)
header.append(" size_t __hash() const { return "+struct_hash_fun(map(varname,destrs),[d.sort.rng for d in destrs]) + ";}\n")
header.append(" };\n");
def emit_sorts(header):
for name,sort in il.sig.sorts.iteritems():
@ -400,6 +446,11 @@ def emit_set(header,symbol):
emit_set_field(header,destr,lhs,rhs,len(vs))
close_loop(header,vs)
return
if is_large_type(sort):
vs = variables(sort.dom)
cvars = ','.join('ctx.constant("{}",sort("{}"))'.format(varname(v),v.sort.name) for v in vs)
code_line(header,'slvr.add(forall({},__to_solver(*this,apply("{}",{}),obj.{})))'.format(cvars,sname,cvars,cname))
return
for idx,dsort in enumerate(domain):
dcard = sort_card(dsort)
indent(header)
@ -480,11 +531,12 @@ public:
if sym in used:
emit_randomize(impl,sym,classname=classname)
else:
if not is_native_sym(sym):
if not is_native_sym(sym) and not is_large_type(sym.sort):
fun = lambda v: (mk_rand(v.sort) if not is_native_sym(v) else None)
assign_array_from_model(impl,sym,'obj.',fun)
indent_level -= 1
impl.append("""
// std::cout << slvr << std::endl;
bool res = solve();
if (res) {
""")
@ -575,6 +627,7 @@ def emit_action_gen(header,impl,name,action,classname):
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx:
emit_randomize(impl,sym,classname=classname)
impl.append("""
// std::cout << slvr << std::endl;
bool res = solve();
if (res) {
""")
@ -885,7 +938,38 @@ def check_member_names(classname):
raise iu.IvyError(None,'Cannot create C++ class {} with member {}.\nUse command line option classname=... to change the class name.'
.format(classname,classname))
def emit_ctuple_to_solver(header,dom,classname):
ct_name = classname + '::' + ctuple(dom)
ch_name = classname + '::' + ctuple_hash(dom)
open_scope(header,line='template<typename R> class to_solver_class<hash_thunk<D,R> >'.replace('D',ct_name).replace('H',ch_name))
code_line(header,'public:')
open_scope(header,line='z3::expr operator()( gen &g, const z3::expr &v, hash_thunk<D,R> &val)'.replace('D',ct_name).replace('H',ch_name))
code_line(header,'z3::expr res = g.ctx.bool_val(true)')
code_line(header,'z3::expr disj = g.ctx.bool_val(false)')
open_scope(header,line='for(typename hash_map<D,R>::iterator it=val.memo.begin(), en = val.memo.end(); it != en; it++)'.replace('D',ct_name).replace('H',ch_name))
code_line(header,'z3::expr cond = '+' && '.join('__to_solver(g,v.arg('+str(n)+'),it->first.arg'+str(n)+')' for n in range(len(dom))))
code_line(header,'res = res && implies(cond,__to_solver(g,v,it->second))')
code_line(header,'disj = disj || cond')
close_scope(header)
code_line(header,'res = res && (disj || dynamic_cast<z3_thunk<D,R> *>(val.fun)->to_z3(g,v))'.replace('D',ct_name))
code_line(header,'return res')
close_scope(header)
close_scope(header,semi=True)
def emit_all_ctuples_to_solver(header,classname):
for dom in all_ctuples():
emit_ctuple_to_solver(header,dom,classname)
def emit_ctuple_equality(header,dom,classname):
t = ctuple(dom)
open_scope(header,line = 'bool operator==(const {}::{} &x, const {}::{} &y)'.format(classname,t,classname,t))
code_line(header,'return '+' && '.join('x.arg{} == y.arg{}'.format(n,n) for n in range(len(dom))))
close_scope(header)
def module_to_cpp_class(classname,basename):
global the_classname
the_classname = classname
check_member_names(classname)
global is_derived
is_derived = set()
@ -951,8 +1035,6 @@ def module_to_cpp_class(classname,basename):
""")
impl.append("typedef {} ivy_class;\n".format(classname))
declare_all_ctuples(header)
native_exprs = []
for n in im.module.natives:
native_exprs.extend(n.args[2:])
@ -1054,6 +1136,25 @@ void __from_solver<bool>( gen &g, const z3::expr &v, bool &res) {
res = g.eval(v);
}
template <class T>
class to_solver_class {
};
template <class T> z3::expr __to_solver( gen &g, const z3::expr &v, T &val) {
return to_solver_class<T>()(g,v,val);
}
template <>
z3::expr __to_solver<int>( gen &g, const z3::expr &v, int &val) {
return v == g.int_to_z3(v.get_sort(),val);
}
template <>
z3::expr __to_solver<bool>( gen &g, const z3::expr &v, bool &val) {
return v == g.int_to_z3(v.get_sort(),val);
}
template <class T> void __randomize( gen &g, const z3::expr &v);
template <>
@ -1065,6 +1166,13 @@ template <>
void __randomize<bool>( gen &g, const z3::expr &v) {
g.randomize(v);
}
template<typename D, typename R>
class z3_thunk : public thunk<D,R> {
public:
virtual z3::expr to_z3(gen &g, const z3::expr &v) = 0;
};
""")
if True or target.get() == "repl":
@ -1083,8 +1191,13 @@ void __randomize<bool>( gen &g, const z3::expr &v) {
impl.append('template <>\n')
impl.append('void __from_solver<' + cfsname + '>( gen &g, const z3::expr &v, ' + cfsname + ' &res);\n')
impl.append('template <>\n')
impl.append('z3::expr __to_solver<' + cfsname + '>( gen &g, const z3::expr &v, ' + cfsname + ' &val);\n')
impl.append('template <>\n')
impl.append('void __randomize<' + cfsname + '>( gen &g, const z3::expr &v);\n')
for dom in all_ctuples():
emit_ctuple_equality(impl,dom,classname)
once_memo = set()
for native in im.module.natives:
tag = native_type(native)
@ -1117,6 +1230,8 @@ void __randomize<bool>( gen &g, const z3::expr &v) {
""")
emit_cpp_sorts(header)
declare_all_ctuples(header)
declare_all_ctuples_hash(header,classname)
for sym in all_state_symbols():
if sym_is_member(sym):
declare_symbol(header,sym)
@ -1149,8 +1264,6 @@ void __randomize<bool>( gen &g, const z3::expr &v) {
emit_tick(header,impl,classname)
header.append('};\n')
declare_all_ctuples_hash(header,classname)
impl.append(classname + '::')
emit_param_decls(impl,classname,im.module.params)
impl.append('{\n')
@ -1295,6 +1408,19 @@ void __randomize<bool>( gen &g, const z3::expr &v) {
close_loop(impl,[v])
close_scope(impl)
impl.append('template <>\n')
open_scope(impl,line='z3::expr __to_solver<' + cfsname + '>( gen &g, const z3::expr &v,' + cfsname + ' &val)')
code_line(impl,'z3::expr res = g.ctx.bool_val(1)')
for idx,sym in enumerate(destrs):
fname = memname(sym)
vs = variables(sym.sort.dom[1:])
for v in vs:
open_loop(impl,[v])
code_line(impl,'res = res && __to_solver(g,g.apply("'+sym.name+'",v'+ ''.join(',g.int_to_z3(g.sort("'+v.sort.name+'"),'+varname(v)+')' for v in vs)+'),val.'+fname+''.join('[{}]'.format(varname(v)) for v in vs) + ')')
for v in vs:
close_loop(impl,[v])
code_line(impl,'return res')
close_scope(impl)
impl.append('template <>\n')
open_scope(impl,line='void __randomize<' + cfsname + '>( gen &g, const z3::expr &v)')
for idx,sym in enumerate(destrs):
fname = memname(sym)
@ -1306,6 +1432,7 @@ void __randomize<bool>( gen &g, const z3::expr &v) {
close_loop(impl,[v])
close_scope(impl)
emit_all_ctuples_to_solver(impl,classname)
emit_repl_boilerplate1a(header,impl,classname)
@ -1570,6 +1697,15 @@ def emit_app(self,header,code):
a.emit(header,code)
first = False
code.append(')')
elif is_large_type(self.rep.sort) and len(self.args[skip_params:]) > 1:
code.append('[' + ctuple(self.rep.sort.dom[skip_params:]) + '(')
first = True
for a in self.args[skip_params:]:
if not first:
code.append(',')
a.emit(header,code)
first = False
code.append(')]')
else:
for a in self.args[skip_params:]:
code.append('[')
@ -1814,7 +1950,7 @@ def emit_assign_large(self,header):
def emit_assign(self,header):
global indent_level
with ivy_ast.ASTContext(self):
if is_large_type(self.args[0].rep.sort):
if is_large_type(self.args[0].rep.sort) and lu.free_variables(self.args[0]):
emit_assign_large(self,header)
return
vs = list(lu.free_variables(self.args[0]))
@ -2416,8 +2552,9 @@ using namespace hash_space;
class gen : public ivy_gen {
protected:
public:
z3::context ctx;
protected:
z3::solver slvr;
z3::model model;
@ -2653,6 +2790,12 @@ public:
slvr.pop();
}
z3::sort sort(const char *name) {
if (std::string("bool") == name)
return ctx.bool_sort();
return enum_sorts.find(name)->second;
}
void mk_enum(const char *sort_name, unsigned num_values, char const * const * value_names) {
z3::func_decl_vector cs(ctx), ts(ctx);
z3::sort sort = ctx.enumeration_sort(sort_name, num_values, value_names, cs, ts);
@ -2809,7 +2952,12 @@ namespace hash_space {
unsigned string_hash(const char * str, unsigned length, unsigned init_value);
template <typename T> class hash {};
template <typename T> class hash {
public:
size_t operator()(const T &s) const {
return s.__hash();
}
};
template <>
class hash<int> {

Просмотреть файл

@ -9,10 +9,21 @@ type s = struct {
y : t
}
individual foo : s
relation sent(X:s,Y:t)
action bar(q:s) = {
assume y(q) = x(foo)
after init {
sent(X,Y) := false
}
individual foo : s
action baz(q:s,r:t) = {
sent(q,r) := true
}
action bar(q:s,r:t) = {
assume sent(q,r)
}
export baz
export bar