Add a reference count in 'struct Packet'.

This is another piece of not-yet-used infrastructure, which later on
will simplify my life when I start processing PacketQueues and adding
some of their packets to other PacketQueues, because this way the code
can unref every packet removed from the source queue in the same way,
whether or not the packet is actually finished with.
This commit is contained in:
Simon Tatham 2018-05-18 07:22:57 +01:00
Родитель 9d96c3eb02
Коммит cfc3386a15
1 изменённых файлов: 26 добавлений и 22 удалений

44
ssh.c
Просмотреть файл

@ -680,6 +680,7 @@ struct ssh_portfwd {
sfree((pf)->sserv), sfree((pf)->dserv)) : (void)0 ), sfree(pf) )
struct Packet {
int refcount;
long length; /* length of packet: see below */
long forcepad; /* SSH-2: force padding to at least this length */
int type; /* only used for incoming packets */
@ -752,7 +753,7 @@ static struct Packet *ssh2_gss_authpacket(Ssh ssh, Ssh_gss_ctx gss_ctx,
static void do_ssh2_transport(Ssh ssh, const void *vin, int inlen,
struct Packet *pktin);
static void ssh2_msg_unexpected(Ssh ssh, struct Packet *pktin);
static void ssh_free_packet(struct Packet *pkt);
static void ssh_unref_packet(struct Packet *pkt);
struct PacketQueueNode {
struct PacketQueueNode *next, *prev;
@ -816,7 +817,7 @@ static void pq_clear(struct PacketQueue *pq)
{
struct Packet *pkt;
while ((pkt = pq_pop(pq)) != NULL)
ssh_free_packet(pkt);
ssh_unref_packet(pkt);
}
struct rdpkt1_state_tag {
@ -1331,10 +1332,12 @@ static void c_write_str(Ssh ssh, const char *buf)
c_write(ssh, buf, strlen(buf));
}
static void ssh_free_packet(struct Packet *pkt)
static void ssh_unref_packet(struct Packet *pkt)
{
if (--pkt->refcount <= 0) {
sfree(pkt->data);
sfree(pkt);
}
}
static struct Packet *ssh_new_packet(void)
{
@ -1342,6 +1345,7 @@ static struct Packet *ssh_new_packet(void)
pkt->body = pkt->data = NULL;
pkt->maxlen = 0;
pkt->refcount = 1;
return pkt;
}
@ -1487,7 +1491,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
if (st->biglen < 0) {
bombout(("Extremely large packet length from server suggests"
" data stream corruption"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1512,7 +1516,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
if (ssh->cipher && detect_attack(ssh->crcda_ctx, st->pktin->data,
st->biglen, NULL)) {
bombout(("Network attack (CRC compensation) detected!"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1523,7 +1527,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
st->gotcrc = GET_32BIT(st->pktin->data + st->biglen - 4);
if (st->gotcrc != st->realcrc) {
bombout(("Incorrect CRC received on packet"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1536,7 +1540,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
st->pktin->body - 1, st->pktin->length + 1,
&decompblk, &decomplen)) {
bombout(("Zlib decompression encountered invalid data"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1800,7 +1804,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
break;
if (st->packetlen >= OUR_V2_PACKETLIMIT) {
bombout(("No valid incoming packet found"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
}
@ -1839,7 +1843,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
if (st->len < 0 || st->len > OUR_V2_PACKETLIMIT ||
st->len % st->cipherblk != 0) {
bombout(("Incoming packet length field was garbled"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1873,7 +1877,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
&& !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data,
st->len + 4, st->incoming_sequence)) {
bombout(("Incorrect MAC received on packet"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1912,7 +1916,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
if (st->len < 0 || st->len > OUR_V2_PACKETLIMIT ||
(st->len + 4) % st->cipherblk != 0) {
bombout(("Incoming packet was garbled on decryption"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
@ -1952,7 +1956,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
&& !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data,
st->len + 4, st->incoming_sequence)) {
bombout(("Incorrect MAC received on packet"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
}
@ -1960,7 +1964,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
st->pad = st->pktin->data[4];
if (st->pad < 4 || st->len - st->pad < 1) {
bombout(("Invalid padding length on received packet"));
ssh_free_packet(st->pktin);
ssh_unref_packet(st->pktin);
crStop(NULL);
}
/*
@ -2152,7 +2156,7 @@ static void s_wrpkt(Ssh ssh, struct Packet *pkt)
backlog = s_write(ssh, pkt->data + offset, len);
if (backlog > SSH_MAX_BACKLOG)
ssh_throttle_all(ssh, 1, backlog);
ssh_free_packet(pkt);
ssh_unref_packet(pkt);
}
static void s_wrpkt_defer(Ssh ssh, struct Packet *pkt)
@ -2168,7 +2172,7 @@ static void s_wrpkt_defer(Ssh ssh, struct Packet *pkt)
memcpy(ssh->deferred_send_data + ssh->deferred_len,
pkt->data + offset, len);
ssh->deferred_len += len;
ssh_free_packet(pkt);
ssh_unref_packet(pkt);
}
/*
@ -2562,7 +2566,7 @@ static void ssh2_pkt_send_noqueue(Ssh ssh, struct Packet *pkt)
ssh->outgoing_data_size > ssh->max_data_size)
do_ssh2_transport(ssh, "too much data sent", -1, NULL);
ssh_free_packet(pkt);
ssh_unref_packet(pkt);
}
/*
@ -2592,7 +2596,7 @@ static void ssh2_pkt_defer_noqueue(Ssh ssh, struct Packet *pkt, int noignore)
memcpy(ssh->deferred_send_data + ssh->deferred_len, pkt->body, len);
ssh->deferred_len += len;
ssh->deferred_data_size += pkt->encrypted_len;
ssh_free_packet(pkt);
ssh_unref_packet(pkt);
}
/*
@ -3459,7 +3463,7 @@ static void ssh_process_incoming_data(Ssh ssh,
pktin = ssh->s_rdpkt(ssh, data, datalen);
if (pktin) {
ssh->protocol(ssh, NULL, 0, pktin);
ssh_free_packet(pktin);
ssh_unref_packet(pktin);
}
}
@ -11925,7 +11929,7 @@ static struct Packet *ssh2_gss_authpacket(Ssh ssh, Ssh_gss_ctx gss_ctx,
buf.value = (char *)p->data + micoffset;
buf.length = p->length - micoffset;
ssh->gsslib->get_mic(ssh->gsslib, gss_ctx, &buf, &mic);
ssh_free_packet(p);
ssh_unref_packet(p);
/* Now we can build the real packet */
if (strcmp(authtype, "gssapi-with-mic") == 0) {
@ -12404,7 +12408,7 @@ static void ssh_free(void *handle)
sfree(ssh->savedhost);
while (ssh->queuelen-- > 0)
ssh_free_packet(ssh->queue[ssh->queuelen]);
ssh_unref_packet(ssh->queue[ssh->queuelen]);
sfree(ssh->queue);
while (ssh->qhead) {