BN_BLINDING multi-threading fix.

Submitted by: Emilia Kasper (Google)
diff --git a/CHANGES b/CHANGES
index d47e9b9..1bcd6f3 100644
--- a/CHANGES
+++ b/CHANGES
@@ -461,6 +461,16 @@
 
  Changes between 1.0.0e and 1.0.0f [xx XXX xxxx]
 
+  *) Fix handling of BN_BLINDING: now BN_BLINDING_invert_ex (rather than
+     BN_BLINDING_invert_ex) calls BN_BLINDING_update, ensuring that concurrent
+     threads won't reuse the same blinding coefficients.
+
+     This also avoids the need to obtain the CRYPTO_LOCK_RSA_BLINDING
+     lock to call BN_BLINDING_invert_ex, and avoids one use of
+     BN_BLINDING_update for each BN_BLINDING structure (previously,
+     the last update always remained unused).
+     [Emilia Käsper (Google)]
+
   *) In ssl3_clear, preserve s3->init_extra along with s3->rbuf.
      [Bob Buckholz (Google)]
 
@@ -1371,8 +1381,15 @@
   
  Changes between 0.9.8r and 0.9.8s [xx XXX xxxx]
 
-  *) In ssl3_clear, preserve s3->init_extra along with s3->rbuf.
-     [Bob Buckholz (Google)]
+  *) Fix handling of BN_BLINDING: now BN_BLINDING_invert_ex (rather than
+     BN_BLINDING_invert_ex) calls BN_BLINDING_update, ensuring that concurrent
+     threads won't reuse the same blinding coefficients.
+
+     This also avoids the need to obtain the CRYPTO_LOCK_RSA_BLINDING
+     lock to call BN_BLINDING_invert_ex, and avoids one use of
+     BN_BLINDING_update for each BN_BLINDING structure (previously,
+     the last update always remained unused).
+     [Emilia Käsper (Google)]
 
   *) Fix SSL memory handling for (EC)DH ciphersuites, in particular
      for multi-threaded use of ECDH.
diff --git a/crypto/bn/bn_blind.c b/crypto/bn/bn_blind.c
index 2dc677c..c1ce161 100644
--- a/crypto/bn/bn_blind.c
+++ b/crypto/bn/bn_blind.c
@@ -128,7 +128,7 @@
 				  * used only by crypto/rsa/rsa_eay.c, rsa_lib.c */
 #endif
 	CRYPTO_THREADID tid;
-	unsigned int  counter;
+	int counter;
 	unsigned long flags;
 	BN_MONT_CTX *m_ctx;
 	int (*bn_mod_exp)(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
@@ -162,7 +162,10 @@
 	if (BN_get_flags(mod, BN_FLG_CONSTTIME) != 0)
 		BN_set_flags(ret->mod, BN_FLG_CONSTTIME);
 
-	ret->counter = BN_BLINDING_COUNTER;
+	/* Set the counter to the special value -1
+	 * to indicate that this is never-used fresh blinding
+	 * that does not need updating before first use. */
+	ret->counter = -1;
 	CRYPTO_THREADID_current(&ret->tid);
 	return(ret);
 err:
@@ -192,7 +195,10 @@
 		goto err;
 		}
 
-	if (--(b->counter) == 0 && b->e != NULL &&
+	if (b->counter == -1)
+		b->counter = 0;
+
+	if (++b->counter == BN_BLINDING_COUNTER && b->e != NULL &&
 		!(b->flags & BN_BLINDING_NO_RECREATE))
 		{
 		/* re-create blinding parameters */
@@ -207,8 +213,8 @@
 
 	ret=1;
 err:
-	if (b->counter == 0)
-		b->counter = BN_BLINDING_COUNTER;
+	if (b->counter == BN_BLINDING_COUNTER)
+		b->counter = 0;
 	return(ret);
 	}
 
@@ -229,6 +235,12 @@
 		return(0);
 		}
 
+	if (b->counter == -1)
+		/* Fresh blinding, doesn't need updating. */
+		b->counter = 0;
+	else if (!BN_BLINDING_update(b,ctx))
+		return(0);
+
 	if (r != NULL)
 		{
 		if (!BN_copy(r, b->Ai)) ret=0;
@@ -249,22 +261,19 @@
 	int ret;
 
 	bn_check_top(n);
-	if ((b->A == NULL) || (b->Ai == NULL))
-		{
-		BNerr(BN_F_BN_BLINDING_INVERT_EX,BN_R_NOT_INITIALIZED);
-		return(0);
-		}
 
 	if (r != NULL)
 		ret = BN_mod_mul(n, n, r, b->mod, ctx);
 	else
-		ret = BN_mod_mul(n, n, b->Ai, b->mod, ctx);
-
-	if (ret >= 0)
 		{
-		if (!BN_BLINDING_update(b,ctx))
+		if (b->Ai == NULL)
+			{
+			BNerr(BN_F_BN_BLINDING_INVERT_EX,BN_R_NOT_INITIALIZED);
 			return(0);
+			}
+		ret = BN_mod_mul(n, n, b->Ai, b->mod, ctx);
 		}
+
 	bn_check_top(n);
 	return(ret);
 	}
diff --git a/crypto/rsa/rsa_eay.c b/crypto/rsa/rsa_eay.c
index 325efb9..16f000f 100644
--- a/crypto/rsa/rsa_eay.c
+++ b/crypto/rsa/rsa_eay.c
@@ -334,45 +334,51 @@
 	return ret;
 }
 
