From d438f321d375f387b7cd2b157bada3741899a9dd Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Sun, 15 Nov 2015 11:34:54 -0800 Subject: [PATCH] x/crypto/otr: clear key slots when handshaking. The OTR implementation had a bug where key slots would not be marked as unused after a rehandshake. Since handshaking resets the key ids, some key slots would be left over with much higher key ids. These key slots would lead to an error when the code ran out of slots. Fixes agl/xmpp-client#96. Change-Id: I013bbc4eaf0616373ab52f14b7f7757c353983ca Reviewed-on: https://go-review.googlesource.com/16934 Reviewed-by: Andrew Gerrand --- otr/otr.go | 19 +++-- otr/otr_test.go | 196 ++++++++++++++++++++++++++++++------------------ 2 files changed, 136 insertions(+), 79 deletions(-) diff --git a/otr/otr.go b/otr/otr.go index f872a9d7..549be116 100644 --- a/otr/otr.go +++ b/otr/otr.go @@ -277,7 +277,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se in = in[len(msgPrefix) : len(in)-1] } else if version := isQuery(in); version > 0 { c.authState = authStateAwaitingDHKey - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHCommit()) return } else { @@ -311,7 +311,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) return case authStateAwaitingDHKey: @@ -330,7 +330,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) return } @@ -343,7 +343,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se if err = c.processDHCommit(msg); err != nil { return } - c.myKeyId = 0 + c.reset() toSend = c.encode(c.generateDHKey()) c.authState = authStateAwaitingRevealSig default: @@ -1036,8 +1036,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot, } } if slot == nil { - err = errors.New("otr: internal error: no key slots") - return + return nil, errors.New("otr: internal error: no more key slots") } var myPriv, myPub, theirPub *big.Int @@ -1163,6 +1162,14 @@ func (c *Conversation) encode(msg []byte) [][]byte { return ret } +func (c *Conversation) reset() { + c.myKeyId = 0 + + for i := range c.keySlots { + c.keySlots[i].used = false + } +} + type PublicKey struct { dsa.PublicKey } diff --git a/otr/otr_test.go b/otr/otr_test.go index 417a7939..49d2accd 100644 --- a/otr/otr_test.go +++ b/otr/otr_test.go @@ -121,11 +121,12 @@ func TestSignVerify(t *testing.T) { } } -func TestConversation(t *testing.T) { +func setupConversation(t *testing.T) (alice, bob *Conversation) { alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex) bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex) - var alice, bob Conversation + alice, bob = new(Conversation), new(Conversation) + alice.PrivateKey = new(PrivateKey) bob.PrivateKey = new(PrivateKey) alice.PrivateKey.Parse(alicePrivateKey) @@ -133,12 +134,6 @@ func TestConversation(t *testing.T) { alice.FragmentSize = 100 bob.FragmentSize = 100 - var alicesMessage, bobsMessage [][]byte - var out []byte - var aliceChange, bobChange SecurityChange - var err error - alicesMessage = append(alicesMessage, []byte(QueryMessage)) - if alice.IsEncrypted() { t.Error("Alice believes that the conversation is secure before we've started") } @@ -146,6 +141,17 @@ func TestConversation(t *testing.T) { t.Error("Bob believes that the conversation is secure before we've started") } + performHandshake(t, alice, bob) + return alice, bob +} + +func performHandshake(t *testing.T, alice, bob *Conversation) { + var alicesMessage, bobsMessage [][]byte + var out []byte + var aliceChange, bobChange SecurityChange + var err error + alicesMessage = append(alicesMessage, []byte(QueryMessage)) + for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ { bobsMessage = nil for i, msg := range alicesMessage { @@ -193,77 +199,106 @@ func TestConversation(t *testing.T) { if !bob.IsEncrypted() { t.Error("Bob doesn't believe that the conversation is secure") } +} + +const ( + firstRoundTrip = iota + subsequentRoundTrip + noMACKeyCheck +) + +func roundTrip(t *testing.T, alice, bob *Conversation, message []byte, macKeyCheck int) { + alicesMessage, err := alice.Send(message) + if err != nil { + t.Errorf("Error from Alice sending message: %s", err) + } + + if len(alice.oldMACs) != 0 { + t.Errorf("Alice has not revealed all MAC keys") + } + + for i, msg := range alicesMessage { + out, encrypted, _, _, err := bob.Receive(msg) + + if err != nil { + t.Errorf("Error generated while processing test message: %s", err.Error()) + } + if len(out) > 0 { + if i != len(alicesMessage)-1 { + t.Fatal("Bob produced a message while processing a fragment of Alice's") + } + if !encrypted { + t.Errorf("Message was not marked as encrypted") + } + if !bytes.Equal(out, message) { + t.Errorf("Message corrupted: got %x, want %x", out, message) + } + } + } + + switch macKeyCheck { + case firstRoundTrip: + if len(bob.oldMACs) != 0 { + t.Errorf("Bob should not have MAC keys to reveal") + } + case subsequentRoundTrip: + if len(bob.oldMACs) != 40 { + t.Errorf("Bob has %d bytes of MAC keys to reveal, but should have 40", len(bob.oldMACs)) + } + } + + bobsMessage, err := bob.Send(message) + if err != nil { + t.Errorf("Error from Bob sending message: %s", err) + } + + if len(bob.oldMACs) != 0 { + t.Errorf("Bob has not revealed all MAC keys") + } + + for i, msg := range bobsMessage { + out, encrypted, _, _, err := alice.Receive(msg) + + if err != nil { + t.Errorf("Error generated while processing test message: %s", err.Error()) + } + if len(out) > 0 { + if i != len(bobsMessage)-1 { + t.Fatal("Alice produced a message while processing a fragment of Bob's") + } + if !encrypted { + t.Errorf("Message was not marked as encrypted") + } + if !bytes.Equal(out, message) { + t.Errorf("Message corrupted: got %x, want %x", out, message) + } + } + } + + switch macKeyCheck { + case firstRoundTrip: + if len(alice.oldMACs) != 20 { + t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 20", len(alice.oldMACs)) + } + case subsequentRoundTrip: + if len(alice.oldMACs) != 40 { + t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 40", len(alice.oldMACs)) + } + } +} + +func TestConversation(t *testing.T) { + alice, bob := setupConversation(t) var testMessages = [][]byte{ []byte("hello"), []byte("bye"), } - for j, testMessage := range testMessages { - alicesMessage, err = alice.Send(testMessage) + roundTripType := firstRoundTrip - if len(alice.oldMACs) != 0 { - t.Errorf("Alice has not revealed all MAC keys") - } - - for i, msg := range alicesMessage { - out, encrypted, _, _, err := bob.Receive(msg) - - if err != nil { - t.Errorf("Error generated while processing test message: %s", err.Error()) - } - if len(out) > 0 { - if i != len(alicesMessage)-1 { - t.Fatal("Bob produced a message while processing a fragment of Alice's") - } - if !encrypted { - t.Errorf("Message was not marked as encrypted") - } - if !bytes.Equal(out, testMessage) { - t.Errorf("Message corrupted: got %x, want %x", out, testMessage) - } - } - } - - if j == 0 { - if len(bob.oldMACs) != 0 { - t.Errorf("Bob should not have MAC keys to reveal") - } - } else if len(bob.oldMACs) != 40 { - t.Errorf("Bob does not have MAC keys to reveal") - } - - bobsMessage, err = bob.Send(testMessage) - - if len(bob.oldMACs) != 0 { - t.Errorf("Bob has not revealed all MAC keys") - } - - for i, msg := range bobsMessage { - out, encrypted, _, _, err := alice.Receive(msg) - - if err != nil { - t.Errorf("Error generated while processing test message: %s", err.Error()) - } - if len(out) > 0 { - if i != len(bobsMessage)-1 { - t.Fatal("Alice produced a message while processing a fragment of Bob's") - } - if !encrypted { - t.Errorf("Message was not marked as encrypted") - } - if !bytes.Equal(out, testMessage) { - t.Errorf("Message corrupted: got %x, want %x", out, testMessage) - } - } - } - - if j == 0 { - if len(alice.oldMACs) != 20 { - t.Errorf("Alice does not have MAC keys to reveal") - } - } else if len(alice.oldMACs) != 40 { - t.Errorf("Alice does not have MAC keys to reveal") - } + for _, testMessage := range testMessages { + roundTrip(t, alice, bob, testMessage, roundTripType) + roundTripType = subsequentRoundTrip } } @@ -348,6 +383,21 @@ func TestBadSMP(t *testing.T) { } } +func TestRehandshaking(t *testing.T) { + alice, bob := setupConversation(t) + roundTrip(t, alice, bob, []byte("test"), firstRoundTrip) + roundTrip(t, alice, bob, []byte("test 2"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 3"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 4"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 5"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 6"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 7"), subsequentRoundTrip) + roundTrip(t, alice, bob, []byte("test 8"), subsequentRoundTrip) + performHandshake(t, alice, bob) + roundTrip(t, alice, bob, []byte("test"), noMACKeyCheck) + roundTrip(t, alice, bob, []byte("test 2"), noMACKeyCheck) +} + func TestAgainstLibOTR(t *testing.T) { // This test requires otr.c.test to be built as /tmp/a.out. // If enabled, this tests runs forever performing OTR handshakes in a