[PATCH v2 02/15] Port TinyAES-128 to be thread safe.

Anton Lundin glance at acc.umu.se
Wed Dec 17 14:10:59 PST 2014


Signed-off-by: Anton Lundin <glance at acc.umu.se>
---
 src/aes.c | 251 ++++++++++++++++++++++++++++++++------------------------------
 1 file changed, 129 insertions(+), 122 deletions(-)

diff --git a/src/aes.c b/src/aes.c
index 4c7645f..4315a65 100644
--- a/src/aes.c
+++ b/src/aes.c
@@ -63,18 +63,21 @@ NOTE:   String length must be evenly divisible by 16byte (str_len % 16 == 0)
 /*****************************************************************************/
 // state - array holding the intermediate results during decryption.
 typedef uint8_t state_t[4][4];
-static state_t* state;
 
-// The array that stores the round keys.
-static uint8_t RoundKey[176];
+typedef struct aes_state_t {
+	state_t* state;
 
-// The Key input to the AES Program
-static const uint8_t* Key;
+	// The array that stores the round keys.
+	uint8_t RoundKey[176];
+
+	// The Key input to the AES Program
+	const uint8_t* Key;
 
 #if defined(CBC) && CBC
-  // Initial Vector used only for CBC mode
-  static uint8_t* Iv;
+	// Initial Vector used only for CBC mode
+	uint8_t* Iv;
 #endif
+} aes_state_t;
 
 // The lookup-tables are marked const so they can be placed in read-only storage instead of RAM
 // The numbers below can be computed dynamically trading ROM for RAM - 
@@ -153,7 +156,7 @@ static uint8_t getSBoxInvert(uint8_t num)
 }
 
 // This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states. 
-static void KeyExpansion(void)
+static void KeyExpansion(aes_state_t *state)
 {
   uint32_t i, j, k;
   uint8_t tempa[4]; // Used for the column/row operations
@@ -161,10 +164,10 @@ static void KeyExpansion(void)
   // The first round key is the key itself.
   for(i = 0; i < Nk; ++i)
   {
-    RoundKey[(i * 4) + 0] = Key[(i * 4) + 0];
-    RoundKey[(i * 4) + 1] = Key[(i * 4) + 1];
-    RoundKey[(i * 4) + 2] = Key[(i * 4) + 2];
-    RoundKey[(i * 4) + 3] = Key[(i * 4) + 3];
+    state->RoundKey[(i * 4) + 0] = state->Key[(i * 4) + 0];
+    state->RoundKey[(i * 4) + 1] = state->Key[(i * 4) + 1];
+    state->RoundKey[(i * 4) + 2] = state->Key[(i * 4) + 2];
+    state->RoundKey[(i * 4) + 3] = state->Key[(i * 4) + 3];
   }
 
   // All other round keys are found from the previous round keys.
@@ -172,7 +175,7 @@ static void KeyExpansion(void)
   {
     for(j = 0; j < 4; ++j)
     {
-      tempa[j]=RoundKey[(i-1) * 4 + j];
+      tempa[j]=state->RoundKey[(i-1) * 4 + j];
     }
     if (i % Nk == 0)
     {
@@ -211,37 +214,37 @@ static void KeyExpansion(void)
         tempa[3] = getSBoxValue(tempa[3]);
       }
     }
-    RoundKey[i * 4 + 0] = RoundKey[(i - Nk) * 4 + 0] ^ tempa[0];
-    RoundKey[i * 4 + 1] = RoundKey[(i - Nk) * 4 + 1] ^ tempa[1];
-    RoundKey[i * 4 + 2] = RoundKey[(i - Nk) * 4 + 2] ^ tempa[2];
-    RoundKey[i * 4 + 3] = RoundKey[(i - Nk) * 4 + 3] ^ tempa[3];
+    state->RoundKey[i * 4 + 0] = state->RoundKey[(i - Nk) * 4 + 0] ^ tempa[0];
+    state->RoundKey[i * 4 + 1] = state->RoundKey[(i - Nk) * 4 + 1] ^ tempa[1];
+    state->RoundKey[i * 4 + 2] = state->RoundKey[(i - Nk) * 4 + 2] ^ tempa[2];
+    state->RoundKey[i * 4 + 3] = state->RoundKey[(i - Nk) * 4 + 3] ^ tempa[3];
   }
 }
 
 // This function adds the round key to state.
 // The round key is added to the state by an XOR function.
