diff --git a/ssh.c b/ssh.c index 78549fed..796949b0 100644 --- a/ssh.c +++ b/ssh.c @@ -680,6 +680,7 @@ struct ssh_portfwd { sfree((pf)->sserv), sfree((pf)->dserv)) : (void)0 ), sfree(pf) ) struct Packet { + int refcount; long length; /* length of packet: see below */ long forcepad; /* SSH-2: force padding to at least this length */ int type; /* only used for incoming packets */ @@ -752,7 +753,7 @@ static struct Packet *ssh2_gss_authpacket(Ssh ssh, Ssh_gss_ctx gss_ctx, static void do_ssh2_transport(Ssh ssh, const void *vin, int inlen, struct Packet *pktin); static void ssh2_msg_unexpected(Ssh ssh, struct Packet *pktin); -static void ssh_free_packet(struct Packet *pkt); +static void ssh_unref_packet(struct Packet *pkt); struct PacketQueueNode { struct PacketQueueNode *next, *prev; @@ -816,7 +817,7 @@ static void pq_clear(struct PacketQueue *pq) { struct Packet *pkt; while ((pkt = pq_pop(pq)) != NULL) - ssh_free_packet(pkt); + ssh_unref_packet(pkt); } struct rdpkt1_state_tag { @@ -1331,10 +1332,12 @@ static void c_write_str(Ssh ssh, const char *buf) c_write(ssh, buf, strlen(buf)); } -static void ssh_free_packet(struct Packet *pkt) +static void ssh_unref_packet(struct Packet *pkt) { - sfree(pkt->data); - sfree(pkt); + if (--pkt->refcount <= 0) { + sfree(pkt->data); + sfree(pkt); + } } static struct Packet *ssh_new_packet(void) { @@ -1342,6 +1345,7 @@ static struct Packet *ssh_new_packet(void) pkt->body = pkt->data = NULL; pkt->maxlen = 0; + pkt->refcount = 1; return pkt; } @@ -1487,7 +1491,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data, if (st->biglen < 0) { bombout(("Extremely large packet length from server suggests" " data stream corruption")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1512,7 +1516,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data, if (ssh->cipher && detect_attack(ssh->crcda_ctx, st->pktin->data, st->biglen, NULL)) { bombout(("Network attack (CRC compensation) detected!")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1523,7 +1527,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data, st->gotcrc = GET_32BIT(st->pktin->data + st->biglen - 4); if (st->gotcrc != st->realcrc) { bombout(("Incorrect CRC received on packet")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1536,7 +1540,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data, st->pktin->body - 1, st->pktin->length + 1, &decompblk, &decomplen)) { bombout(("Zlib decompression encountered invalid data")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1800,7 +1804,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, break; if (st->packetlen >= OUR_V2_PACKETLIMIT) { bombout(("No valid incoming packet found")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } } @@ -1839,7 +1843,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, if (st->len < 0 || st->len > OUR_V2_PACKETLIMIT || st->len % st->cipherblk != 0) { bombout(("Incoming packet length field was garbled")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1873,7 +1877,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, && !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data, st->len + 4, st->incoming_sequence)) { bombout(("Incorrect MAC received on packet")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1912,7 +1916,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, if (st->len < 0 || st->len > OUR_V2_PACKETLIMIT || (st->len + 4) % st->cipherblk != 0) { bombout(("Incoming packet was garbled on decryption")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } @@ -1952,7 +1956,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, && !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data, st->len + 4, st->incoming_sequence)) { bombout(("Incorrect MAC received on packet")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } } @@ -1960,7 +1964,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data, st->pad = st->pktin->data[4]; if (st->pad < 4 || st->len - st->pad < 1) { bombout(("Invalid padding length on received packet")); - ssh_free_packet(st->pktin); + ssh_unref_packet(st->pktin); crStop(NULL); } /* @@ -2152,7 +2156,7 @@ static void s_wrpkt(Ssh ssh, struct Packet *pkt) backlog = s_write(ssh, pkt->data + offset, len); if (backlog > SSH_MAX_BACKLOG) ssh_throttle_all(ssh, 1, backlog); - ssh_free_packet(pkt); + ssh_unref_packet(pkt); } static void s_wrpkt_defer(Ssh ssh, struct Packet *pkt) @@ -2168,7 +2172,7 @@ static void s_wrpkt_defer(Ssh ssh, struct Packet *pkt) memcpy(ssh->deferred_send_data + ssh->deferred_len, pkt->data + offset, len); ssh->deferred_len += len; - ssh_free_packet(pkt); + ssh_unref_packet(pkt); } /* @@ -2562,7 +2566,7 @@ static void ssh2_pkt_send_noqueue(Ssh ssh, struct Packet *pkt) ssh->outgoing_data_size > ssh->max_data_size) do_ssh2_transport(ssh, "too much data sent", -1, NULL); - ssh_free_packet(pkt); + ssh_unref_packet(pkt); } /* @@ -2592,7 +2596,7 @@ static void ssh2_pkt_defer_noqueue(Ssh ssh, struct Packet *pkt, int noignore) memcpy(ssh->deferred_send_data + ssh->deferred_len, pkt->body, len); ssh->deferred_len += len; ssh->deferred_data_size += pkt->encrypted_len; - ssh_free_packet(pkt); + ssh_unref_packet(pkt); } /* @@ -3459,7 +3463,7 @@ static void ssh_process_incoming_data(Ssh ssh, pktin = ssh->s_rdpkt(ssh, data, datalen); if (pktin) { ssh->protocol(ssh, NULL, 0, pktin); - ssh_free_packet(pktin); + ssh_unref_packet(pktin); } } @@ -11925,7 +11929,7 @@ static struct Packet *ssh2_gss_authpacket(Ssh ssh, Ssh_gss_ctx gss_ctx, buf.value = (char *)p->data + micoffset; buf.length = p->length - micoffset; ssh->gsslib->get_mic(ssh->gsslib, gss_ctx, &buf, &mic); - ssh_free_packet(p); + ssh_unref_packet(p); /* Now we can build the real packet */ if (strcmp(authtype, "gssapi-with-mic") == 0) { @@ -12404,7 +12408,7 @@ static void ssh_free(void *handle) sfree(ssh->savedhost); while (ssh->queuelen-- > 0) - ssh_free_packet(ssh->queue[ssh->queuelen]); + ssh_unref_packet(ssh->queue[ssh->queuelen]); sfree(ssh->queue); while (ssh->qhead) {