working on dependent input fields

This commit is contained in:
Ken McMillan 2018-08-10 18:41:58 -07:00
Родитель 8e2960f57b
Коммит 2676736c0a
1 изменённых файлов: 98 добавлений и 19 удалений

Просмотреть файл

@ -580,11 +580,11 @@ def int_to_z3(sort,val):
raise iu.IvyError(None,"cannot produce test generator because sort {} is uninterpreted".format(sort))
return 'int_to_z3(sort("'+sort.name+'"),'+val+')'
def emit_eval(header,symbol,obj=None,classname=None):
def emit_eval(header,symbol,obj=None,classname=None,lhs=None):
global indent_level
name = symbol.name
sname = slv.solver_name(symbol)
cname = varname(name)
cname = varname(name) if lhs is None else code_eval(header,lhs)
sort = symbol.sort
domain = sort_domain(sort)
for idx,dsort in enumerate(domain):
@ -594,7 +594,7 @@ def emit_eval(header,symbol,obj=None,classname=None):
indent_level += 1
indent(header)
if sort.rng.name in im.module.sort_destructors or sort.rng.name in im.module.native_types or sort.rng in sort_to_cpptype:
code_line(header,'__from_solver<'+classname+'::'+varname(sort.rng.name)+'>(*this,apply("'+sname+'"'+''.join(','+int_to_z3(s,'X{}'.format(idx)) for idx,s in enumerate(domain))+'),'+varname(symbol)+''.join('[X{}]'.format(idx) for idx in range(len(domain)))+')')
code_line(header,'__from_solver<'+classname+'::'+varname(sort.rng.name)+'>(*this,apply("'+sname+'"'+''.join(','+int_to_z3(s,'X{}'.format(idx)) for idx,s in enumerate(domain))+'),'+cname+''.join('[X{}]'.format(idx) for idx in range(len(domain)))+')')
else:
header.append((obj + '.' if obj else '')
+ cname + ''.join("[X{}]".format(idx) for idx in range(len(domain)))
@ -825,18 +825,17 @@ def extract_defined_parameters(pre_clauses,inputs):
defmap[fmla.args[0]] = fmla
iu.dbg('defmap')
inpdefs = []
*** make sure var does not occur in defs ***
while change:
change = False
for input,fmla in list(defmap.iteritems()):
if all(input not in ilu.used_variables_ast(f) or f == fmla for f in pre_clauses.fmlas):
if (all(input not in ilu.used_symbols_ast(f) or f == fmla for f in pre_clauses.fmlas)
and all(input not in ilu.used_symbols_ast(d) for d in pre_clauses.defs)):
pre_clauses = ilu.Clauses([f for f in pre_clauses.fmlas if f != fmla],pre_clauses.defs)
del defmap[input]
inpdefs.append(fmla)
change = True
**** trim the defs ***
pre_clauses = ilu.trim_clauses(pre_clauses)
inpdefs.reverse()
pre_clauses = ilu.trim_clauses(pre_clauses)
return pre_clauses,inpdefs
def collect_used_definitions(pre,inpdefs,ssyms):
@ -859,7 +858,7 @@ def collect_used_definitions(pre,inpdefs,ssyms):
return res,usyms
def emit_defined_inputs(pre,inpdefs,code,classname,ssyms):
def emit_defined_inputs(pre,inpdefs,code,classname,ssyms,fsyms):
udefs,usyms = collect_used_definitions(pre,inpdefs,ssyms)
iu.dbg('usyms')
for sym in usyms:
@ -873,9 +872,77 @@ def emit_defined_inputs(pre,inpdefs,code,classname,ssyms):
emit_assign(dfn,code)
skip_z3 = False
for param_def in inpdefs:
code_line(code,code_eval(code,param_def.args[0]) + ' = ' + code_eval(code,param_def.args[1]))
lhs = param_def.args[0]
lhs = fsyms.get(lhs,lhs)
code_line(code,code_eval(code,lhs) + ' = ' + code_eval(code,param_def.args[1]))
def minimal_field_references(fmla,inputs):
inpset = set(inputs)
res = defaultdict(set)
def field_ref(f):
if il.is_app(f):
if f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
return field_ref(f.args[0])
if f.rep in inpset:
return f.rep
return None
def recur(f):
if il.is_app(f):
if f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
inp = field_ref(f.args[0])
if inp is not None:
res[inp].add(f)
return
if f.rep in inpset:
res[f.rep].add(f.rep)
return
for x in f.args:
recur(x)
def get_minima(refs):
def lt(x,y):
return len(y.args) == 1 and (x == y.args[0] or lt(x,y.args[0]))
return set(y for y in refs if all(not(lt(x,y)) for x in refs))
recur(fmla)
res = dict((inp,get_minima(refs)) for inp,refs in res.iteritems())
iu.dbg('dict((str(inp),list(str(ref) for ref in refs)) for inp,refs in res.iteritems())')
return res
def minimal_field_siblings(inputs,mrefs):
res = defaultdict(set)
for inp in inputs:
if inp in mrefs:
for f in mrefs[inp]:
if len(f.args) == 1:
sort = f.rep.sort.dom[0]
destrs = im.module.sort_destructors[sort.name]
for d in destrs:
res[inp].add(d(f.args[0]))
else:
res[inp].add(inp)
return res
def extract_input_fields(pre_clauses,inputs):
mrefs = minimal_field_references(pre_clauses.to_formula(),inputs)
mrefs = minimal_field_siblings(inputs,mrefs)
iu.dbg('[(str(x),map(str,y)) for x,y in mrefs.iteritems()]')
def field_symbol_name(f):
if len(f.args) == 1:
return field_symbol_name(f.args[0]) + '__' + f.rep.name
return f.rep.name
fsyms = dict((il.Symbol(field_symbol_name(y),y.sort),y) for l in mrefs.values() for y in l)
rfsyms = dict((y,x) for x,y in fsyms.iteritems())
def recur(f):
if il.is_app(f):
if f.rep in mrefs or f.rep.name in im.module.destructor_sorts and len(f.args) == 1:
if f in rfsyms:
return rfsyms[f]
return f.clone(map(recur,f.args))
pre_clauses = ilu.Clauses(map(recur,pre_clauses.fmlas),map(recur,pre_clauses.defs))
inputs = list(fsyms.keys())
return pre_clauses,inputs,fsyms
def emit_action_gen(header,impl,name,action,classname):
global indent_level
@ -907,6 +974,13 @@ def emit_action_gen(header,impl,name,action,classname):
orig_pre = pre
pre_clauses = ilu.trim_clauses(pre)
inputs = [x for x in ilu.used_symbols_clauses(pre_clauses) if is_local_sym(x) and not x.is_numeral()]
inputset = set(inputs)
for p in action.formal_params:
p = p.prefix('__')
if p not in inputset:
inputs.append(p)
pre_clauses, inputs, fsyms = extract_input_fields(pre_clauses,inputs)
iu.dbg('[str(pre_clauses),[str(x) for x in inputs],str(fsyms)]')
pre_clauses, param_defs = extract_defined_parameters(pre_clauses,inputs)
iu.dbg('pre_clauses')
iu.dbg('param_defs')
@ -918,19 +992,23 @@ def emit_action_gen(header,impl,name,action,classname):
iu.dbg('used')
used_names = set(varname(s) for s in used)
defed_params = set(f.args[0] for f in param_defs)
for p in action.formal_params:
p = p.prefix('__')
if varname(p) not in used_names and p not in defed_params:
used.add(p)
for x in used:
if x.is_numeral() and il.is_uninterpreted_sort(x.sort):
raise iu.IvyError(None,'Cannot compile numeral {} of uninterpreted sort {}'.format(x,x.sort))
syms = [x for x in used if is_local_sym(x) and not x.is_numeral()]
syms = inputs
iu.dbg('syms')
header.append("class " + caname + "_gen : public gen {\n public:\n")
for sym in syms + list(defed_params):
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
declare_symbol(header,sym,classname=classname)
iu.dbg('[str(s) for s in defed_params]')
decld = set()
def get_root(f):
return get_root(f.args[0]) if len(f.args) == 1 else f
for sym in syms:
if sym in fsyms:
sym = get_root(fsyms[sym])
if sym not in decld:
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
declare_symbol(header,sym,classname=classname)
decld.add(sym)
header.append(" {}_gen();\n".format(caname))
header.append(" bool generate(" + classname + "&);\n");
header.append(" void execute(" + classname + "&);\n};\n");
@ -967,12 +1045,13 @@ def emit_action_gen(header,impl,name,action,classname):
indent_level += 1
for sym in syms:
if not sym.name.startswith('__ts') and sym not in pre_clauses.defidx and sym.name != '*>':
emit_eval(impl,sym,classname=classname)
if sym not in defed_params:
emit_eval(impl,sym,classname=classname,lhs=fsyms.get(sym,sym))
ssyms = set()
for sym in all_state_symbols():
if sym_is_member(sym):
ssyms.add(sym)
emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms)
emit_defined_inputs(orig_pre,param_defs,impl,classname,ssyms,fsyms)
indent_level -= 2
impl.append("""
}""")