add IPDL front-end support for transitioning to one of a set of states

This commit is contained in:
Chris Jones 2009-08-19 21:21:46 -05:00
Родитель 809d7382e0
Коммит 24b18f8473
3 изменённых файлов: 113 добавлений и 49 удалений

Просмотреть файл

@ -280,11 +280,11 @@ class TransitionStmt(Node):
self.transitions = transitions
class Transition(Node):
def __init__(self, loc, trigger, msg, toState):
def __init__(self, loc, trigger, msg, toStates):
Node.__init__(self, loc)
self.trigger = trigger
self.msg = msg
self.toState = toState
self.toStates = toStates
@staticmethod
def nameToTrigger(name):
@ -328,7 +328,7 @@ class State(Node):
return hash(repr(self))
def __ne__(self, o):
return not (self == o)
def __repr__(self): return '<State %r start=%s>'% (self.name, self.start)
def __repr__(self): return '<State %r start=%r>'% (self.name, self.start)
def __str__(self): return '<State %s start=%s>'% (self.name, self.start)
class Param(Node):

Просмотреть файл

@ -156,6 +156,7 @@ reserved = set((
'manager',
'manages',
'namespace',
'or',
'parent',
'protocol',
'recv',
@ -421,7 +422,7 @@ def p_Transitions(p):
p[0] = [ p[1] ]
def p_Transition(p):
"""Transition : Trigger MessageId GOTO State ';'"""
"""Transition : Trigger MessageId GOTO StateList ';'"""
loc, trigger = p[1]
p[0] = Transition(loc, trigger, p[2], p[4])
@ -432,6 +433,15 @@ def p_Trigger(p):
| ANSWER"""
p[0] = [ locFromTok(p, 1), Transition.nameToTrigger(p[1]) ]
def p_StateList(p):
"""StateList : StateList OR State
| State"""
if 2 == len(p):
p[0] = [ p[1] ]
else:
p[1].append(p[3])
p[0] = p[1]
def p_State(p):
"""State : ID"""
p[0] = State(locFromTok(p, 1), p[1])

Просмотреть файл

@ -134,8 +134,10 @@ class IPDLType(Type):
or o.isSync() and self.isRpc())
class StateType(IPDLType):
def __init__(self): pass
def isState(self): return True
def __init__(self, start=False):
self.start = start
def isState(self):
return True
class MessageType(IPDLType):
def __init__(self, sendSemantics, direction,
@ -470,6 +472,7 @@ class GatherDecls(TcheckVisitor):
mgdname, p.name)
p.states = { }
if len(p.transitionStmts):
p.startStates = [ ts for ts in p.transitionStmts
if ts.state.start ]
@ -481,10 +484,11 @@ class GatherDecls(TcheckVisitor):
p.states[trans.state] = trans
trans.state.decl = self.declare(
loc=trans.state.loc,
type=StateType(),
type=StateType(trans.state.start),
progname=trans.state.name)
for trans in p.transitionStmts:
self.seentriggers = set()
trans.accept(self)
# visit the message decls once more and resolve the state names
@ -661,21 +665,19 @@ class GatherDecls(TcheckVisitor):
md.protocolDecl = self.currentProtocolDecl
def visitTransitionStmt(self, ts):
self.seentriggers = set()
TcheckVisitor.visitTransitionStmt(self, ts)
def visitTransition(self, t):
loc = t.loc
sname = t.toState.name
sdecl = self.symtab.lookup(sname)
if sdecl is None:
self.error(loc, "state `%s' has not been declared", sname)
elif not sdecl.type.isState():
self.error(
loc, "`%s' should have state type, but instead has type `%s'",
sname, sdecl.type.typename())
else:
t.toState.decl = sdecl
# check the trigger message
mname = t.msg
if mname in self.seentriggers:
self.error(loc, "trigger `%s' appears multiple times", mname)
self.seentriggers.add(mname)
mdecl = self.symtab.lookup(mname)
if mdecl is None:
self.error(loc, "message `%s' has not been declared", mname)
@ -687,6 +689,28 @@ class GatherDecls(TcheckVisitor):
else:
t.msg = mdecl
# check the to-states
seenstates = set()
for toState in t.toStates:
sname = toState.name
sdecl = self.symtab.lookup(sname)
if sname in seenstates:
self.error(loc, "to-state `%s' appears multiple times", sname)
seenstates.add(sname)
if sdecl is None:
self.error(loc, "state `%s' has not been declared", sname)
elif not sdecl.type.isState():
self.error(
loc, "`%s' should have state type, but instead has type `%s'",
sname, sdecl.type.typename())
else:
toState.decl = sdecl
toState.start = sdecl.type.start
t.toStates = set(t.toStates)
##-----------------------------------------------------------------------------
class CheckTypes(TcheckVisitor):
@ -808,6 +832,18 @@ class CheckTypes(TcheckVisitor):
##-----------------------------------------------------------------------------
def unique_pairs(s):
n = len(s)
for i, e1 in enumerate(s):
for j in xrange(i+1, n):
yield (e1, s[j])
def cross_product(s1, s2):
for e1 in s1:
for e2 in s2:
yield (e1, e2)
class CheckStateMachine(TcheckVisitor):
def __init__(self, errors):
# don't need the symbol table, we just want the error reporting
@ -859,10 +895,26 @@ class CheckStateMachine(TcheckVisitor):
# and t1 transitions to state T1 and t2 to T2,
# then the following must be true:
# T2 allows the trigger t1, transitioning to state U
# T1 allows the trigger t2, transitioning to state U"""
# T1 allows the trigger t2, transitioning to state U
#
# This is a more formal way of expressing "it doesn't matter
# in which order the triggers t1 and t2 occur / are processed."
#
# The presence of triggers with multiple out states complicates
# this check slightly, but doesn't fundamentally change it.
#
# from a state S,
# for any pair of triggers t1 and t2,
# where t1 and t2 have opposite direction,
# for each pair of states (T1, T2) \in t1_out x t2_out,
# where t1_out is the set of outstates from t1
# t2_out is the set of outstates from t2
# t1_out x t2_out is their Cartesian product
# and t1 transitions to state T1 and t2 to T2,
# then the following must be true:
# T2 allows the trigger t1, with out-state set { U }
# T1 allows the trigger t2, with out-state set { U }
#
syncdirection = None
syncok = True
for trans in ts.transitions:
@ -879,42 +931,43 @@ class CheckStateMachine(TcheckVisitor):
if not syncok:
return
def triggerTarget(S, t):
'''Return the state transitioned to from state |S|
upon trigger |t|, or None if |t| is not a trigger in |S|.'''
def triggerTargets(S, t):
'''Return the set of states transitioned to from state |S|
upon trigger |t|, or { } if |t| is not a trigger in |S|.'''
for trans in self.p.states[S].transitions:
if t.trigger is trans.trigger and t.msg is trans.msg:
return trans.toState
return None
return trans.toStates
return set()
ntrans = len(ts.transitions)
for i, t1 in enumerate(ts.transitions):
for j in xrange(i+1, ntrans):
t2 = ts.transitions[j]
for (t1, t2) in unique_pairs(ts.transitions):
# if the triggers have the same direction, they can't race,
# since only one endpoint can initiate either (and delivery
# is in-order)
if t1.trigger.direction() == t2.trigger.direction():
continue
T1 = t1.toState
T2 = t2.toState
t1_out = t1.toStates
t2_out = t2.toStates
U1 = triggerTarget(T1, t2)
U2 = triggerTarget(T2, t1)
for (T1, T2) in cross_product(t1_out, t2_out):
U1 = triggerTargets(T1, t2)
U2 = triggerTargets(T2, t1)
if U1 is None or U1 != U2:
if (0 == len(U1)
or 1 < len(U1) or 1 < len(U2)
or U1 != U2):
self.error(
t2.loc,
"trigger `%s' potentially races (does not commute) with `%s' at state `%s' in protocol `%s'",
t1.msg.progname, t2.msg.progname,
ts.state.name, self.p.name)
"in protocol `%s' state `%s', trigger `%s' potentially races (does not commute) with `%s'",
self.p.name, ts.state.name,
t1.msg.progname, t2.msg.progname)
# don't report more than one Diamond Rule
# violation per state. there may be O(n^2) total,
# way too many for a human to parse
# violation per state. there may be O(n^4)
# total, way too many for a human to parse
#
# XXX/cjones: could set a limit on #printed and stop after
# that limit ...
# XXX/cjones: could set a limit on #printed
# and stop after that limit ...
return
def checkReachability(self, p):
@ -924,7 +977,8 @@ upon trigger |t|, or None if |t| is not a trigger in |S|.'''
return
visited.add(ts.state)
for outedge in ts.transitions:
explore(p.states[outedge.toState])
for toState in outedge.toStates:
explore(p.states[toState])
for root in p.startStates:
explore(root)