Fix for bug 204015 - make strsclnt reuse token after it is unplugged . r=nelsonb

This commit is contained in:
jpierre%netscape.com 2003-05-15 17:09:19 +00:00
Родитель afeedbf77b
Коммит 903e186861
1 изменённых файлов: 209 добавлений и 50 удалений

Просмотреть файл

@ -175,10 +175,11 @@ static void
Usage(const char *progName) Usage(const char *progName)
{ {
fprintf(stderr, fprintf(stderr,
"Usage: %s [-n rsa_nickname] [-p port] [-d dbdir] [-c connections]\n" "Usage: %s [-n nickname] [-p port] [-d dbdir] [-c connections]\n"
" [-DNvq] [-f fortezza_nickname] [-2 filename]\n" " [-DNovq] [-2 filename]\n"
" [-w dbpasswd] [-C cipher(s)] [-t threads] hostname\n" " [-w dbpasswd] [-C cipher(s)] [-t threads] hostname\n"
" where -v means verbose\n" " where -v means verbose\n"
" -o means override server certificate validation\n"
" -D means no TCP delays\n" " -D means no TCP delays\n"
" -q means quit when server gone (timeout rather than retry forever)\n" " -q means quit when server gone (timeout rather than retry forever)\n"
" -N means no session reuse\n", " -N means no session reuse\n",
@ -230,6 +231,13 @@ disableAllSSLCiphers(void)
} }
} }
static SECStatus
myGoodSSLAuthCertificate(void *arg, PRFileDesc *fd, PRBool checkSig,
PRBool isServer)
{
return SECSuccess;
}
/* This invokes the "default" AuthCert handler in libssl. /* This invokes the "default" AuthCert handler in libssl.
** The only reason to use this one is that it prints out info as it goes. ** The only reason to use this one is that it prints out info as it goes.
*/ */
@ -250,7 +258,7 @@ mySSLAuthCertificate(void *arg, PRFileDesc *fd, PRBool checkSig,
++certsTested; ++certsTested;
if (rv == SECSuccess) { if (rv == SECSuccess) {
fputs("strsclnt: -- SSL: Server Certificate Validated.\n", stderr); fputs("strsclnt: -- SSL: Server Certificate Validated.\n", stderr);
} }
CERT_DestroyCertificate(peerCert); CERT_DestroyCertificate(peerCert);
/* error, if any, will be displayed by the Bad Cert Handler. */ /* error, if any, will be displayed by the Bad Cert Handler. */
return rv; return rv;
@ -830,14 +838,187 @@ getIPAddress(const char * hostName)
return rv; return rv;
} }
typedef struct {
PRLock* lock;
char* nickname;
CERTCertificate* cert;
SECKEYPrivateKey* key;
char* password;
} cert_and_key;
PRBool FindCertAndKey(cert_and_key* Cert_And_Key)
{
if ( (NULL == Cert_And_Key->nickname) || (0 == strcmp(Cert_And_Key->nickname,"none"))) {
return PR_TRUE;
}
Cert_And_Key->cert = CERT_FindUserCertByUsage(CERT_GetDefaultCertDB(),
Cert_And_Key->nickname, certUsageSSLClient,
PR_FALSE, Cert_And_Key->password);
if (Cert_And_Key->cert) {
Cert_And_Key->key = PK11_FindKeyByAnyCert(Cert_And_Key->cert, Cert_And_Key->password);
}
if (Cert_And_Key->cert && Cert_And_Key->key) {
return PR_TRUE;
} else {
return PR_FALSE;
}
}
PRBool LoggedIn(CERTCertificate* cert, SECKEYPrivateKey* key)
{
if ( (cert->slot) && (key->pkcs11Slot) &&
(PR_TRUE == PK11_IsLoggedIn(cert->slot, NULL)) &&
(PR_TRUE == PK11_IsLoggedIn(key->pkcs11Slot, NULL)) ) {
return PR_TRUE;
}
return PR_FALSE;
}
SECStatus
StressClient_GetClientAuthData(void * arg,
PRFileDesc * socket,
struct CERTDistNamesStr * caNames,
struct CERTCertificateStr ** pRetCert,
struct SECKEYPrivateKeyStr **pRetKey)
{
cert_and_key* Cert_And_Key = (cert_and_key*) arg;
if (!pRetCert || !pRetKey) {
/* bad pointers, can't return a cert or key */
return SECFailure;
}
*pRetCert = NULL;
*pRetKey = NULL;
if (Cert_And_Key && Cert_And_Key->nickname) {
while (PR_TRUE) {
if (Cert_And_Key && Cert_And_Key->lock) {
int timeout = 0;
SECStatus rv = SECSuccess;
PR_Lock(Cert_And_Key->lock);
if (Cert_And_Key->cert) {
*pRetCert = CERT_DupCertificate(Cert_And_Key->cert);
}
if (Cert_And_Key->key) {
*pRetKey = SECKEY_CopyPrivateKey(Cert_And_Key->key);
}
PR_Unlock(Cert_And_Key->lock);
if (!*pRetCert || !*pRetKey) {
/* one or both of them failed to copy. Either the source was NULL, or there was
an out of memory condition. Free any allocated copy and fail */
if (*pRetCert) {
CERT_DestroyCertificate(*pRetCert);
*pRetCert = NULL;
}
if (*pRetKey) {
SECKEY_DestroyPrivateKey(*pRetKey);
*pRetKey = NULL;
}
break;
}
/* now check if those objects are valid */
if ( PR_FALSE == LoggedIn(*pRetCert, *pRetKey) ) {
/* token is no longer logged in, it was removed */
int timeout = 0;
CERTCertificate* oldcert = NULL;
SECKEYPrivateKey* oldkey = NULL;
/* first, delete and clear our invalid local objects */
CERT_DestroyCertificate(*pRetCert);
SECKEY_DestroyPrivateKey(*pRetKey);
*pRetCert = NULL;
*pRetKey = NULL;
PR_Lock(Cert_And_Key->lock);
/* check if another thread already logged back in */
if (PR_TRUE == LoggedIn(Cert_And_Key->cert, Cert_And_Key->key)) {
/* yes : try again */
PR_Unlock(Cert_And_Key->lock);
continue;
}
/* this is the thread to retry */
CERT_DestroyCertificate(Cert_And_Key->cert);
SECKEY_DestroyPrivateKey(Cert_And_Key->key);
Cert_And_Key->cert = NULL;
Cert_And_Key->key = NULL;
/* now look up the cert and key again */
while (PR_FALSE == FindCertAndKey(Cert_And_Key) ) {
PR_Sleep(PR_SecondsToInterval(1));
timeout++;
if (timeout>=60) {
printf("\nToken pulled and not reinserted early enough : aborting.\n");
exit(1);
}
}
PR_Unlock(Cert_And_Key->lock);
continue;
/* try again to reduce code size */
}
return SECSuccess;
}
}
*pRetCert = NULL;
*pRetKey = NULL;
return SECFailure;
} else {
/* no cert configured, automatically find the right cert. */
CERTCertificate * cert = NULL;
SECKEYPrivateKey * privkey = NULL;
CERTCertNicknames * names;
int i;
void * proto_win;
SECStatus rv = SECFailure;
if (Cert_And_Key) {
proto_win = Cert_And_Key->password;
}
names = CERT_GetCertNicknames(CERT_GetDefaultCertDB(),
SEC_CERT_NICKNAMES_USER, proto_win);
if (names != NULL) {
for (i = 0; i < names->numnicknames; i++) {
cert = CERT_FindUserCertByUsage(CERT_GetDefaultCertDB(),
names->nicknames[i], certUsageSSLClient,
PR_FALSE, proto_win);
if ( !cert )
continue;
/* Only check unexpired certs */
if (CERT_CheckCertValidTimes(cert, PR_Now(), PR_TRUE) !=
secCertTimeValid ) {
CERT_DestroyCertificate(cert);
continue;
}
rv = NSS_CmpCertChainWCANames(cert, caNames);
if ( rv == SECSuccess ) {
privkey = PK11_FindKeyByAnyCert(cert, proto_win);
if ( privkey )
break;
}
rv = SECFailure;
CERT_DestroyCertificate(cert);
}
CERT_FreeNicknames(names);
}
if (rv == SECSuccess) {
*pRetCert = cert;
*pRetKey = privkey;
}
return rv;
}
}
void void
client_main( client_main(
unsigned short port, unsigned short port,
int connections, int connections,
SECKEYPrivateKey ** privKey, cert_and_key* Cert_And_Key,
CERTCertificate ** cert, const char * hostName)
const char * hostName,
char * nickName)
{ {
PRFileDesc *model_sock = NULL; PRFileDesc *model_sock = NULL;
int i; int i;
@ -920,11 +1101,10 @@ client_main(
SSL_SetURL(model_sock, hostName); SSL_SetURL(model_sock, hostName);
SSL_AuthCertificateHook(model_sock, mySSLAuthCertificate, SSL_AuthCertificateHook(model_sock, mySSLAuthCertificate,
(void *)CERT_GetDefaultCertDB()); (void *)CERT_GetDefaultCertDB());
SSL_BadCertHook(model_sock, myBadCertHandler, NULL); SSL_BadCertHook(model_sock, myBadCertHandler, NULL);
SSL_GetClientAuthDataHook(model_sock, NSS_GetClientAuthData, nickName); SSL_GetClientAuthDataHook(model_sock, StressClient_GetClientAuthData, (void*)Cert_And_Key);
/* I'm not going to set the HandshakeCallback function. */ /* I'm not going to set the HandshakeCallback function. */
@ -949,7 +1129,6 @@ client_main(
destroy_thread_data(); destroy_thread_data();
PR_Close(model_sock); PR_Close(model_sock);
} }
SECStatus SECStatus
@ -995,15 +1174,12 @@ int
main(int argc, char **argv) main(int argc, char **argv)
{ {
const char * dir = "."; const char * dir = ".";
char * fNickName = NULL;
const char * fileName = NULL; const char * fileName = NULL;
char * hostName = NULL; char * hostName = NULL;
char * nickName = NULL; char * nickName = NULL;
char * progName = NULL; char * progName = NULL;
char * tmp = NULL; char * tmp = NULL;
char * passwd = NULL; char * passwd = NULL;
CERTCertificate * cert [kt_kea_size] = { NULL };
SECKEYPrivateKey * privKey[kt_kea_size] = { NULL };
int connections = 1; int connections = 1;
int exitVal; int exitVal;
int tmpInt; int tmpInt;
@ -1011,6 +1187,7 @@ main(int argc, char **argv)
SECStatus rv; SECStatus rv;
PLOptState * optstate; PLOptState * optstate;
PLOptStatus status; PLOptStatus status;
cert_and_key Cert_And_Key;
/* Call the NSPR initialization routines */ /* Call the NSPR initialization routines */
PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1); PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);
@ -1021,7 +1198,7 @@ main(int argc, char **argv)
progName = progName ? progName + 1 : tmp; progName = progName ? progName + 1 : tmp;
optstate = PL_CreateOptState(argc, argv, "2:C:DNc:d:f:n:op:t:vqw:"); optstate = PL_CreateOptState(argc, argv, "2:C:DNc:d:n:op:qt:vw:");
while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) { while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) {
switch(optstate->option) { switch(optstate->option) {
@ -1037,8 +1214,6 @@ main(int argc, char **argv)
case 'd': dir = optstate->value; break; case 'd': dir = optstate->value; break;
case 'f': fNickName = PL_strdup(optstate->value); break;
case 'n': nickName = PL_strdup(optstate->value); break; case 'n': nickName = PL_strdup(optstate->value); break;
case 'o': MakeCertOK = 1; break; case 'o': MakeCertOK = 1; break;
@ -1094,53 +1269,37 @@ main(int argc, char **argv)
exit(1); exit(1);
} }
ssl3stats = SSL_GetStatistics(); ssl3stats = SSL_GetStatistics();
Cert_And_Key.lock = PR_NewLock();
Cert_And_Key.nickname = nickName;
Cert_And_Key.password = passwd;
Cert_And_Key.cert = NULL;
Cert_And_Key.key = NULL;
if (nickName && strcmp(nickName, "none")) { if (PR_FALSE == FindCertAndKey(&Cert_And_Key)) {
cert[kt_rsa] = PK11_FindCertFromNickname(nickName, passwd); if (Cert_And_Key.cert == NULL) {
if (cert[kt_rsa] == NULL) { fprintf(stderr, "strsclnt: Can't find certificate %s\n", Cert_And_Key.nickname);
fprintf(stderr, "strsclnt: Can't find certificate %s\n", nickName);
exit(1); exit(1);
} }
privKey[kt_rsa] = PK11_FindKeyByAnyCert(cert[kt_rsa], passwd); if (Cert_And_Key.key == NULL) {
if (privKey[kt_rsa] == NULL) {
fprintf(stderr, "strsclnt: Can't find Private Key for cert %s\n", fprintf(stderr, "strsclnt: Can't find Private Key for cert %s\n",
nickName); Cert_And_Key.nickname);
exit(1); exit(1);
} }
} }
if (fNickName) {
cert[kt_fortezza] = PK11_FindCertFromNickname(fNickName, passwd);
if (cert[kt_fortezza] == NULL) {
fprintf(stderr, "strsclnt: Can't find certificate %s\n", fNickName);
exit(1);
}
privKey[kt_fortezza] = PK11_FindKeyByAnyCert(cert[kt_fortezza], passwd); client_main(port, connections, &Cert_And_Key, hostName);
if (privKey[kt_fortezza] == NULL) {
fprintf(stderr, "strsclnt: Can't find Private Key for cert %s\n",
fNickName);
exit(1);
}
}
client_main(port, connections, privKey, cert, hostName, nickName);
/* clean up */ /* clean up */
if (cert[kt_rsa]) { if (Cert_And_Key.cert) {
CERT_DestroyCertificate(cert[kt_rsa]); CERT_DestroyCertificate(Cert_And_Key.cert);
} }
if (cert[kt_fortezza]) { if (Cert_And_Key.key) {
CERT_DestroyCertificate(cert[kt_fortezza]); SECKEY_DestroyPrivateKey(Cert_And_Key.key);
}
if (privKey[kt_rsa]) {
SECKEY_DestroyPrivateKey(privKey[kt_rsa]);
}
if (privKey[kt_fortezza]) {
SECKEY_DestroyPrivateKey(privKey[kt_fortezza]);
} }
PR_DestroyLock(Cert_And_Key.lock);
/* some final stats. */ /* some final stats. */
if (ssl3stats->hsh_sid_cache_hits + ssl3stats->hsh_sid_cache_misses + if (ssl3stats->hsh_sid_cache_hits + ssl3stats->hsh_sid_cache_misses +