-static void AddRoundKey(uint8_t round)
+static void AddRoundKey(aes_state_t *state, uint8_t round)
 {
   uint8_t i,j;
   for(i=0;i<4;++i)
   {
     for(j = 0; j < 4; ++j)
     {
-      (*state)[i][j] ^= RoundKey[round * Nb * 4 + i * Nb + j];
+      (*state->state)[i][j] ^= state->RoundKey[round * Nb * 4 + i * Nb + j];
     }
   }
 }
 
 // The SubBytes Function Substitutes the values in the
 // state matrix with values in an S-box.
-static void SubBytes(void)
+static void SubBytes(aes_state_t *state)
 {
   uint8_t i, j;
   for(i = 0; i < 4; ++i)
   {
     for(j = 0; j < 4; ++j)
     {
-      (*state)[j][i] = getSBoxValue((*state)[j][i]);
+      (*state->state)[j][i] = getSBoxValue((*state->state)[j][i]);
     }
   }
 }
@@ -249,32 +252,32 @@ static void SubBytes(void)
 // The ShiftRows() function shifts the rows in the state to the left.
 // Each row is shifted with different offset.
 // Offset = Row number. So the first row is not shifted.
-static void ShiftRows(void)
+static void ShiftRows(aes_state_t *state)
 {
   uint8_t temp;
 
   // Rotate first row 1 columns to left  
-  temp           = (*state)[0][1];
-  (*state)[0][1] = (*state)[1][1];
-  (*state)[1][1] = (*state)[2][1];
-  (*state)[2][1] = (*state)[3][1];
-  (*state)[3][1] = temp;
+  temp           = (*state->state)[0][1];
+  (*state->state)[0][1] = (*state->state)[1][1];
+  (*state->state)[1][1] = (*state->state)[2][1];
+  (*state->state)[2][1] = (*state->state)[3][1];
+  (*state->state)[3][1] = temp;
 
   // Rotate second row 2 columns to left  
-  temp           = (*state)[0][2];
-  (*state)[0][2] = (*state)[2][2];
-  (*state)[2][2] = temp;
+  temp           = (*state->state)[0][2];
+  (*state->state)[0][2] = (*state->state)[2][2];
+  (*state->state)[2][2] = temp;
 
-  temp       = (*state)[1][2];
-  (*state)[1][2] = (*state)[3][2];
-  (*state)[3][2] = temp;
+  temp       = (*state->state)[1][2];
+  (*state->state)[1][2] = (*state->state)[3][2];
+  (*state->state)[3][2] = temp;
 
   // Rotate third row 3 columns to left
-  temp       = (*state)[0][3];
-  (*state)[0][3] = (*state)[3][3];
-  (*state)[3][3] = (*state)[2][3];
-  (*state)[2][3] = (*state)[1][3];
-  (*state)[1][3] = temp;
+  temp       = (*state->state)[0][3];
+  (*state->state)[0][3] = (*state->state)[3][3];
+  (*state->state)[3][3] = (*state->state)[2][3];
+  (*state->state)[2][3] = (*state->state)[1][3];
+  (*state->state)[1][3] = temp;
 }
 
 static uint8_t xtime(uint8_t x)
@@ -283,18 +286,18 @@ static uint8_t xtime(uint8_t x)
 }
 
 // MixColumns function mixes the columns of the state matrix
