x/crypto/otr: clear key slots when handshaking.

The OTR implementation had a bug where key slots would not be marked as
unused after a rehandshake. Since handshaking resets the key ids, some
key slots would be left over with much higher key ids. These key slots
would lead to an error when the code ran out of slots.

Fixes agl/xmpp-client#96.

Change-Id: I013bbc4eaf0616373ab52f14b7f7757c353983ca
Reviewed-on: https://go-review.googlesource.com/16934
Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Adam Langley 2015-11-15 11:34:54 -08:00
Родитель 346896d577
Коммит d438f321d3
2 изменённых файлов: 136 добавлений и 79 удалений

Просмотреть файл

@ -277,7 +277,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
in = in[len(msgPrefix) : len(in)-1]
} else if version := isQuery(in); version > 0 {
c.authState = authStateAwaitingDHKey
c.myKeyId = 0
c.reset()
toSend = c.encode(c.generateDHCommit())
return
} else {
@ -311,7 +311,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
if err = c.processDHCommit(msg); err != nil {
return
}
c.myKeyId = 0
c.reset()
toSend = c.encode(c.generateDHKey())
return
case authStateAwaitingDHKey:
@ -330,7 +330,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
if err = c.processDHCommit(msg); err != nil {
return
}
c.myKeyId = 0
c.reset()
toSend = c.encode(c.generateDHKey())
return
}
@ -343,7 +343,7 @@ func (c *Conversation) Receive(in []byte) (out []byte, encrypted bool, change Se
if err = c.processDHCommit(msg); err != nil {
return
}
c.myKeyId = 0
c.reset()
toSend = c.encode(c.generateDHKey())
c.authState = authStateAwaitingRevealSig
default:
@ -1036,8 +1036,7 @@ func (c *Conversation) calcDataKeys(myKeyId, theirKeyId uint32) (slot *keySlot,
}
}
if slot == nil {
err = errors.New("otr: internal error: no key slots")
return
return nil, errors.New("otr: internal error: no more key slots")
}
var myPriv, myPub, theirPub *big.Int
@ -1163,6 +1162,14 @@ func (c *Conversation) encode(msg []byte) [][]byte {
return ret
}
func (c *Conversation) reset() {
c.myKeyId = 0
for i := range c.keySlots {
c.keySlots[i].used = false
}
}
type PublicKey struct {
dsa.PublicKey
}

Просмотреть файл

