зеркало из https://github.com/microsoft/ivy.git
in the middle of adding dependent inputs in test
This commit is contained in:
Родитель
e29c33891e
Коммит
8e2960f57b
|
@ -74,20 +74,26 @@ def indent_code(header,code):
|
|||
for line in code.split('\n'):
|
||||
header.append((indent_level * 4 + get_indent(line) - indent) * ' ' + line.strip() + '\n')
|
||||
|
||||
def sym_decl(sym,c_type = None,skip_params=0,classname=None):
|
||||
def sym_decl(sym,c_type = None,skip_params=0,classname=None,isref=False,ival=None):
|
||||
name, sort = sym.name,sym.sort
|
||||
dims = []
|
||||
the_c_type,dims = ctype_function(sort,skip_params=skip_params,classname=classname)
|
||||
res = (c_type or the_c_type) + ' '
|
||||
if isref:
|
||||
res += '(&'
|
||||
res += memname(sym) if skip_params else varname(sym.name)
|
||||
if isref:
|
||||
res += ')'
|
||||
for d in dims:
|
||||
res += '[' + str(d) + ']'
|
||||
if ival is not None:
|
||||
res += ' = '+ival;
|
||||
return res
|
||||
|
||||
def declare_symbol(header,sym,c_type = None,skip_params=0,classname=None):
|
||||
def declare_symbol(header,sym,c_type = None,skip_params=0,classname=None,isref=False,ival=None):
|
||||
if slv.solver_name(sym) == None:
|
||||
return # skip interpreted symbols
|
||||
header.append(' '+sym_decl(sym,c_type,skip_params,classname=classname)+';\n')
|
||||
header.append(' '+sym_decl(sym,c_type,skip_params,classname=classname,isref=isref,ival=ival)+';\n')
|
||||
|
||||
special_names = {
|
||||
'<' : '__lt',
|
||||
|
@ -396,6 +402,7 @@ def gather_referenced_symbols(expr,res,ignore=[]):
|
|||
ldf = is_derived[sym]
|
||||
gather_referenced_symbols(ldf.formula.args[1],res,ldf.formula.args[0].args)
|
||||
|
||||
skip_z3 = False
|
||||
|
||||
def make_thunk(impl,vs,expr):
|
||||
global the_classname
|
||||
|
@ -429,7 +436,7 @@ def make_thunk(impl,vs,expr):
|
|||
close_scope(impl)
|
||||
if target.get() in ["gen","test"]:
|
||||
open_scope(impl,line = 'z3::expr to_z3(gen &g, const z3::expr &v)')
|
||||
if False and isinstance(expr,HavocSymbol):
|
||||
if False and isinstance(expr,HavocSymbol) or skip_z3:
|
||||
code_line(impl,'return g.ctx.bool_val(true)')
|
||||
else:
|
||||
if lu.free_variables(expr):
|
||||
|
@ -802,6 +809,74 @@ def fix_definition(df):
|
|||
subst = dict((s,il.Variable('X__{}'.format(idx),s.sort)) for idx,s in enumerate(df.args[0].args) if not il.is_variable(s))
|
||||
return ilu.substitute_constants_ast(df,subst)
|
||||
|
||||
# An action input x is 'defined' by the action's precondition if the precondition is
|
||||
# of the form P & x = expr, where x does not occur in P. Here, we remove the defined inputs from
|
||||
# the precondition, to improve the performance of the solver. We also return a list of
|
||||
# the definitions, so the values of the defined inputs can be computed.
|
||||
|
||||
def extract_defined_parameters(pre_clauses,inputs):
|
||||
change = True
|
||||
inputset = set(inputs)
|
||||
defmap = {}
|
||||
iu.dbg('inputs')
|
||||
for fmla in pre_clauses.fmlas:
|
||||
iu.dbg('fmla')
|
||||
if il.is_eq(fmla) and fmla.args[0] in inputset:
|
||||
defmap[fmla.args[0]] = fmla
|
||||
iu.dbg('defmap')
|
||||
inpdefs = []
|
||||
*** make sure var does not occur in defs ***
|
||||
while change:
|
||||
change = False
|
||||
for input,fmla in list(defmap.iteritems()):
|
||||
if all(input not in ilu.used_variables_ast(f) or f == fmla for f in pre_clauses.fmlas):
|
||||
pre_clauses = ilu.Clauses([f for f in pre_clauses.fmlas if f != fmla],pre_clauses.defs)
|
||||
del defmap[input]
|
||||
inpdefs.append(fmla)
|
||||
change = True
|
||||
**** trim the defs ***
|
||||
inpdefs.reverse()
|
||||
pre_clauses = ilu.trim_clauses(pre_clauses)
|
||||
return pre_clauses,inpdefs
|
||||
|
||||
def collect_used_definitions(pre,inpdefs,ssyms):
|
||||
defmap = dict((d.defines(),d) for d in pre.defs)
|
||||
used = set()
|
||||
res = []
|
||||
usyms = []
|
||||
def recur(d):
|
||||
for sym in ilu.used_symbols_ast(d.args[1]):
|
||||
if sym not in used:
|
||||
used.add(sym)
|
||||
if sym in defmap:
|
||||
d = defmap[sym]
|
||||
recur(d)
|
||||
res.append(d)
|
||||
elif sym in ssyms:
|
||||
usyms.append(sym)
|
||||
for inpdef in inpdefs:
|
||||
recur(inpdef)
|
||||
return res,usyms
|
||||
|
||||
|
||||
def emit_defined_inputs(pre,inpdefs,code,classname,ssyms):
|
||||
udefs,usyms = collect_used_definitions(pre,inpdefs,ssyms)
|
||||
iu.dbg('usyms')
|
||||
for sym in usyms:
|
||||
declare_symbol(code,sym,classname=classname,isref=True,ival='obj.'+code_eval(code,sym))
|
||||
iu.dbg('udefs')
|
||||
global skip_z3
|
||||
skip_z3 = True
|
||||
for dfn in udefs:
|
||||
sym = dfn.defines()
|
||||
declare_symbol(code,sym,classname=classname)
|
||||
emit_assign(dfn,code)
|
||||
skip_z3 = False
|
||||
for param_def in inpdefs:
|
||||
code_line(code,code_eval(code,param_def.args[0]) + ' = ' + code_eval(code,param_def.args[1]))
|
||||
|
||||
|
||||
|
||||
def emit_action_gen(header,impl,name,action,classname):
|
||||
global indent_level
|
||||
global global_classname
|
||||
|
@ -829,22 +904,31 @@ def emit_action_gen(header,impl,name,action,classname):
|
|||
with ia.UnrollContext(card):
|
||||
upd = action.update(im.module,None)
|
||||
pre = tr.reverse_image(ilu.true_clauses(),ilu.true_clauses(),upd)
|
||||
orig_pre = pre
|
||||
pre_clauses = ilu.trim_clauses(pre)
|
||||
inputs = [x for x in ilu.used_symbols_clauses(pre_clauses) if is_local_sym(x) and not x.is_numeral()]
|
||||
pre_clauses, param_defs = extract_defined_parameters(pre_clauses,inputs)
|
||||
iu.dbg('pre_clauses')
|
||||
iu.dbg('param_defs')
|
||||
rdefs = im.relevant_definitions(ilu.symbols_clauses(pre_clauses))
|
||||
pre_clauses = ilu.and_clauses(pre_clauses,ilu.Clauses([fix_definition(ldf.formula).to_constraint() for ldf in rdefs]))
|
||||
pre_clauses = ilu.and_clauses(pre_clauses,ilu.Clauses(im.module.variant_axioms()))
|
||||
pre = pre_clauses.to_formula()
|
||||
used = set(ilu.used_symbols_ast(pre))
|
||||
iu.dbg('used')
|
||||
used_names = set(varname(s) for s in used)
|
||||
defed_params = set(f.args[0] for f in param_defs)
|
||||
for p in action.formal_params:
|
||||
if varname(p) not in used_names:
|
||||
p = p.prefix('__')
|
||||
if varname(p) not in used_names and p not in defed_params:
|
||||
used.add(p)
|
||||
for x in used:
|
||||
if x.is_numeral() and il.is_uninterpreted_sort(x.sort):
|
||||
raise iu.IvyError(None,'Cannot compile numeral {} of uninterpreted sort {}'.format(x,x.sort))
|
||||
syms = [x for x in used if is_local_sym(x) and not x.is_numeral()]
|
||||
iu.dbg('syms')
|
||||
header.append("class " + caname + "_gen : public gen {\n public:\n")
|
||||
for sym in syms:
|
||||
for sym in syms + list(defed_params):
|
||||
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
|
||||
declare_symbol(header,sym,classname=classname)
|
||||
header.append(" {}_gen();\n".format(caname))
|
||||
|
@ -884,6 +968,11 @@ def emit_action_gen(header,impl,name,action,classname):
|
|||
for sym in syms:
|
||||
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
|
||||
emit_eval(impl,sym,classname=classname)
|
||||
ssyms = set()
|
||||
for sym in all_state_symbols():
|
||||
if sym_is_member(sym):
|
||||
ssyms.add(sym)
|
||||
emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms)
|
||||
indent_level -= 2
|
||||
impl.append("""
|
||||
}""")
|
||||
|
@ -1201,7 +1290,7 @@ def open_loop(impl,vs,declare=True,bounds=None):
|
|||
ct = ctype(idx.sort)
|
||||
impl.append('for ('+ ((ct + ' ') if declare else '') + vn + ' = (' + ct + ')' + bds[0] + '; (int) ' + vn + ' < ' + bds[1] + '; ' + vn + ' = (' + ct + ')(((int)' + vn + ') + 1)) {\n')
|
||||
else:
|
||||
impl.append('for ('+ (ct if declare else '') + vn + ' = ' + bds[0] + '; ' + vn + ' < ' + bds[1] + '; ' + vn + '++) {\n')
|
||||
impl.append('for ('+ ((ct + ' ') if declare else '') + vn + ' = ' + bds[0] + '; ' + vn + ' < ' + bds[1] + '; ' + vn + '++) {\n')
|
||||
indent_level += 1
|
||||
|
||||
def close_loop(impl,vs):
|
||||
|
|
Загрузка…
Ссылка в новой задаче