From f15866be625193548fd549c710ccbd4b181b7ed8 Mon Sep 17 00:00:00 2001 From: "julien.pierre.bugs%sun.com" Date: Mon, 25 Jul 2005 20:39:14 +0000 Subject: [PATCH] Fix for bug 292151 . Prevent strsclnt from starting threads for each connection. Allow specifying a ratio of full handshakes . r=nelson --- security/nss/cmd/strsclnt/strsclnt.c | 314 ++++++++++++++++++--------- 1 file changed, 206 insertions(+), 108 deletions(-) diff --git a/security/nss/cmd/strsclnt/strsclnt.c b/security/nss/cmd/strsclnt/strsclnt.c index 35ed739d671..621e2a93257 100644 --- a/security/nss/cmd/strsclnt/strsclnt.c +++ b/security/nss/cmd/strsclnt/strsclnt.c @@ -139,6 +139,8 @@ int ssl3CipherSuites[] = { 0 }; +#define NO_FULLHS_PERCENTAGE -1 + /* This global string is so that client main can see * which ciphers to use. */ @@ -148,8 +150,26 @@ static const char *cipherString; static int certsTested; static int MakeCertOK; static int NoReuse; +static int fullhs = NO_FULLHS_PERCENTAGE; /* percentage of full handshakes to + ** perform */ +static PRInt32 globalconid = 0; /* atomically set */ +static int total_connections; /* total number of connections to perform */ +static int total_connections_rounded_down_to_hundreds; +static int total_connections_modulo_100; + static PRBool NoDelay; static PRBool QuitOnTimeout = PR_FALSE; +static PRBool ThrottleUp = PR_FALSE; + +static PRLock * threadLock; /* protects the global variables below */ +static PRTime lastConnectFailure; +static PRTime lastConnectSuccess; +static PRTime lastThrottleUp; +static int remaining_connections; /* number of connections left */ +static int active_threads = 8; /* number of threads currently trying to + ** connect */ +static int numUsed; +/* end of variables protected by threadLock */ static SSL3Statistics * ssl3stats; @@ -181,15 +201,19 @@ Usage(const char *progName) { fprintf(stderr, "Usage: %s [-n nickname] [-p port] [-d dbdir] [-c connections]\n" - " [-3DNTovq] [-2 filename]\n" + " [-3DTovq] [-2 filename] [-P fullhandshakespercentage | -N]\n" " [-w dbpasswd] [-C cipher(s)] [-t threads] hostname\n" " where -v means verbose\n" " -o flag is interpreted as follows:\n" " 1 -o means override the result of server certificate validation.\n" " 2 -o's mean skip server certificate validation altogether.\n" + " -3 means disable SSL3\n" " -D means no TCP delays\n" " -q means quit when server gone (timeout rather than retry forever)\n" - " -N means no session reuse\n", + " -N means no session reuse\n" + " -P means do a specified percentage of full handshakes (0-100)\n" + " -T means disable TLS\n" + " -U means enable throttling up threads\n", progName); exit(1); } @@ -355,26 +379,18 @@ printSecurityInfo(PRFileDesc *fd) typedef int startFn(void *a, void *b, int c); -PRLock * threadLock; -PRCondVar * threadStartQ; -PRCondVar * threadEndQ; -int numUsed; -int numRunning; -PRInt32 numConnected; -int max_threads = 8; /* default much less than max. */ - -typedef enum { rs_idle = 0, rs_running = 1, rs_zombie = 2 } runState; +static PRInt32 numConnected; +static int max_threads; /* peak threads allowed */ typedef struct perThreadStr { void * a; void * b; - int c; + int tid; int rv; startFn * startFunc; PRThread * prThread; PRBool inUse; - runState running; } perThread; perThread threads[MAX_THREADS]; @@ -383,25 +399,61 @@ void thread_wrapper(void * arg) { perThread * slot = (perThread *)arg; + PRBool die = PR_FALSE; - /* wait for parent to finish launching us before proceeding. */ - PR_Lock(threadLock); - PR_Unlock(threadLock); + do { + PRBool doop = PR_FALSE; + PRBool dosleep = PR_FALSE; + PRTime now = PR_Now(); - slot->rv = (* slot->startFunc)(slot->a, slot->b, slot->c); - - /* Handle cleanup of thread here. */ - PRINTF("strsclnt: Thread in slot %d returned %d\n", - slot - threads, slot->rv); - - PR_Lock(threadLock); - slot->running = rs_idle; - --numRunning; - - /* notify the thread launcher. */ - PR_NotifyCondVar(threadStartQ); - - PR_Unlock(threadLock); + PR_Lock(threadLock); + if (! (slot->tid < active_threads)) { + /* this thread isn't supposed to be running */ + if (!ThrottleUp) { + /* we'll never need this thread again, so abort it */ + die = PR_TRUE; + } else if (remaining_connections > 0) { + /* we may still need this thread, so just sleep for 1s */ + dosleep = PR_TRUE; + /* the conditions to trigger a throttle up are : + ** 1. last PR_Connect failure must have happened more than + ** 10s ago + ** 2. last throttling up must have happened more than 0.5s ago + ** 3. there must be a more recent PR_Connect success than + ** failure + */ + if ( (now - lastConnectFailure > 10 * PR_USEC_PER_SEC) && + ( (!lastThrottleUp) || ( (now - lastThrottleUp) >= + (PR_USEC_PER_SEC/2)) ) && + (lastConnectSuccess > lastConnectFailure) ) { + /* try throttling up by one thread */ + active_threads = PR_MIN(max_threads, active_threads+1); + fprintf(stderr,"active_threads set up to %d\n", + active_threads); + lastThrottleUp = PR_MAX(now, lastThrottleUp); + } + } else { + /* no more connections left, we are done */ + die = PR_TRUE; + } + } else { + /* this thread should run */ + if (--remaining_connections >= 0) { + doop = PR_TRUE; + } else { + die = PR_TRUE; + } + } + PR_Unlock(threadLock); + if (doop) { + slot->rv = (* slot->startFunc)(slot->a, slot->b, slot->tid); + PRINTF("strsclnt: Thread in slot %d returned %d\n", + slot->tid, slot->rv); + } + if (dosleep) { + PR_Sleep(PR_SecondsToInterval(1)); + } + } while (!die); } SECStatus @@ -409,46 +461,30 @@ launch_thread( startFn * startFunc, void * a, void * b, - int c) + int tid) { perThread * slot; int i; - if (!threadStartQ) { - threadLock = PR_NewLock(); - threadStartQ = PR_NewCondVar(threadLock); - threadEndQ = PR_NewCondVar(threadLock); - } PR_Lock(threadLock); - while (numRunning >= max_threads) { - PR_WaitCondVar(threadStartQ, PR_INTERVAL_NO_TIMEOUT); - } - for (i = 0; i < numUsed; ++i) { - if (threads[i].running == rs_idle) - break; - } - if (i >= numUsed) { - if (i >= MAX_THREADS) { - /* something's really wrong here. */ - PORT_Assert(i < MAX_THREADS); - PR_Unlock(threadLock); - return SECFailure; - } - ++numUsed; - PORT_Assert(numUsed == i + 1); + + PORT_Assert(numUsed < MAX_THREADS); + if (! (numUsed < MAX_THREADS)) { + PR_Unlock(threadLock); + return SECFailure; } - slot = threads + i; + slot = &threads[numUsed++]; slot->a = a; slot->b = b; - slot->c = c; + slot->tid = tid; slot->startFunc = startFunc; slot->prThread = PR_CreateThread(PR_USER_THREAD, thread_wrapper, slot, PR_PRIORITY_NORMAL, PR_GLOBAL_THREAD, - PR_UNJOINABLE_THREAD, 0); + PR_JOINABLE_THREAD, 0); if (slot->prThread == NULL) { PR_Unlock(threadLock); printf("strsclnt: Failed to launch thread!\n"); @@ -456,37 +492,25 @@ launch_thread( } slot->inUse = 1; - slot->running = 1; - ++numRunning; PR_Unlock(threadLock); PRINTF("strsclnt: Launched thread in slot %d \n", i); return SECSuccess; } -/* Wait until numRunning == 0 */ +/* join all the threads */ int reap_threads(void) { perThread * slot; int i; - if (!threadLock) - return 0; - PR_Lock(threadLock); - while (numRunning > 0) { - PR_WaitCondVar(threadStartQ, PR_INTERVAL_NO_TIMEOUT); + for (i = 0; i < MAX_THREADS; ++i) { + if (threads[i].prThread) { + PR_JoinThread(threads[i].prThread); + threads[i].prThread = NULL; + } } - - /* Safety Sam sez: make sure count is right. */ - for (i = 0; i < numUsed; ++i) { - slot = threads + i; - if (slot->running != rs_idle) { - FPRINTF(stderr, "strsclnt: Thread in slot %d is in state %d!\n", - i, slot->running); - } - } - PR_Unlock(threadLock); return 0; } @@ -495,20 +519,18 @@ destroy_thread_data(void) { PORT_Memset(threads, 0, sizeof threads); - if (threadEndQ) { - PR_DestroyCondVar(threadEndQ); - threadEndQ = NULL; - } - if (threadStartQ) { - PR_DestroyCondVar(threadStartQ); - threadStartQ = NULL; - } if (threadLock) { PR_DestroyLock(threadLock); threadLock = NULL; } } +void +init_thread_data(void) +{ + threadLock = PR_NewLock(); +} + /************************************************************************** ** End thread management routines. **************************************************************************/ @@ -668,7 +690,7 @@ cleanup: const char request[] = {"GET /abc HTTP/1.0\r\n\r\n" }; SECStatus -handle_connection( PRFileDesc *ssl_sock, int connection) +handle_connection( PRFileDesc *ssl_sock, int tid) { int countRead = 0; PRInt32 rv; @@ -702,8 +724,9 @@ handle_connection( PRFileDesc *ssl_sock, int connection) } countRead += rv; - FPRINTF(stderr, "strsclnt: connection %d read %d bytes (%d total).\n", - connection, rv, countRead ); + FPRINTF(stderr, + "strsclnt: connection on thread %d read %d bytes (%d total).\n", + tid, rv, countRead ); } PR_Free(buf); buf = 0; @@ -711,12 +734,27 @@ handle_connection( PRFileDesc *ssl_sock, int connection) /* Caller closes the socket. */ FPRINTF(stderr, - "strsclnt: connection %d read %d bytes total. -----------------------\n", - connection, countRead); + "strsclnt: connection on thread %d read %d bytes total. ---------\n", + tid, countRead); return SECSuccess; /* success */ } +#define USE_SOCK_PEER_ID 1 + +#ifdef USE_SOCK_PEER_ID + +PRInt32 lastFullHandshakePeerID; + +SECStatus +myHandshakeCallback(PRFileDesc *socket, void *arg) +{ + PR_AtomicSet(&lastFullHandshakePeerID, (PRInt32) arg); + return SECSuccess; +} + +#endif + /* one copy of this function is launched in a separate thread for each ** connection to be made. */ @@ -724,7 +762,7 @@ int do_connects( void * a, void * b, - int connection) + int tid) { PRNetAddr * addr = (PRNetAddr *) a; PRFileDesc * model_sock = (PRFileDesc *) b; @@ -765,16 +803,26 @@ retry: prStatus = PR_Connect(tcp_sock, addr, PR_INTERVAL_NO_TIMEOUT); if (prStatus != PR_SUCCESS) { - PRErrorCode err = PR_GetError(); - if ((err == PR_CONNECT_REFUSED_ERROR) || + PRErrorCode err = PR_GetError(); /* save error code */ + if (ThrottleUp) { + PRTime now = PR_Now(); + PR_Lock(threadLock); + lastConnectFailure = PR_MAX(now, lastConnectFailure); + PR_Unlock(threadLock); + } + if ((err == PR_CONNECT_REFUSED_ERROR) || (err == PR_CONNECT_RESET_ERROR) ) { int connections = numConnected; PR_Close(tcp_sock); - if (connections > 2 && max_threads >= connections) { - max_threads = connections - 1; - fprintf(stderr,"max_threads set down to %d\n", max_threads); - } + PR_Lock(threadLock); + if (connections > 2 && active_threads >= connections) { + active_threads = connections - 1; + fprintf(stderr,"active_threads set down to %d\n", + active_threads); + } + PR_Unlock(threadLock); + if (QuitOnTimeout && sleepInterval > 40000) { fprintf(stderr, "strsclnt: Client timed out waiting for connection to server.\n"); @@ -787,6 +835,13 @@ retry: errWarn("PR_Connect"); rv = SECFailure; goto done; + } else { + if (ThrottleUp) { + PRTime now; + PR_Lock(threadLock); + lastConnectSuccess = PR_MAX(now, lastConnectSuccess); + PR_Unlock(threadLock); + } } ssl_sock = SSL_ImportFD(model_sock, tcp_sock); @@ -795,7 +850,37 @@ retry: PR_Close(tcp_sock); return SECSuccess; } - + if (fullhs != NO_FULLHS_PERCENTAGE) { +#ifdef USE_SOCK_PEER_ID + char sockPeerIDString[512]; + static PRInt32 sockPeerID = 0; /* atomically incremented */ + PRInt32 thisPeerID; +#endif + PRInt32 savid = PR_AtomicIncrement(&globalconid); + PRInt32 conid = 1 + (savid - 1) % 100; + /* don't change peer ID on the very first handshake, which is always + a full, so the session gets stored into the client cache */ + if ( (savid != 1) && + ( ( (savid <= total_connections_rounded_down_to_hundreds) && + (conid <= fullhs) ) || + (conid*100 <= total_connections_modulo_100*fullhs ) ) ) { +#ifdef USE_SOCK_PEER_ID + /* force a full handshake by changing the socket peer ID */ + thisPeerID = PR_AtomicIncrement(&sockPeerID); + } else { + /* reuse previous sockPeerID for restart handhsake */ + thisPeerID = lastFullHandshakePeerID; + } + PR_snprintf(sockPeerIDString, sizeof(sockPeerIDString), "ID%d", + thisPeerID); + SSL_SetSockPeerID(ssl_sock, sockPeerIDString); + SSL_HandshakeCallback(ssl_sock, myHandshakeCallback, (void*)thisPeerID); +#else + /* force a full handshake by setting the no cache option */ + SSL_OptionSet(ssl_sock, SSL_NO_CACHE, 1); + } +#endif + } rv = SSL_ResetHandshake(ssl_sock, /* asServer */ 0); if (rv != SECSuccess) { errWarn("SSL_ResetHandshake"); @@ -805,9 +890,9 @@ retry: PR_AtomicIncrement(&numConnected); if (bigBuf.data != NULL) { - result = handle_fdx_connection( ssl_sock, connection); + result = handle_fdx_connection( ssl_sock, tid); } else { - result = handle_connection( ssl_sock, connection); + result = handle_connection( ssl_sock, tid); } PR_AtomicDecrement(&numConnected); @@ -918,7 +1003,7 @@ StressClient_GetClientAuthData(void * arg, PR_Unlock(Cert_And_Key->lock); if (!*pRetCert || !*pRetKey) { /* one or both of them failed to copy. Either the source was NULL, or there was - an out of memory condition. Free any allocated copy and fail */ + ** an out of memory condition. Free any allocated copy and fail */ if (*pRetCert) { CERT_DestroyCertificate(*pRetCert); *pRetCert = NULL; @@ -1126,20 +1211,26 @@ client_main( /* end of ssl configuration. */ - i = 1; + init_thread_data(); + + remaining_connections = total_connections = connections; + total_connections_modulo_100 = total_connections % 100; + total_connections_rounded_down_to_hundreds = + total_connections - total_connections_modulo_100; + if (!NoReuse) { - rv = launch_thread(do_connects, &addr, model_sock, i); - --connections; - ++i; + remaining_connections = 1; + rv = launch_thread(do_connects, &addr, model_sock, 0); /* wait for the first connection to terminate, then launch the rest. */ reap_threads(); + remaining_connections = total_connections - 1 ; } - if (connections > 0) { - /* Start up the connections */ - do { + if (remaining_connections > 0) { + active_threads = PR_MIN(active_threads, remaining_connections); + /* Start up the threads */ + for (i=0;i 0); + } reap_threads(); } destroy_thread_data(); @@ -1214,7 +1305,7 @@ main(int argc, char **argv) progName = progName ? progName + 1 : tmp; - optstate = PL_CreateOptState(argc, argv, "2:3C:DNTc:d:n:op:qt:vw:"); + optstate = PL_CreateOptState(argc, argv, "2:3C:DNP:TUc:d:n:op:qt:vw:"); while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) { switch(optstate->option) { @@ -1227,8 +1318,12 @@ main(int argc, char **argv) case 'D': NoDelay = PR_TRUE; break; case 'N': NoReuse = 1; break; + + case 'P': fullhs = PORT_Atoi(optstate->value); break; case 'T': disableTLS = PR_TRUE; break; + + case 'U': ThrottleUp = PR_TRUE; break; case 'c': connections = PORT_Atoi(optstate->value); break; @@ -1245,7 +1340,7 @@ main(int argc, char **argv) case 't': tmpInt = PORT_Atoi(optstate->value); if (tmpInt > 0 && tmpInt < MAX_THREADS) - max_threads = tmpInt; + max_threads = active_threads = tmpInt; break; case 'v': verbose++; break; @@ -1269,6 +1364,9 @@ main(int argc, char **argv) if (!hostName || status == PL_OPT_BAD) Usage(progName); + if (fullhs!= NO_FULLHS_PERCENTAGE && (fullhs < 0 || fullhs>100 || NoReuse) ) + Usage(progName); + if (port == 0) Usage(progName);