@ -121,11 +121,12 @@ func TestSignVerify(t *testing.T) {
}
}
func TestConversation(t *testing.T) {
func setupConversation(t *testing.T) (alice, bob *Conversation) {
alicePrivateKey, _ := hex.DecodeString(alicePrivateKeyHex)
bobPrivateKey, _ := hex.DecodeString(bobPrivateKeyHex)
var alice, bob Conversation
alice, bob = new(Conversation), new(Conversation)
alice.PrivateKey = new(PrivateKey)
bob.PrivateKey = new(PrivateKey)
alice.PrivateKey.Parse(alicePrivateKey)
@ -133,12 +134,6 @@ func TestConversation(t *testing.T) {
alice.FragmentSize = 100
bob.FragmentSize = 100
var alicesMessage, bobsMessage [][]byte
var out []byte
var aliceChange, bobChange SecurityChange
var err error
alicesMessage = append(alicesMessage, []byte(QueryMessage))
if alice.IsEncrypted() {
t.Error("Alice believes that the conversation is secure before we've started")
}
@ -146,6 +141,17 @@ func TestConversation(t *testing.T) {
t.Error("Bob believes that the conversation is secure before we've started")
}
performHandshake(t, alice, bob)
return alice, bob
}
func performHandshake(t *testing.T, alice, bob *Conversation) {
var alicesMessage, bobsMessage [][]byte
var out []byte
var aliceChange, bobChange SecurityChange
var err error
alicesMessage = append(alicesMessage, []byte(QueryMessage))
for round := 0; len(alicesMessage) > 0 || len(bobsMessage) > 0; round++ {
bobsMessage = nil
for i, msg := range alicesMessage {
@ -193,77 +199,106 @@ func TestConversation(t *testing.T) {
if !bob.IsEncrypted() {
t.Error("Bob doesn't believe that the conversation is secure")
}
}
const (
firstRoundTrip = iota
subsequentRoundTrip
noMACKeyCheck
)
func roundTrip(t *testing.T, alice, bob *Conversation, message []byte, macKeyCheck int) {
alicesMessage, err := alice.Send(message)
if err != nil {
t.Errorf("Error from Alice sending message: %s", err)
}
if len(alice.oldMACs) != 0 {
t.Errorf("Alice has not revealed all MAC keys")
}
for i, msg := range alicesMessage {
out, encrypted, _, _, err := bob.Receive(msg)
if err != nil {
t.Errorf("Error generated while processing test message: %s", err.Error())
}
if len(out) > 0 {
if i != len(alicesMessage)-1 {
t.Fatal("Bob produced a message while processing a fragment of Alice's")
}
if !encrypted {
t.Errorf("Message was not marked as encrypted")
}
if !bytes.Equal(out, message) {
t.Errorf("Message corrupted: got %x, want %x", out, message)
}
}
}
switch macKeyCheck {
case firstRoundTrip:
if len(bob.oldMACs) != 0 {
t.Errorf("Bob should not have MAC keys to reveal")
}
case subsequentRoundTrip:
if len(bob.oldMACs) != 40 {
t.Errorf("Bob has %d bytes of MAC keys to reveal, but should have 40", len(bob.oldMACs))
}
}
bobsMessage, err := bob.Send(message)
if err != nil {
t.Errorf("Error from Bob sending message: %s", err)
}
if len(bob.oldMACs) != 0 {
t.Errorf("Bob has not revealed all MAC keys")
}
for i, msg := range bobsMessage {
out, encrypted, _, _, err := alice.Receive(msg)
if err != nil {
t.Errorf("Error generated while processing test message: %s", err.Error())
}
if len(out) > 0 {
if i != len(bobsMessage)-1 {
t.Fatal("Alice produced a message while processing a fragment of Bob's")
}
if !encrypted {
t.Errorf("Message was not marked as encrypted")
}
if !bytes.Equal(out, message) {
t.Errorf("Message corrupted: got %x, want %x", out, message)
}
}
}
switch macKeyCheck {
case firstRoundTrip:
if len(alice.oldMACs) != 20 {
t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 20", len(alice.oldMACs))
}
case subsequentRoundTrip:
if len(alice.oldMACs) != 40 {
t.Errorf("Alice has %d bytes of MAC keys to reveal, but should have 40", len(alice.oldMACs))
}
}
}
func TestConversation(t *testing.T) {
alice, bob := setupConversation(t)
var testMessages = [][]byte{
[]byte("hello"), []byte("bye"),
}
for j, testMessage := range testMessages {
alicesMessage, err = alice.Send(testMessage)
roundTripType := firstRoundTrip
if len(alice.oldMACs) != 0 {
t.Errorf("Alice has not revealed all MAC keys")
}
for i, msg := range alicesMessage {
out, encrypted, _, _, err := bob.Receive(msg)
if err != nil {
t.Errorf("Error generated while processing test message: %s", err.Error())
}
if len(out) > 0 {
if i != len(alicesMessage)-1 {
t.Fatal("Bob produced a message while processing a fragment of Alice's")
}
if !encrypted {
t.Errorf("Message was not marked as encrypted")
}
if !bytes.Equal(out, testMessage) {
t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
}
}
}
if j == 0 {
if len(bob.oldMACs) != 0 {
t.Errorf("Bob should not have MAC keys to reveal")
}
} else if len(bob.oldMACs) != 40 {
t.Errorf("Bob does not have MAC keys to reveal")
}
bobsMessage, err = bob.Send(testMessage)
if len(bob.oldMACs) != 0 {
t.Errorf("Bob has not revealed all MAC keys")
}
for i, msg := range bobsMessage {
out, encrypted, _, _, err := alice.Receive(msg)
if err != nil {
t.Errorf("Error generated while processing test message: %s", err.Error())
}
if len(out) > 0 {
if i != len(bobsMessage)-1 {
t.Fatal("Alice produced a message while processing a fragment of Bob's")
}
if !encrypted {
t.Errorf("Message was not marked as encrypted")
}
if !bytes.Equal(out, testMessage) {
t.Errorf("Message corrupted: got %x, want %x", out, testMessage)
}
}
}
if j == 0 {
if len(alice.oldMACs) != 20 {
t.Errorf("Alice does not have MAC keys to reveal")
}
} else if len(alice.oldMACs) != 40 {
t.Errorf("Alice does not have MAC keys to reveal")
}
for _, testMessage := range testMessages {
roundTrip(t, alice, bob, testMessage, roundTripType)
roundTripType = subsequentRoundTrip
}
}
@ -348,6 +383,21 @@ func TestBadSMP(t *testing.T) {
}
}
func TestRehandshaking(t *testing.T) {
alice, bob := setupConversation(t)
roundTrip(t, alice, bob, []byte("test"), firstRoundTrip)
roundTrip(t, alice, bob, []byte("test 2"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 3"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 4"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 5"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 6"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 7"), subsequentRoundTrip)
roundTrip(t, alice, bob, []byte("test 8"), subsequentRoundTrip)
performHandshake(t, alice, bob)
roundTrip(t, alice, bob, []byte("test"), noMACKeyCheck)
roundTrip(t, alice, bob, []byte("test 2"), noMACKeyCheck)
}
func TestAgainstLibOTR(t *testing.T) {
// This test requires otr.c.test to be built as /tmp/a.out.
// If enabled, this tests runs forever performing OTR handshakes in a