diff --git a/ssh.h b/ssh.h index 02c4a1ae..6dbe31bf 100644 --- a/ssh.h +++ b/ssh.h @@ -535,11 +535,10 @@ struct ssh_rsa_kex_extra { struct RSAKey *ssh_rsakex_newkey(ptrlen data); void ssh_rsakex_freekey(struct RSAKey *key); int ssh_rsakex_klen(struct RSAKey *key); -void ssh_rsakex_encrypt(const struct ssh_hashalg *h, - unsigned char *in, int inlen, - unsigned char *out, int outlen, struct RSAKey *key); -mp_int *ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext, - struct RSAKey *rsa); +strbuf *ssh_rsakex_encrypt( + struct RSAKey *key, const struct ssh_hashalg *h, ptrlen plaintext); +mp_int *ssh_rsakex_decrypt( + struct RSAKey *key, const struct ssh_hashalg *h, ptrlen ciphertext); /* * SSH2 ECDH key exchange functions diff --git a/ssh2kex-client.c b/ssh2kex-client.c index f2e5dc13..15ab9eb7 100644 --- a/ssh2kex-client.c +++ b/ssh2kex-client.c @@ -557,9 +557,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) int klen = ssh_rsakex_klen(s->rsa_kex_key); int nbits = klen - (2*s->kex_alg->hash->hlen*8 + 49); int i, byte = 0; - strbuf *buf; - unsigned char *outstr; - int outstrlen; + strbuf *buf, *outstr; s->K = mp_power_2(nbits - 1); @@ -579,22 +577,19 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) /* * Encrypt it with the given RSA key. */ - outstrlen = (klen + 7) / 8; - outstr = snewn(outstrlen, unsigned char); - ssh_rsakex_encrypt(s->kex_alg->hash, buf->u, buf->len, - outstr, outstrlen, s->rsa_kex_key); + outstr = ssh_rsakex_encrypt(s->rsa_kex_key, s->kex_alg->hash, + ptrlen_from_strbuf(buf)); /* * And send it off in a return packet. */ pktout = ssh_bpp_new_pktout(s->ppl.bpp, SSH2_MSG_KEXRSA_SECRET); - put_string(pktout, outstr, outstrlen); + put_stringpl(pktout, ptrlen_from_strbuf(outstr)); pq_push(s->ppl.out_pq, pktout); - put_string(s->exhash, outstr, outstrlen); + put_stringsb(s->exhash, outstr); /* frees outstr */ strbuf_free(buf); - sfree(outstr); } ssh_rsakex_freekey(s->rsa_kex_key); diff --git a/ssh2kex-server.c b/ssh2kex-server.c index a5c3fb5c..5d8cd410 100644 --- a/ssh2kex-server.c +++ b/ssh2kex-server.c @@ -279,7 +279,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) ptrlen encrypted_secret = get_string(pktin); put_stringpl(s->exhash, encrypted_secret); s->K = ssh_rsakex_decrypt( - s->kex_alg->hash, encrypted_secret, s->rsa_kex_key); + s->rsa_kex_key, s->kex_alg->hash, encrypted_secret); } if (!s->K) { diff --git a/sshrsa.c b/sshrsa.c index 8c768646..a1abcfae 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -743,9 +743,8 @@ static void oaep_mask(const struct ssh_hashalg *h, void *seed, int seedlen, } } -void ssh_rsakex_encrypt(const struct ssh_hashalg *h, - unsigned char *in, int inlen, - unsigned char *out, int outlen, struct RSAKey *rsa) +strbuf *ssh_rsakex_encrypt( + struct RSAKey *rsa, const struct ssh_hashalg *h, ptrlen in) { mp_int *b1, *b2; int k, i; @@ -783,10 +782,12 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h, k = (7 + mp_get_nbits(rsa->modulus)) / 8; /* The length of the input data must be at most k - 2hLen - 2. */ - assert(inlen > 0 && inlen <= k - 2*HLEN - 2); + assert(in.len > 0 && in.len <= k - 2*HLEN - 2); /* The length of the output data wants to be precisely k. */ - assert(outlen == k); + strbuf *toret = strbuf_new(); + int outlen = k; + unsigned char *out = strbuf_append(toret, outlen); /* * Now perform EME-OAEP encoding. First set up all the unmasked @@ -806,8 +807,8 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h, /* A bunch of zero octets */ memset(out + 2*HLEN + 1, 0, outlen - (2*HLEN + 1)); /* A single 1 octet, followed by the input message data. */ - out[outlen - inlen - 1] = 1; - memcpy(out + outlen - inlen, in, inlen); + out[outlen - in.len - 1] = 1; + memcpy(out + outlen - in.len, in.ptr, in.len); /* * Now use the seed data to mask the block DB. @@ -835,10 +836,11 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h, /* * And we're done. */ + return toret; } -mp_int *ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext, - struct RSAKey *rsa) +mp_int *ssh_rsakex_decrypt( + struct RSAKey *rsa, const struct ssh_hashalg *h, ptrlen ciphertext) { mp_int *b1, *b2; int outlen, i;