diff --git a/ivy/ivy_to_cpp.py b/ivy/ivy_to_cpp.py index 3c34ab0..f0626b4 100755 --- a/ivy/ivy_to_cpp.py +++ b/ivy/ivy_to_cpp.py @@ -74,20 +74,26 @@ def indent_code(header,code): for line in code.split('\n'): header.append((indent_level * 4 + get_indent(line) - indent) * ' ' + line.strip() + '\n') -def sym_decl(sym,c_type = None,skip_params=0,classname=None): +def sym_decl(sym,c_type = None,skip_params=0,classname=None,isref=False,ival=None): name, sort = sym.name,sym.sort dims = [] the_c_type,dims = ctype_function(sort,skip_params=skip_params,classname=classname) res = (c_type or the_c_type) + ' ' + if isref: + res += '(&' res += memname(sym) if skip_params else varname(sym.name) + if isref: + res += ')' for d in dims: res += '[' + str(d) + ']' + if ival is not None: + res += ' = '+ival; return res -def declare_symbol(header,sym,c_type = None,skip_params=0,classname=None): +def declare_symbol(header,sym,c_type = None,skip_params=0,classname=None,isref=False,ival=None): if slv.solver_name(sym) == None: return # skip interpreted symbols - header.append(' '+sym_decl(sym,c_type,skip_params,classname=classname)+';\n') + header.append(' '+sym_decl(sym,c_type,skip_params,classname=classname,isref=isref,ival=ival)+';\n') special_names = { '<' : '__lt', @@ -396,6 +402,7 @@ def gather_referenced_symbols(expr,res,ignore=[]): ldf = is_derived[sym] gather_referenced_symbols(ldf.formula.args[1],res,ldf.formula.args[0].args) +skip_z3 = False def make_thunk(impl,vs,expr): global the_classname @@ -429,7 +436,7 @@ def make_thunk(impl,vs,expr): close_scope(impl) if target.get() in ["gen","test"]: open_scope(impl,line = 'z3::expr to_z3(gen &g, const z3::expr &v)') - if False and isinstance(expr,HavocSymbol): + if False and isinstance(expr,HavocSymbol) or skip_z3: code_line(impl,'return g.ctx.bool_val(true)') else: if lu.free_variables(expr): @@ -802,6 +809,74 @@ def fix_definition(df): subst = dict((s,il.Variable('X__{}'.format(idx),s.sort)) for idx,s in enumerate(df.args[0].args) if not il.is_variable(s)) return ilu.substitute_constants_ast(df,subst) +# An action input x is 'defined' by the action's precondition if the precondition is +# of the form P & x = expr, where x does not occur in P. Here, we remove the defined inputs from +# the precondition, to improve the performance of the solver. We also return a list of +# the definitions, so the values of the defined inputs can be computed. + +def extract_defined_parameters(pre_clauses,inputs): + change = True + inputset = set(inputs) + defmap = {} + iu.dbg('inputs') + for fmla in pre_clauses.fmlas: + iu.dbg('fmla') + if il.is_eq(fmla) and fmla.args[0] in inputset: + defmap[fmla.args[0]] = fmla + iu.dbg('defmap') + inpdefs = [] + *** make sure var does not occur in defs *** + while change: + change = False + for input,fmla in list(defmap.iteritems()): + if all(input not in ilu.used_variables_ast(f) or f == fmla for f in pre_clauses.fmlas): + pre_clauses = ilu.Clauses([f for f in pre_clauses.fmlas if f != fmla],pre_clauses.defs) + del defmap[input] + inpdefs.append(fmla) + change = True + **** trim the defs *** + inpdefs.reverse() + pre_clauses = ilu.trim_clauses(pre_clauses) + return pre_clauses,inpdefs + +def collect_used_definitions(pre,inpdefs,ssyms): + defmap = dict((d.defines(),d) for d in pre.defs) + used = set() + res = [] + usyms = [] + def recur(d): + for sym in ilu.used_symbols_ast(d.args[1]): + if sym not in used: + used.add(sym) + if sym in defmap: + d = defmap[sym] + recur(d) + res.append(d) + elif sym in ssyms: + usyms.append(sym) + for inpdef in inpdefs: + recur(inpdef) + return res,usyms + + +def emit_defined_inputs(pre,inpdefs,code,classname,ssyms): + udefs,usyms = collect_used_definitions(pre,inpdefs,ssyms) + iu.dbg('usyms') + for sym in usyms: + declare_symbol(code,sym,classname=classname,isref=True,ival='obj.'+code_eval(code,sym)) + iu.dbg('udefs') + global skip_z3 + skip_z3 = True + for dfn in udefs: + sym = dfn.defines() + declare_symbol(code,sym,classname=classname) + emit_assign(dfn,code) + skip_z3 = False + for param_def in inpdefs: + code_line(code,code_eval(code,param_def.args[0]) + ' = ' + code_eval(code,param_def.args[1])) + + + def emit_action_gen(header,impl,name,action,classname): global indent_level global global_classname @@ -829,22 +904,31 @@ def emit_action_gen(header,impl,name,action,classname): with ia.UnrollContext(card): upd = action.update(im.module,None) pre = tr.reverse_image(ilu.true_clauses(),ilu.true_clauses(),upd) + orig_pre = pre pre_clauses = ilu.trim_clauses(pre) + inputs = [x for x in ilu.used_symbols_clauses(pre_clauses) if is_local_sym(x) and not x.is_numeral()] + pre_clauses, param_defs = extract_defined_parameters(pre_clauses,inputs) + iu.dbg('pre_clauses') + iu.dbg('param_defs') rdefs = im.relevant_definitions(ilu.symbols_clauses(pre_clauses)) pre_clauses = ilu.and_clauses(pre_clauses,ilu.Clauses([fix_definition(ldf.formula).to_constraint() for ldf in rdefs])) pre_clauses = ilu.and_clauses(pre_clauses,ilu.Clauses(im.module.variant_axioms())) pre = pre_clauses.to_formula() used = set(ilu.used_symbols_ast(pre)) + iu.dbg('used') used_names = set(varname(s) for s in used) + defed_params = set(f.args[0] for f in param_defs) for p in action.formal_params: - if varname(p) not in used_names: + p = p.prefix('__') + if varname(p) not in used_names and p not in defed_params: used.add(p) for x in used: if x.is_numeral() and il.is_uninterpreted_sort(x.sort): raise iu.IvyError(None,'Cannot compile numeral {} of uninterpreted sort {}'.format(x,x.sort)) syms = [x for x in used if is_local_sym(x) and not x.is_numeral()] + iu.dbg('syms') header.append("class " + caname + "_gen : public gen {\n public:\n") - for sym in syms: + for sym in syms + list(defed_params): if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>': declare_symbol(header,sym,classname=classname) header.append(" {}_gen();\n".format(caname)) @@ -884,6 +968,11 @@ def emit_action_gen(header,impl,name,action,classname): for sym in syms: if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>': emit_eval(impl,sym,classname=classname) + ssyms = set() + for sym in all_state_symbols(): + if sym_is_member(sym): + ssyms.add(sym) + emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms) indent_level -= 2 impl.append(""" }""") @@ -1201,7 +1290,7 @@ def open_loop(impl,vs,declare=True,bounds=None): ct = ctype(idx.sort) impl.append('for ('+ ((ct + ' ') if declare else '') + vn + ' = (' + ct + ')' + bds[0] + '; (int) ' + vn + ' < ' + bds[1] + '; ' + vn + ' = (' + ct + ')(((int)' + vn + ') + 1)) {\n') else: - impl.append('for ('+ (ct if declare else '') + vn + ' = ' + bds[0] + '; ' + vn + ' < ' + bds[1] + '; ' + vn + '++) {\n') + impl.append('for ('+ ((ct + ' ') if declare else '') + vn + ' = ' + bds[0] + '; ' + vn + ' < ' + bds[1] + '; ' + vn + '++) {\n') indent_level += 1 def close_loop(impl,vs):