-static int rsa_blinding_convert(BN_BLINDING *b, int local, BIGNUM *f,
-	BIGNUM *r, BN_CTX *ctx)
-{
-	if (local)
+static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
+	BN_CTX *ctx)
+	{
+	if (unblind == NULL)
+		/* Local blinding: store the unblinding factor
+		 * in BN_BLINDING. */
 		return BN_BLINDING_convert_ex(f, NULL, b, ctx);
 	else
 		{
-		int ret;
-		CRYPTO_r_lock(CRYPTO_LOCK_RSA_BLINDING);
-		ret = BN_BLINDING_convert_ex(f, r, b, ctx);
-		CRYPTO_r_unlock(CRYPTO_LOCK_RSA_BLINDING);
-		return ret;
-		}
-}
-
-static int rsa_blinding_invert(BN_BLINDING *b, int local, BIGNUM *f,
-	BIGNUM *r, BN_CTX *ctx)
-{
-	if (local)
-		return BN_BLINDING_invert_ex(f, NULL, b, ctx);
-	else
-		{
+		/* Shared blinding: store the unblinding factor
+		 * outside BN_BLINDING. */
 		int ret;
 		CRYPTO_w_lock(CRYPTO_LOCK_RSA_BLINDING);
-		ret = BN_BLINDING_invert_ex(f, r, b, ctx);
+		ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
 		CRYPTO_w_unlock(CRYPTO_LOCK_RSA_BLINDING);
 		return ret;
 		}
-}
+	}
+
+static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
+	BN_CTX *ctx)
+	{
+	/* For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
+	 * will use the unblinding factor stored in BN_BLINDING.
+	 * If BN_BLINDING is shared between threads, unblind must be non-null:
+	 * BN_BLINDING_invert_ex will then use the local unblinding factor,
+	 * and will only read the modulus from BN_BLINDING.
+	 * In both cases it's safe to access the blinding without a lock.
+	 */
+	return BN_BLINDING_invert_ex(f, unblind, b, ctx);
+	}
 
 /* signing */
 static int RSA_eay_private_encrypt(int flen, const unsigned char *from,
 	     unsigned char *to, RSA *rsa, int padding)
 	{
-	BIGNUM *f, *ret, *br, *res;
+	BIGNUM *f, *ret, *res;
 	int i,j,k,num=0,r= -1;
 	unsigned char *buf=NULL;
 	BN_CTX *ctx=NULL;
 	int local_blinding = 0;
+	/* Used only if the blinding structure is shared. A non-NULL unblind
+	 * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
+	 * the unblinding factor outside the blinding structure. */
+	BIGNUM *unblind = NULL;
 	BN_BLINDING *blinding = NULL;
 
 #ifdef OPENSSL_FIPS
@@ -393,7 +399,6 @@
 	if ((ctx=BN_CTX_new()) == NULL) goto err;
 	BN_CTX_start(ctx);
 	f   = BN_CTX_get(ctx);
-	br  = BN_CTX_get(ctx);
 	ret = BN_CTX_get(ctx);
 	num = BN_num_bytes(rsa->n);
 	buf = OPENSSL_malloc(num);
@@ -441,8 +446,15 @@
 		}
 	
 	if (blinding != NULL)
-		if (!rsa_blinding_convert(blinding, local_blinding, f, br, ctx))
+		{
+		if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL))
+			{
+			RSAerr(RSA_F_RSA_EAY_PRIVATE_ENCRYPT,ERR_R_MALLOC_FAILURE);
 			goto err;
+			}
+		if (!rsa_blinding_convert(blinding, f, unblind, ctx))
+			goto err;
+		}
 
 	if ( (rsa->flags & RSA_FLAG_EXT_PKEY) ||
 		((rsa->p != NULL) &&
@@ -476,7 +488,7 @@
 		}
 
 	if (blinding)
-		if (!rsa_blinding_invert(blinding, local_blinding, ret, br, ctx))
+		if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
 			goto err;
 
 	if (padding == RSA_X931_PADDING)
@@ -515,12 +527,16 @@
 static int RSA_eay_private_decrypt(int flen, const unsigned char *from,
 	     unsigned char *to, RSA *rsa, int padding)
 	{
-	BIGNUM *f, *ret, *br;
+	BIGNUM *f, *ret;
 	int j,num=0,r= -1;
 	unsigned char *p;
 	unsigned char *buf=NULL;
 	BN_CTX *ctx=NULL;
 	int local_blinding = 0;
+	/* Used only if the blinding structure is shared. A non-NULL unblind
+	 * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
+	 * the unblinding factor outside the blinding structure. */
+	BIGNUM *unblind = NULL;
 	BN_BLINDING *blinding = NULL;
 
 #ifdef OPENSSL_FIPS
@@ -541,7 +557,6 @@
 	if((ctx = BN_CTX_new()) == NULL) goto err;
 	BN_CTX_start(ctx);
 	f   = BN_CTX_get(ctx);
-	br  = BN_CTX_get(ctx);
 	ret = BN_CTX_get(ctx);
 	num = BN_num_bytes(rsa->n);
 	buf = OPENSSL_malloc(num);
@@ -579,8 +594,15 @@
 		}
 	
 	if (blinding != NULL)
-		if (!rsa_blinding_convert(blinding, local_blinding, f, br, ctx))
+		{
+		if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL))
+			{
+			RSAerr(RSA_F_RSA_EAY_PRIVATE_DECRYPT,ERR_R_MALLOC_FAILURE);
 			goto err;
+			}
+		if (!rsa_blinding_convert(blinding, f, unblind, ctx))
+			goto err;
+		}
 
 	/* do the decrypt */
 	if ( (rsa->flags & RSA_FLAG_EXT_PKEY) ||
@@ -614,7 +636,7 @@
 		}
 
 	if (blinding)
-		if (!rsa_blinding_invert(blinding, local_blinding, ret, br, ctx))
+		if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
 			goto err;
 
 	p=buf;