x/crypto/otr: clear key slots when handshaking.
The OTR implementation had a bug where key slots would not be marked as unused after a rehandshake. Since handshaking resets the key ids, some key slots would be left over with much higher key ids. These key slots would lead to an error when the code ran out of slots. Fixes agl/xmpp-client#96. Change-Id: I013bbc4eaf0616373ab52f14b7f7757c353983ca Reviewed-on: https://go-review.googlesource.com/16934 Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Родитель
346896d577
Коммит
d438f321d3
19
otr/otr.go
19
otr/otr.go
|
@ -277,7 +277,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
|
|||
in = in[len(msgPrefix) : len(in)-1]
|
||||
} else if version := isQuery(in); version > 0 {
|
||||
c.authState = authStateAwaitingDHKey
|
||||
c.myKeyId = 0
|
||||
c.reset()
|
||||
toSend = c.encode(c.generateDHCommit())
|
||||
return
|
||||
} else {
|
||||
|
@ -311,7 +311,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
|
|||
if err = c.processDHCommit(msg); err != nil {
|
||||
return
|
||||
}
|
||||
c.myKeyId = 0
|
||||
c.reset()
|
||||
toSend = c.encode(c.generateDHKey())
|
||||
return
|
||||
case authStateAwaitingDHKey:
|
||||
|
@ -330,7 +330,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
|
|||
if err = c.processDHCommit(msg); err != nil {
|
||||
return
|
||||
}
|
||||
c.myKeyId = 0
|
||||
c.reset()
|
||||
toSend = c.encode(c.generateDHKey())
|
||||
return
|
||||
}
|
||||
|
@ -343,7 +343,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
|
|||
if err = c.processDHCommit(msg); err != nil {
|
||||
return
|
||||
}
|
||||
c.myKeyId = 0
|
||||
c.reset()
|
||||
toSend = c.encode(c.generateDHKey())
|
||||
c.authState = authStateAwaitingRevealSig
|
||||
default:
|
||||
|
@ -1036,8 +1036,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot,
|
|||
}
|
||||
}
|
||||
if slot == nil {
|
||||
err = errors.New("otr: internal error: no key slots")
|
||||
return
|
||||
return nil, errors.New("otr: internal error: no more key slots")
|
||||
}
|
||||
|
||||
var myPriv, myPub, theirPub *big.Int
|
||||
|
@ -1163,6 +1162,14 @@ func (c *Conversation) encode(msg []byte) [][]byte {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (c *Conversation) reset() {
|
||||
c.myKeyId = 0
|
||||
|
||||
for i := range c.keySlots {
|
||||
c.keySlots[i].used = false
|
||||
}
|
||||
}
|
||||
|
||||
type PublicKey struct {
|
||||
dsa.PublicKey
|
||||
}
|
||||
|
|
196
otr/otr_test.go
196
otr/otr_test.go
|
@ -121,11 +121,12 @@ func TestSignVerify(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConversation(t *testing.T) {
|
||||
func setupConversation(t *testing.T) (alice, bob *Conversation) {
|
||||
alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex)
|
||||
bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex)
|
||||
|
||||
var alice, bob Conversation
|
||||
alice, bob = new(Conversation), new(Conversation)
|
||||
|
||||
alice.PrivateKey = new(PrivateKey)
|
||||
bob.PrivateKey = new(PrivateKey)
|
||||
alice.PrivateKey.Parse(alicePrivateKey)
|
||||
|
@ -133,12 +134,6 @@ func TestConversation(t *testing.T) {
|
|||
alice.FragmentSize = 100
|
||||
bob.FragmentSize = 100
|
||||
|
||||
var alicesMessage, bobsMessage [][]byte
|
||||
var out []byte
|
||||
var aliceChange, bobChange SecurityChange
|
||||
var err error
|
||||
alicesMessage = append(alicesMessage, []byte(QueryMessage))
|
||||
|
||||
if alice.IsEncrypted() {
|
||||
t.Error("Alice believes that the conversation is secure before we've started")
|
||||
}
|
||||
|
@ -146,6 +141,17 @@ func TestConversation(t *testing.T) {
|
|||
t.Error("Bob believes that the conversation is secure before we've started")
|
||||
}
|
||||
|
||||
performHandshake(t, alice, bob)
|
||||
return alice, bob
|
||||
}
|
||||
|
||||
func performHandshake(t *testing.T, alice, bob *Conversation) {
|
||||
var alicesMessage, bobsMessage [][]byte
|
||||
var out []byte
|
||||
var aliceChange, bobChange SecurityChange
|
||||
var err error
|
||||
alicesMessage = append(alicesMessage, []byte(QueryMessage))
|
||||
|
||||
for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ {
|
||||
bobsMessage = nil
|
||||
for i, msg := range alicesMessage {
|
||||
|
@ -193,77 +199,106 @@ func TestConversation(t *testing.T) {
|
|||
if !bob.IsEncrypted() {
|
||||
t.Error("Bob doesn't believe that the conversation is secure")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
firstRoundTrip = iota
|
||||
subsequentRoundTrip
|
||||
noMACKeyCheck
|
||||
)
|
||||
|
||||
func roundTrip(t *testing.T, alice, bob *Conversation, message []byte, macKeyCheck int) {
|
||||
alicesMessage, err := alice.Send(message)
|
||||
if err != nil {
|
||||
t.Errorf("Error from Alice sending message: %s", err)
|
||||
}
|
||||
|
||||
if len(alice.oldMACs) != 0 {
|
||||
t.Errorf("Alice has not revealed all MAC keys")
|
||||
}
|
||||
|
||||
for i, msg := range alicesMessage {
|
||||
out, encrypted, _, _, err := bob.Receive(msg)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error generated while processing test message: %s", err.Error())
|
||||
}
|
||||
if len(out) > 0 {
|
||||
if i != len(alicesMessage)-1 {
|
||||
t.Fatal("Bob produced a message while processing a fragment of Alice's")
|
||||
}
|
||||
if !encrypted {
|
||||
t.Errorf("Message was not marked as encrypted")
|
||||
}
|
||||
if !bytes.Equal(out, message) {
|
||||
t.Errorf("Message corrupted: got %x, want %x", out, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch macKeyCheck {
|
||||
case firstRoundTrip:
|
||||
if len(bob.oldMACs) != 0 {
|
||||
t.Errorf("Bob should not have MAC keys to reveal")
|
||||
}
|
||||
case subsequentRoundTrip:
|
||||
if len(bob.oldMACs) != 40 {
|
||||
t.Errorf("Bob has %d bytes of MAC keys to reveal, but should have 40", len(bob.oldMACs))
|
||||
}
|
||||
}
|
||||
|
||||
bobsMessage, err := bob.Send(message)
|
||||
if err != nil {
|
||||
t.Errorf("Error from Bob sending message: %s", err)
|
||||
}
|
||||
|
||||
if len(bob.oldMACs) != 0 {
|
||||
t.Errorf("Bob has not revealed all MAC keys")
|
||||
}
|
||||
|
||||
for i, msg := range bobsMessage {
|
||||
out, encrypted, _, _, err := alice.Receive(msg)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error generated while processing test message: %s", err.Error())
|
||||
}
|
||||
if len(out) > 0 {
|
||||
if i != len(bobsMessage)-1 {
|
||||
t.Fatal("Alice produced a message while processing a fragment of Bob's")
|
||||
}
|
||||
if !encrypted {
|
||||
t.Errorf("Message was not marked as encrypted")
|
||||
}
|
||||
if !bytes.Equal(out, message) {
|
||||
t.Errorf("Message corrupted: got %x, want %x", out, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch macKeyCheck {
|
||||
case firstRoundTrip:
|
||||
if len(alice.oldMACs) != 20 {
|
||||
t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 20", len(alice.oldMACs))
|
||||
}
|
||||
case subsequentRoundTrip:
|
||||
if len(alice.oldMACs) != 40 {
|
||||
t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 40", len(alice.oldMACs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConversation(t *testing.T) {
|
||||
alice, bob := setupConversation(t)
|
||||
|
||||
var testMessages = [][]byte{
|
||||
[]byte("hello"), []byte("bye"),
|
||||
}
|
||||
|
||||
for j, testMessage := range testMessages {
|
||||
alicesMessage, err = alice.Send(testMessage)
|
||||
roundTripType := firstRoundTrip
|
||||
|
||||
if len(alice.oldMACs) != 0 {
|
||||
t.Errorf("Alice has not revealed all MAC keys")
|
||||
}
|
||||
|
||||
for i, msg := range alicesMessage {
|
||||
out, encrypted, _, _, err := bob.Receive(msg)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error generated while processing test message: %s", err.Error())
|
||||
}
|
||||
if len(out) > 0 {
|
||||
if i != len(alicesMessage)-1 {
|
||||
t.Fatal("Bob produced a message while processing a fragment of Alice's")
|
||||
}
|
||||
if !encrypted {
|
||||
t.Errorf("Message was not marked as encrypted")
|
||||
}
|
||||
if !bytes.Equal(out, testMessage) {
|
||||
t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if j == 0 {
|
||||
if len(bob.oldMACs) != 0 {
|
||||
t.Errorf("Bob should not have MAC keys to reveal")
|
||||
}
|
||||
} else if len(bob.oldMACs) != 40 {
|
||||
t.Errorf("Bob does not have MAC keys to reveal")
|
||||
}
|
||||
|
||||
bobsMessage, err = bob.Send(testMessage)
|
||||
|
||||
if len(bob.oldMACs) != 0 {
|
||||
t.Errorf("Bob has not revealed all MAC keys")
|
||||
}
|
||||
|
||||
for i, msg := range bobsMessage {
|
||||
out, encrypted, _, _, err := alice.Receive(msg)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error generated while processing test message: %s", err.Error())
|
||||
}
|
||||
if len(out) > 0 {
|
||||
if i != len(bobsMessage)-1 {
|
||||
t.Fatal("Alice produced a message while processing a fragment of Bob's")
|
||||
}
|
||||
if !encrypted {
|
||||
t.Errorf("Message was not marked as encrypted")
|
||||
}
|
||||
if !bytes.Equal(out, testMessage) {
|
||||
t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if j == 0 {
|
||||
if len(alice.oldMACs) != 20 {
|
||||
t.Errorf("Alice does not have MAC keys to reveal")
|
||||
}
|
||||
} else if len(alice.oldMACs) != 40 {
|
||||
t.Errorf("Alice does not have MAC keys to reveal")
|
||||
}
|
||||
for _, testMessage := range testMessages {
|
||||
roundTrip(t, alice, bob, testMessage, roundTripType)
|
||||
roundTripType = subsequentRoundTrip
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -348,6 +383,21 @@ func TestBadSMP(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRehandshaking(t *testing.T) {
|
||||
alice, bob := setupConversation(t)
|
||||
roundTrip(t, alice, bob, []byte("test"), firstRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 2"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 3"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 4"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 5"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 6"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 7"), subsequentRoundTrip)
|
||||
roundTrip(t, alice, bob, []byte("test 8"), subsequentRoundTrip)
|
||||
performHandshake(t, alice, bob)
|
||||
roundTrip(t, alice, bob, []byte("test"), noMACKeyCheck)
|
||||
roundTrip(t, alice, bob, []byte("test 2"), noMACKeyCheck)
|
||||
}
|
||||
|
||||
func TestAgainstLibOTR(t *testing.T) {
|
||||
// This test requires otr.c.test to be built as /tmp/a.out.
|
||||
// If enabled, this tests runs forever performing OTR handshakes in a
|
||||
|
|
Загрузка…
Ссылка в новой задаче