зеркало из https://github.com/microsoft/ivy.git
working on dependent input fields
This commit is contained in:
Родитель
8e2960f57b
Коммит
2676736c0a
|
@ -580,11 +580,11 @@ def int_to_z3(sort,val):
|
|||
raise iu.IvyError(None,"cannot produce test generator because sort {} is uninterpreted".format(sort))
|
||||
return 'int_to_z3(sort("'+sort.name+'"),'+val+')'
|
||||
|
||||
def emit_eval(header,symbol,obj=None,classname=None):
|
||||
def emit_eval(header,symbol,obj=None,classname=None,lhs=None):
|
||||
global indent_level
|
||||
name = symbol.name
|
||||
sname = slv.solver_name(symbol)
|
||||
cname = varname(name)
|
||||
cname = varname(name) if lhs is None else code_eval(header,lhs)
|
||||
sort = symbol.sort
|
||||
domain = sort_domain(sort)
|
||||
for idx,dsort in enumerate(domain):
|
||||
|
@ -594,7 +594,7 @@ def emit_eval(header,symbol,obj=None,classname=None):
|
|||
indent_level += 1
|
||||
indent(header)
|
||||
if sort.rng.name in im.module.sort_destructors or sort.rng.name in im.module.native_types or sort.rng in sort_to_cpptype:
|
||||
code_line(header,'__from_solver<'+classname+'::'+varname(sort.rng.name)+'>(*this,apply("'+sname+'"'+''.join(','+int_to_z3(s,'X{}'.format(idx)) for idx,s in enumerate(domain))+'),'+varname(symbol)+''.join('[X{}]'.format(idx) for idx in range(len(domain)))+')')
|
||||
code_line(header,'__from_solver<'+classname+'::'+varname(sort.rng.name)+'>(*this,apply("'+sname+'"'+''.join(','+int_to_z3(s,'X{}'.format(idx)) for idx,s in enumerate(domain))+'),'+cname+''.join('[X{}]'.format(idx) for idx in range(len(domain)))+')')
|
||||
else:
|
||||
header.append((obj + '.' if obj else '')
|
||||
+ cname + ''.join("[X{}]".format(idx) for idx in range(len(domain)))
|
||||
|
@ -825,18 +825,17 @@ def extract_defined_parameters(pre_clauses,inputs):
|
|||
defmap[fmla.args[0]] = fmla
|
||||
iu.dbg('defmap')
|
||||
inpdefs = []
|
||||
*** make sure var does not occur in defs ***
|
||||
while change:
|
||||
change = False
|
||||
for input,fmla in list(defmap.iteritems()):
|
||||
if all(input not in ilu.used_variables_ast(f) or f == fmla for f in pre_clauses.fmlas):
|
||||
if (all(input not in ilu.used_symbols_ast(f) or f == fmla for f in pre_clauses.fmlas)
|
||||
and all(input not in ilu.used_symbols_ast(d) for d in pre_clauses.defs)):
|
||||
pre_clauses = ilu.Clauses([f for f in pre_clauses.fmlas if f != fmla],pre_clauses.defs)
|
||||
del defmap[input]
|
||||
inpdefs.append(fmla)
|
||||
change = True
|
||||
**** trim the defs ***
|
||||
pre_clauses = ilu.trim_clauses(pre_clauses)
|
||||
inpdefs.reverse()
|
||||
pre_clauses = ilu.trim_clauses(pre_clauses)
|
||||
return pre_clauses,inpdefs
|
||||
|
||||
def collect_used_definitions(pre,inpdefs,ssyms):
|
||||
|
@ -859,7 +858,7 @@ def collect_used_definitions(pre,inpdefs,ssyms):
|
|||
return res,usyms
|
||||
|
||||
|
||||
def emit_defined_inputs(pre,inpdefs,code,classname,ssyms):
|
||||
def emit_defined_inputs(pre,inpdefs,code,classname,ssyms,fsyms):
|
||||
udefs,usyms = collect_used_definitions(pre,inpdefs,ssyms)
|
||||
iu.dbg('usyms')
|
||||
for sym in usyms:
|
||||
|
@ -873,9 +872,77 @@ def emit_defined_inputs(pre,inpdefs,code,classname,ssyms):
|
|||
emit_assign(dfn,code)
|
||||
skip_z3 = False
|
||||
for param_def in inpdefs:
|
||||
code_line(code,code_eval(code,param_def.args[0]) + ' = ' + code_eval(code,param_def.args[1]))
|
||||
lhs = param_def.args[0]
|
||||
lhs = fsyms.get(lhs,lhs)
|
||||
code_line(code,code_eval(code,lhs) + ' = ' + code_eval(code,param_def.args[1]))
|
||||
|
||||
def minimal_field_references(fmla,inputs):
|
||||
inpset = set(inputs)
|
||||
res = defaultdict(set)
|
||||
def field_ref(f):
|
||||
if il.is_app(f):
|
||||
if f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
|
||||
return field_ref(f.args[0])
|
||||
if f.rep in inpset:
|
||||
return f.rep
|
||||
return None
|
||||
|
||||
def recur(f):
|
||||
if il.is_app(f):
|
||||
if f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
|
||||
inp = field_ref(f.args[0])
|
||||
if inp is not None:
|
||||
res[inp].add(f)
|
||||
return
|
||||
if f.rep in inpset:
|
||||
res[f.rep].add(f.rep)
|
||||
return
|
||||
for x in f.args:
|
||||
recur(x)
|
||||
|
||||
def get_minima(refs):
|
||||
def lt(x,y):
|
||||
return len(y.args) == 1 and (x == y.args[0] or lt(x,y.args[0]))
|
||||
return set(y for y in refs if all(not(lt(x,y)) for x in refs))
|
||||
|
||||
recur(fmla)
|
||||
res = dict((inp,get_minima(refs)) for inp,refs in res.iteritems())
|
||||
iu.dbg('dict((str(inp),list(str(ref) for ref in refs)) for inp,refs in res.iteritems())')
|
||||
return res
|
||||
|
||||
def minimal_field_siblings(inputs,mrefs):
|
||||
res = defaultdict(set)
|
||||
for inp in inputs:
|
||||
if inp in mrefs:
|
||||
for f in mrefs[inp]:
|
||||
if len(f.args) == 1:
|
||||
sort = f.rep.sort.dom[0]
|
||||
destrs = im.module.sort_destructors[sort.name]
|
||||
for d in destrs:
|
||||
res[inp].add(d(f.args[0]))
|
||||
else:
|
||||
res[inp].add(inp)
|
||||
return res
|
||||
|
||||
def extract_input_fields(pre_clauses,inputs):
|
||||
mrefs = minimal_field_references(pre_clauses.to_formula(),inputs)
|
||||
mrefs = minimal_field_siblings(inputs,mrefs)
|
||||
iu.dbg('[(str(x),map(str,y)) for x,y in mrefs.iteritems()]')
|
||||
def field_symbol_name(f):
|
||||
if len(f.args) == 1:
|
||||
return field_symbol_name(f.args[0]) + '__' + f.rep.name
|
||||
return f.rep.name
|
||||
fsyms = dict((il.Symbol(field_symbol_name(y),y.sort),y) for l in mrefs.values() for y in l)
|
||||
rfsyms = dict((y,x) for x,y in fsyms.iteritems())
|
||||
def recur(f):
|
||||
if il.is_app(f):
|
||||
if f.rep in mrefs or f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
|
||||
if f in rfsyms:
|
||||
return rfsyms[f]
|
||||
return f.clone(map(recur,f.args))
|
||||
pre_clauses = ilu.Clauses(map(recur,pre_clauses.fmlas),map(recur,pre_clauses.defs))
|
||||
inputs = list(fsyms.keys())
|
||||
return pre_clauses,inputs,fsyms
|
||||
|
||||
def emit_action_gen(header,impl,name,action,classname):
|
||||
global indent_level
|
||||
|
@ -907,6 +974,13 @@ def emit_action_gen(header,impl,name,action,classname):
|
|||
orig_pre = pre
|
||||
pre_clauses = ilu.trim_clauses(pre)
|
||||
inputs = [x for x in ilu.used_symbols_clauses(pre_clauses) if is_local_sym(x) and not x.is_numeral()]
|
||||
inputset = set(inputs)
|
||||
for p in action.formal_params:
|
||||
p = p.prefix('__')
|
||||
if p not in inputset:
|
||||
inputs.append(p)
|
||||
pre_clauses, inputs, fsyms = extract_input_fields(pre_clauses,inputs)
|
||||
iu.dbg('[str(pre_clauses),[str(x) for x in inputs],str(fsyms)]')
|
||||
pre_clauses, param_defs = extract_defined_parameters(pre_clauses,inputs)
|
||||
iu.dbg('pre_clauses')
|
||||
iu.dbg('param_defs')
|
||||
|
@ -918,19 +992,23 @@ def emit_action_gen(header,impl,name,action,classname):
|
|||
iu.dbg('used')
|
||||
used_names = set(varname(s) for s in used)
|
||||
defed_params = set(f.args[0] for f in param_defs)
|
||||
for p in action.formal_params:
|
||||
p = p.prefix('__')
|
||||
if varname(p) not in used_names and p not in defed_params:
|
||||
used.add(p)
|
||||
for x in used:
|
||||
if x.is_numeral() and il.is_uninterpreted_sort(x.sort):
|
||||
raise iu.IvyError(None,'Cannot compile numeral {} of uninterpreted sort {}'.format(x,x.sort))
|
||||
syms = [x for x in used if is_local_sym(x) and not x.is_numeral()]
|
||||
syms = inputs
|
||||
iu.dbg('syms')
|
||||
header.append("class " + caname + "_gen : public gen {\n public:\n")
|
||||
for sym in syms + list(defed_params):
|
||||
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
|
||||
declare_symbol(header,sym,classname=classname)
|
||||
iu.dbg('[str(s) for s in defed_params]')
|
||||
decld = set()
|
||||
def get_root(f):
|
||||
return get_root(f.args[0]) if len(f.args) == 1 else f
|
||||
for sym in syms:
|
||||
if sym in fsyms:
|
||||
sym = get_root(fsyms[sym])
|
||||
if sym not in decld:
|
||||
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
|
||||
declare_symbol(header,sym,classname=classname)
|
||||
decld.add(sym)
|
||||
header.append(" {}_gen();\n".format(caname))
|
||||
header.append(" bool generate(" + classname + "&);\n");
|
||||
header.append(" void execute(" + classname + "&);\n};\n");
|
||||
|
@ -967,12 +1045,13 @@ def emit_action_gen(header,impl,name,action,classname):
|
|||
indent_level += 1
|
||||
for sym in syms:
|
||||
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
|
||||
emit_eval(impl,sym,classname=classname)
|
||||
if sym not in defed_params:
|
||||
emit_eval(impl,sym,classname=classname,lhs=fsyms.get(sym,sym))
|
||||
ssyms = set()
|
||||
for sym in all_state_symbols():
|
||||
if sym_is_member(sym):
|
||||
ssyms.add(sym)
|
||||
emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms)
|
||||
emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms,fsyms)
|
||||
indent_level -= 2
|
||||
impl.append("""
|
||||
}""")
|
||||
|
|
Загрузка…
Ссылка в новой задаче