-static void MixColumns(void)
+static void MixColumns(aes_state_t *state)
 {
   uint8_t i;
   uint8_t Tmp,Tm,t;
   for(i = 0; i < 4; ++i)
   {  
-    t   = (*state)[i][0];
-    Tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3] ;
-    Tm  = (*state)[i][0] ^ (*state)[i][1] ; Tm = xtime(Tm);  (*state)[i][0] ^= Tm ^ Tmp ;
-    Tm  = (*state)[i][1] ^ (*state)[i][2] ; Tm = xtime(Tm);  (*state)[i][1] ^= Tm ^ Tmp ;
-    Tm  = (*state)[i][2] ^ (*state)[i][3] ; Tm = xtime(Tm);  (*state)[i][2] ^= Tm ^ Tmp ;
-    Tm  = (*state)[i][3] ^ t ;        Tm = xtime(Tm);  (*state)[i][3] ^= Tm ^ Tmp ;
+    t   = (*state->state)[i][0];
+    Tmp = (*state->state)[i][0] ^ (*state->state)[i][1] ^ (*state->state)[i][2] ^ (*state->state)[i][3] ;
+    Tm  = (*state->state)[i][0] ^ (*state->state)[i][1] ; Tm = xtime(Tm);  (*state->state)[i][0] ^= Tm ^ Tmp ;
+    Tm  = (*state->state)[i][1] ^ (*state->state)[i][2] ; Tm = xtime(Tm);  (*state->state)[i][1] ^= Tm ^ Tmp ;
+    Tm  = (*state->state)[i][2] ^ (*state->state)[i][3] ; Tm = xtime(Tm);  (*state->state)[i][2] ^= Tm ^ Tmp ;
+    Tm  = (*state->state)[i][3] ^ t ;        Tm = xtime(Tm);  (*state->state)[i][3] ^= Tm ^ Tmp ;
   }
 }
 
@@ -321,117 +324,117 @@ static uint8_t Multiply(uint8_t x, uint8_t y)
 // MixColumns function mixes the columns of the state matrix.
 // The method used to multiply may be difficult to understand for the inexperienced.
 // Please use the references to gain more information.
-static void InvMixColumns(void)
+static void InvMixColumns(aes_state_t *state)
 {
   int i;
   uint8_t a,b,c,d;
   for(i=0;i<4;++i)
   { 
-    a = (*state)[i][0];
-    b = (*state)[i][1];
-    c = (*state)[i][2];
-    d = (*state)[i][3];
-
-    (*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09);
-    (*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d);
-    (*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b);
-    (*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e);
+    a = (*state->state)[i][0];
+    b = (*state->state)[i][1];
+    c = (*state->state)[i][2];
+    d = (*state->state)[i][3];
+
+    (*state->state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09);
+    (*state->state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d);
+    (*state->state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b);
+    (*state->state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e);
   }
 }
 
 
 // The SubBytes Function Substitutes the values in the
 // state matrix with values in an S-box.
-static void InvSubBytes(void)
+static void InvSubBytes(aes_state_t *state)
 {
   uint8_t i,j;
   for(i=0;i<4;++i)
   {
     for(j=0;j<4;++j)
     {
-      (*state)[j][i] = getSBoxInvert((*state)[j][i]);
+      (*state->state)[j][i] = getSBoxInvert((*state->state)[j][i]);
     }
   }
 }
 
