diff --git a/security/nss/cmd/tstclnt/tstclnt.c b/security/nss/cmd/tstclnt/tstclnt.c index fb5e786f632..836e22f88c5 100644 --- a/security/nss/cmd/tstclnt/tstclnt.c +++ b/security/nss/cmd/tstclnt/tstclnt.c @@ -46,7 +46,7 @@ #if defined(XP_UNIX) #include #else -#include "ctype.h" /* for isalpha() */ +#include /* for isalpha() */ #endif #include @@ -65,6 +65,11 @@ #include "pk11func.h" #include "plgetopt.h" +#if defined(WIN32) +#include +#include +#endif + #define PRINTF if (verbose) printf #define FPRINTF if (verbose) fprintf @@ -358,6 +363,39 @@ own_GetClientAuthData(void * arg, return NSS_GetClientAuthData(arg, socket, caNames, pRetCert, pRetKey); } +#if defined(WIN32) +void +thread_main(void * arg) +{ + PRFileDesc * ps = (PRFileDesc *)arg; + PRFileDesc * std_in = PR_GetSpecialFD(PR_StandardInput); + int wc, rc; + char buf[256]; + + { + /* Put stdin into O_BINARY mode + ** or else incoming \r\n's will become \n's. + */ + int smrv = _setmode(_fileno(stdin), _O_BINARY); + if (smrv == -1) { + fprintf(stderr, + "%s: Cannot change stdin to binary mode. Use -i option instead.\n", + progName); + /* plow ahead anyway */ + } + } + + do { + rc = PR_Read(std_in, buf, sizeof buf); + if (rc <= 0) + break; + wc = PR_Write(ps, buf, rc); + } while (wc == rc); + PR_Close(ps); +} + +#endif + int main(int argc, char **argv) { PRFileDesc * s; @@ -379,7 +417,6 @@ int main(int argc, char **argv) int disableSSL3 = 0; int disableTLS = 0; int useExportPolicy = 0; - int file_read = 0; PRSocketOptionData opt; PRNetAddr addr; PRHostEnt hp; @@ -387,7 +424,7 @@ int main(int argc, char **argv) char buf[PR_NETDB_BUF_SIZE]; PRBool useCommandLinePassword = PR_FALSE; PRBool pingServerFirst = PR_FALSE; - int error=0; + int error = 0; PLOptState *optstate; PLOptStatus optstatus; PRStatus prStatus; @@ -412,11 +449,7 @@ int main(int argc, char **argv) case 'c': cipherString = strdup(optstate->value); break; case 'h': host = strdup(optstate->value); break; -#ifdef _WINDOWS - case 'f': file_read = 1; break; -#else - case 'f': break; -#endif + case 'f': /* no longer meaningful. */ break; case 'd': certDir = strdup(optstate->value); @@ -460,11 +493,11 @@ int main(int argc, char **argv) PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1); /* set our password function */ - if ( useCommandLinePassword ) { - PK11_SetPasswordFunc(ownPasswd); - } else { + if ( useCommandLinePassword ) { + PK11_SetPasswordFunc(ownPasswd); + } else { PK11_SetPasswordFunc(SECU_GetModulePassword); - } + } /* open the cert DB, the key DB, and the secmod DB. */ rv = NSS_Init(certDir); @@ -692,11 +725,35 @@ int main(int argc, char **argv) npds = 2; std_out = PR_GetSpecialFD(PR_StandardOutput); +#if defined(WIN32) + /* PR_Poll cannot be used with stdin on Windows. (sigh). + ** But use of PR_Poll and non-blocking sockets is a major feature + ** of this program. So, we simulate a pollable stdin with a + ** TCP socket pair and a thread that reads stdin and writes to + ** that socket pair. + */ + { + PRFileDesc * fds[2]; + PRThread * thread; - if (file_read) { - pollset[1].out_flags = PR_POLL_READ; - npds=1; + int nspr_rv = PR_NewTCPSocketPair(fds); + if (nspr_rv != PR_SUCCESS) { + SECU_PrintError(progName, "PR_NewTCPSocketPair failed"); + error = 1; + goto done; } + pollset[1].fd = fds[1]; + + thread = PR_CreateThread(PR_USER_THREAD, thread_main, fds[0], + PR_PRIORITY_NORMAL, PR_GLOBAL_THREAD, + PR_UNJOINABLE_THREAD, 0); + if (!thread) { + SECU_PrintError(progName, "PR_CreateThread failed"); + error = 1; + goto done; + } + } +#endif /* ** Select on stdin and on the socket. Write data from stdin to @@ -709,21 +766,14 @@ int main(int argc, char **argv) int nb; /* num bytes read from stdin. */ pollset[0].out_flags = 0; - if (!file_read) { - pollset[1].out_flags = 0; - } + pollset[1].out_flags = 0; PRINTF("%s: about to call PR_Poll !\n", progName); - if (pollset[1].in_flags && file_read) { - filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_WAIT); - filesReady++; - } else { - filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT); - } + filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT); if (filesReady < 0) { - SECU_PrintError(progName, "select failed"); - error=1; - goto done; + SECU_PrintError(progName, "select failed"); + error = 1; + goto done; } if (filesReady == 0) { /* shouldn't happen! */ PRINTF("%s: PR_Poll returned zero!\n", progName); @@ -733,17 +783,15 @@ int main(int argc, char **argv) if (pollset[1].in_flags) { PRINTF("%s: PR_Poll returned 0x%02x for stdin out_flags.\n", progName, pollset[1].out_flags); -#ifndef _WINDOWS } if (pollset[1].out_flags & PR_POLL_READ) { -#endif /* Read from stdin and write to socket */ nb = PR_Read(pollset[1].fd, buf, sizeof(buf)); PRINTF("%s: stdin read %d bytes\n", progName, nb); if (nb < 0) { if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { SECU_PrintError(progName, "read from stdin failed"); - error=1; + error = 1; break; } } else if (nb == 0) { @@ -758,7 +806,7 @@ int main(int argc, char **argv) if (err != PR_WOULD_BLOCK_ERROR) { SECU_PrintError(progName, "write to SSL socket failed"); - error=254; + error = 254; goto done; } cc = 0; @@ -793,7 +841,7 @@ int main(int argc, char **argv) if (nb < 0) { if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { SECU_PrintError(progName, "read from socket failed"); - error=1; + error = 1; goto done; } } else if (nb == 0) {