diff --git a/ipc/ipdl/ipdl/ast.py b/ipc/ipdl/ipdl/ast.py index 1bd030325edd..7787a2b5aeae 100644 --- a/ipc/ipdl/ipdl/ast.py +++ b/ipc/ipdl/ipdl/ast.py @@ -280,11 +280,11 @@ class TransitionStmt(Node): self.transitions = transitions class Transition(Node): - def __init__(self, loc, trigger, msg, toState): + def __init__(self, loc, trigger, msg, toStates): Node.__init__(self, loc) self.trigger = trigger self.msg = msg - self.toState = toState + self.toStates = toStates @staticmethod def nameToTrigger(name): @@ -321,14 +321,14 @@ class State(Node): self.name = name self.start = start def __eq__(self, o): - return (isinstance(o, State) - and o.name == self.name - and o.start == self.start) + return (isinstance(o, State) + and o.name == self.name + and o.start == self.start) def __hash__(self): return hash(repr(self)) def __ne__(self, o): return not (self == o) - def __repr__(self): return ''% (self.name, self.start) + def __repr__(self): return ''% (self.name, self.start) def __str__(self): return ''% (self.name, self.start) class Param(Node): diff --git a/ipc/ipdl/ipdl/parser.py b/ipc/ipdl/ipdl/parser.py index 7d0602259202..fa4f72a295ad 100644 --- a/ipc/ipdl/ipdl/parser.py +++ b/ipc/ipdl/ipdl/parser.py @@ -156,6 +156,7 @@ reserved = set(( 'manager', 'manages', 'namespace', + 'or', 'parent', 'protocol', 'recv', @@ -421,7 +422,7 @@ def p_Transitions(p): p[0] = [ p[1] ] def p_Transition(p): - """Transition : Trigger MessageId GOTO State ';'""" + """Transition : Trigger MessageId GOTO StateList ';'""" loc, trigger = p[1] p[0] = Transition(loc, trigger, p[2], p[4]) @@ -432,6 +433,15 @@ def p_Trigger(p): | ANSWER""" p[0] = [ locFromTok(p, 1), Transition.nameToTrigger(p[1]) ] +def p_StateList(p): + """StateList : StateList OR State + | State""" + if 2 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[3]) + p[0] = p[1] + def p_State(p): """State : ID""" p[0] = State(locFromTok(p, 1), p[1]) diff --git a/ipc/ipdl/ipdl/type.py b/ipc/ipdl/ipdl/type.py index c11c403660fb..70ea4236c546 100644 --- a/ipc/ipdl/ipdl/type.py +++ b/ipc/ipdl/ipdl/type.py @@ -134,8 +134,10 @@ class IPDLType(Type): or o.isSync() and self.isRpc()) class StateType(IPDLType): - def __init__(self): pass - def isState(self): return True + def __init__(self, start=False): + self.start = start + def isState(self): + return True class MessageType(IPDLType): def __init__(self, sendSemantics, direction, @@ -470,6 +472,7 @@ class GatherDecls(TcheckVisitor): mgdname, p.name) p.states = { } + if len(p.transitionStmts): p.startStates = [ ts for ts in p.transitionStmts if ts.state.start ] @@ -481,10 +484,11 @@ class GatherDecls(TcheckVisitor): p.states[trans.state] = trans trans.state.decl = self.declare( loc=trans.state.loc, - type=StateType(), + type=StateType(trans.state.start), progname=trans.state.name) for trans in p.transitionStmts: + self.seentriggers = set() trans.accept(self) # visit the message decls once more and resolve the state names @@ -661,21 +665,19 @@ class GatherDecls(TcheckVisitor): md.protocolDecl = self.currentProtocolDecl + def visitTransitionStmt(self, ts): + self.seentriggers = set() + TcheckVisitor.visitTransitionStmt(self, ts) + def visitTransition(self, t): loc = t.loc - sname = t.toState.name - sdecl = self.symtab.lookup(sname) - if sdecl is None: - self.error(loc, "state `%s' has not been declared", sname) - elif not sdecl.type.isState(): - self.error( - loc, "`%s' should have state type, but instead has type `%s'", - sname, sdecl.type.typename()) - else: - t.toState.decl = sdecl - + # check the trigger message mname = t.msg + if mname in self.seentriggers: + self.error(loc, "trigger `%s' appears multiple times", mname) + self.seentriggers.add(mname) + mdecl = self.symtab.lookup(mname) if mdecl is None: self.error(loc, "message `%s' has not been declared", mname) @@ -687,6 +689,28 @@ class GatherDecls(TcheckVisitor): else: t.msg = mdecl + # check the to-states + seenstates = set() + for toState in t.toStates: + sname = toState.name + sdecl = self.symtab.lookup(sname) + + if sname in seenstates: + self.error(loc, "to-state `%s' appears multiple times", sname) + seenstates.add(sname) + + if sdecl is None: + self.error(loc, "state `%s' has not been declared", sname) + elif not sdecl.type.isState(): + self.error( + loc, "`%s' should have state type, but instead has type `%s'", + sname, sdecl.type.typename()) + else: + toState.decl = sdecl + toState.start = sdecl.type.start + + t.toStates = set(t.toStates) + ##----------------------------------------------------------------------------- class CheckTypes(TcheckVisitor): @@ -808,6 +832,18 @@ class CheckTypes(TcheckVisitor): ##----------------------------------------------------------------------------- +def unique_pairs(s): + n = len(s) + for i, e1 in enumerate(s): + for j in xrange(i+1, n): + yield (e1, s[j]) + +def cross_product(s1, s2): + for e1 in s1: + for e2 in s2: + yield (e1, e2) + + class CheckStateMachine(TcheckVisitor): def __init__(self, errors): # don't need the symbol table, we just want the error reporting @@ -859,10 +895,26 @@ class CheckStateMachine(TcheckVisitor): # and t1 transitions to state T1 and t2 to T2, # then the following must be true: # T2 allows the trigger t1, transitioning to state U - # T1 allows the trigger t2, transitioning to state U""" + # T1 allows the trigger t2, transitioning to state U # # This is a more formal way of expressing "it doesn't matter # in which order the triggers t1 and t2 occur / are processed." + # + # The presence of triggers with multiple out states complicates + # this check slightly, but doesn't fundamentally change it. + # + # from a state S, + # for any pair of triggers t1 and t2, + # where t1 and t2 have opposite direction, + # for each pair of states (T1, T2) \in t1_out x t2_out, + # where t1_out is the set of outstates from t1 + # t2_out is the set of outstates from t2 + # t1_out x t2_out is their Cartesian product + # and t1 transitions to state T1 and t2 to T2, + # then the following must be true: + # T2 allows the trigger t1, with out-state set { U } + # T1 allows the trigger t2, with out-state set { U } + # syncdirection = None syncok = True for trans in ts.transitions: @@ -879,42 +931,43 @@ class CheckStateMachine(TcheckVisitor): if not syncok: return - def triggerTarget(S, t): - '''Return the state transitioned to from state |S| -upon trigger |t|, or None if |t| is not a trigger in |S|.''' + def triggerTargets(S, t): + '''Return the set of states transitioned to from state |S| +upon trigger |t|, or { } if |t| is not a trigger in |S|.''' for trans in self.p.states[S].transitions: if t.trigger is trans.trigger and t.msg is trans.msg: - return trans.toState - return None + return trans.toStates + return set() - ntrans = len(ts.transitions) - for i, t1 in enumerate(ts.transitions): - for j in xrange(i+1, ntrans): - t2 = ts.transitions[j] - # if the triggers have the same direction, they can't race, - # since only one endpoint can initiate either (and delivery - # is in-order) - if t1.trigger.direction() == t2.trigger.direction(): - continue - T1 = t1.toState - T2 = t2.toState + for (t1, t2) in unique_pairs(ts.transitions): + # if the triggers have the same direction, they can't race, + # since only one endpoint can initiate either (and delivery + # is in-order) + if t1.trigger.direction() == t2.trigger.direction(): + continue - U1 = triggerTarget(T1, t2) - U2 = triggerTarget(T2, t1) + t1_out = t1.toStates + t2_out = t2.toStates - if U1 is None or U1 != U2: + for (T1, T2) in cross_product(t1_out, t2_out): + U1 = triggerTargets(T1, t2) + U2 = triggerTargets(T2, t1) + + if (0 == len(U1) + or 1 < len(U1) or 1 < len(U2) + or U1 != U2): self.error( t2.loc, - "trigger `%s' potentially races (does not commute) with `%s' at state `%s' in protocol `%s'", - t1.msg.progname, t2.msg.progname, - ts.state.name, self.p.name) + "in protocol `%s' state `%s', trigger `%s' potentially races (does not commute) with `%s'", + self.p.name, ts.state.name, + t1.msg.progname, t2.msg.progname) # don't report more than one Diamond Rule - # violation per state. there may be O(n^2) total, - # way too many for a human to parse + # violation per state. there may be O(n^4) + # total, way too many for a human to parse # - # XXX/cjones: could set a limit on #printed and stop after - # that limit ... + # XXX/cjones: could set a limit on #printed + # and stop after that limit ... return def checkReachability(self, p): @@ -924,7 +977,8 @@ upon trigger |t|, or None if |t| is not a trigger in |S|.''' return visited.add(ts.state) for outedge in ts.transitions: - explore(p.states[outedge.toState]) + for toState in outedge.toStates: + explore(p.states[toState]) for root in p.startStates: explore(root)