-static void InvShiftRows(void)
+static void InvShiftRows(aes_state_t *state)
 {
   uint8_t temp;
 
   // Rotate first row 1 columns to right  
-  temp=(*state)[3][1];
-  (*state)[3][1]=(*state)[2][1];
-  (*state)[2][1]=(*state)[1][1];
-  (*state)[1][1]=(*state)[0][1];
-  (*state)[0][1]=temp;
+  temp=(*state->state)[3][1];
+  (*state->state)[3][1]=(*state->state)[2][1];
+  (*state->state)[2][1]=(*state->state)[1][1];
+  (*state->state)[1][1]=(*state->state)[0][1];
+  (*state->state)[0][1]=temp;
 
   // Rotate second row 2 columns to right 
-  temp=(*state)[0][2];
-  (*state)[0][2]=(*state)[2][2];
-  (*state)[2][2]=temp;
+  temp=(*state->state)[0][2];
+  (*state->state)[0][2]=(*state->state)[2][2];
+  (*state->state)[2][2]=temp;
 
-  temp=(*state)[1][2];
-  (*state)[1][2]=(*state)[3][2];
-  (*state)[3][2]=temp;
+  temp=(*state->state)[1][2];
+  (*state->state)[1][2]=(*state->state)[3][2];
+  (*state->state)[3][2]=temp;
 
   // Rotate third row 3 columns to right
-  temp=(*state)[0][3];
-  (*state)[0][3]=(*state)[1][3];
-  (*state)[1][3]=(*state)[2][3];
-  (*state)[2][3]=(*state)[3][3];
-  (*state)[3][3]=temp;
+  temp=(*state->state)[0][3];
+  (*state->state)[0][3]=(*state->state)[1][3];
+  (*state->state)[1][3]=(*state->state)[2][3];
+  (*state->state)[2][3]=(*state->state)[3][3];
+  (*state->state)[3][3]=temp;
 }
 
 
 // Cipher is the main function that encrypts the PlainText.
-static void Cipher(void)
+static void Cipher(aes_state_t *state)
 {
   uint8_t round = 0;
 
   // Add the First round key to the state before starting the rounds.
-  AddRoundKey(0); 
+  AddRoundKey(state, 0);
   
   // There will be Nr rounds.
   // The first Nr-1 rounds are identical.
   // These Nr-1 rounds are executed in the loop below.
   for(round = 1; round < Nr; ++round)
   {
-    SubBytes();
-    ShiftRows();
-    MixColumns();
-    AddRoundKey(round);
+    SubBytes(state);
+    ShiftRows(state);
+    MixColumns(state);
+    AddRoundKey(state, round);
   }
   
   // The last round is given below.
   // The MixColumns function is not here in the last round.
-  SubBytes();
-  ShiftRows();
-  AddRoundKey(Nr);
+  SubBytes(state);
+  ShiftRows(state);
+  AddRoundKey(state, Nr);
 }
 
-static void InvCipher(void)
+static void InvCipher(aes_state_t *state)
 {
   uint8_t round=0;
 
   // Add the First round key to the state before starting the rounds.
-  AddRoundKey(Nr); 
+  AddRoundKey(state, Nr);
 
   // There will be Nr rounds.
   // The first Nr-1 rounds are identical.
   // These Nr-1 rounds are executed in the loop below.
   for(round=Nr-1;round>0;round--)
   {
-    InvShiftRows();
-    InvSubBytes();
-    AddRoundKey(round);
-    InvMixColumns();
+    InvShiftRows(state);
+    InvSubBytes(state);
+    AddRoundKey(state, round);
+    InvMixColumns(state);
   }
   
   // The last round is given below.
   // The MixColumns function is not here in the last round.
-  InvShiftRows();
-  InvSubBytes();
-  AddRoundKey(0);
+  InvShiftRows(state);
+  InvSubBytes(state);
+  AddRoundKey(state, 0);
 }
 
 static void BlockCopy(uint8_t* output, uint8_t* input)
@@ -453,28 +456,30 @@ static void BlockCopy(uint8_t* output, uint8_t* input)
 
 void AES128_ECB_encrypt(uint8_t* input, const uint8_t* key, uint8_t* output)
 {
+  aes_state_t state;
   // Copy input to output, and work in-memory on output
   BlockCopy(output, input);
-  state = (state_t*)output;
+  state.state = (state_t*)output;
 
-  Key = key;
-  KeyExpansion();
+  state.Key = key;
+  KeyExpansion(&state);
 
   // The next function call encrypts the PlainText with the Key using AES algorithm.
-  Cipher();
+  Cipher(&state);
 }
 
 void AES128_ECB_decrypt(uint8_t* input, const uint8_t* key, uint8_t *output)
 {
+  aes_state_t state;
   // Copy input to output, and work in-memory on output
   BlockCopy(output, input);
-  state = (state_t*)output;
+  state.state = (state_t*)output;
 
   // The KeyExpansion routine must be called before encryption.
-  Key = key;
-  KeyExpansion();
+  state.Key = key;
+  KeyExpansion(&state);
 
-  InvCipher();
+  InvCipher(&state);
 }
 
 
@@ -487,12 +492,12 @@ void AES128_ECB_decrypt(uint8_t* input, const uint8_t* key, uint8_t *output)
 #if defined(CBC) && CBC
 
 
-static void XorWithIv(uint8_t* buf)
+static void XorWithIv(aes_state_t *state, uint8_t* buf)
 {
   uint8_t i;
   for(i = 0; i < KEYLEN; ++i)
   {
-    buf[i] ^= Iv[i];
+    buf[i] ^= state->Iv[i];
   }
 }
 
@@ -500,29 +505,30 @@ void AES128_CBC_encrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length,
 {
   intptr_t i;
   uint8_t remainders = length % KEYLEN; /* Remaining bytes in the last non-full block */
+  aes_state_t state;
 
   BlockCopy(output, input);
-  state = (state_t*)output;
+  state.state = (state_t*)output;
 
   // Skip the key expansion if key is passed as 0
   if(0 != key)
   {
-    Key = key;
-    KeyExpansion();
+    state.Key = key;
+    KeyExpansion(&state);
   }
 
   if(iv != 0)
   {
-    Iv = (uint8_t*)iv;
+    state.Iv = (uint8_t*)iv;
   }
 
   for(i = 0; i < length; i += KEYLEN)
   {
-    XorWithIv(input);
+    XorWithIv(&state, input);
     BlockCopy(output, input);
-    state = (state_t*)output;
-    Cipher();
-    Iv = output;
+    state.state = (state_t*)output;
+    Cipher(&state);
+    state.Iv = output;
     input += KEYLEN;
     output += KEYLEN;
   }
@@ -531,8 +537,8 @@ void AES128_CBC_encrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length,
   {
     BlockCopy(output, input);
     memset(output + remainders, 0, KEYLEN - remainders); /* add 0-padding */
-    state = (state_t*)output;
-    Cipher();
+    state.state = (state_t*)output;
+    Cipher(&state);
   }
 }
 
@@ -540,30 +546,31 @@ void AES128_CBC_decrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length,
 {
   intptr_t i;
   uint8_t remainders = length % KEYLEN; /* Remaining bytes in the last non-full block */
+  aes_state_t state;
   
   BlockCopy(output, input);
-  state = (state_t*)output;
+  state.state = (state_t*)output;
 
   // Skip the key expansion if key is passed as 0
   if(0 != key)
   {
-    Key = key;
-    KeyExpansion();
+    state.Key = key;
+    KeyExpansion(&state);
   }
 
   // If iv is passed as 0, we continue to encrypt without re-setting the Iv
   if(iv != 0)
   {
-    Iv = (uint8_t*)iv;
+    state.Iv = (uint8_t*)iv;
   }
 
   for(i = 0; i < length; i += KEYLEN)
   {
     BlockCopy(output, input);
-    state = (state_t*)output;
-    InvCipher();
-    XorWithIv(output);
-    Iv = input;
+    state.state = (state_t*)output;
+    InvCipher(&state);
+    XorWithIv(&state, output);
+    state.Iv = input;
     input += KEYLEN;
     output += KEYLEN;
   }
@@ -572,8 +579,8 @@ void AES128_CBC_decrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length,
   {
     BlockCopy(output, input);
     memset(output+remainders, 0, KEYLEN - remainders); /* add 0-padding */
-    state = (state_t*)output;
-    InvCipher();
+    state.state = (state_t*)output;
+    InvCipher(&state);
   }
 }
 
-- 
2.1.0



More information about the devel mailing list