Signed-off-by: Anton Lundin glance@acc.umu.se --- src/aes.c | 251 ++++++++++++++++++++++++++++++++------------------------------ 1 file changed, 129 insertions(+), 122 deletions(-)
diff --git a/src/aes.c b/src/aes.c index 4c7645f..4315a65 100644 --- a/src/aes.c +++ b/src/aes.c @@ -63,18 +63,21 @@ NOTE: String length must be evenly divisible by 16byte (str_len % 16 == 0) /*****************************************************************************/ // state - array holding the intermediate results during decryption. typedef uint8_t state_t[4][4]; -static state_t* state;
-// The array that stores the round keys. -static uint8_t RoundKey[176]; +typedef struct aes_state_t { + state_t* state;
-// The Key input to the AES Program -static const uint8_t* Key; + // The array that stores the round keys. + uint8_t RoundKey[176]; + + // The Key input to the AES Program + const uint8_t* Key;
#if defined(CBC) && CBC - // Initial Vector used only for CBC mode - static uint8_t* Iv; + // Initial Vector used only for CBC mode + uint8_t* Iv; #endif +} aes_state_t;
// The lookup-tables are marked const so they can be placed in read-only storage instead of RAM // The numbers below can be computed dynamically trading ROM for RAM - @@ -153,7 +156,7 @@ static uint8_t getSBoxInvert(uint8_t num) }
// This function produces Nb(Nr+1) round keys. The round keys are used in each round to decrypt the states. -static void KeyExpansion(void) +static void KeyExpansion(aes_state_t *state) { uint32_t i, j, k; uint8_t tempa[4]; // Used for the column/row operations @@ -161,10 +164,10 @@ static void KeyExpansion(void) // The first round key is the key itself. for(i = 0; i < Nk; ++i) { - RoundKey[(i * 4) + 0] = Key[(i * 4) + 0]; - RoundKey[(i * 4) + 1] = Key[(i * 4) + 1]; - RoundKey[(i * 4) + 2] = Key[(i * 4) + 2]; - RoundKey[(i * 4) + 3] = Key[(i * 4) + 3]; + state->RoundKey[(i * 4) + 0] = state->Key[(i * 4) + 0]; + state->RoundKey[(i * 4) + 1] = state->Key[(i * 4) + 1]; + state->RoundKey[(i * 4) + 2] = state->Key[(i * 4) + 2]; + state->RoundKey[(i * 4) + 3] = state->Key[(i * 4) + 3]; }
// All other round keys are found from the previous round keys. @@ -172,7 +175,7 @@ static void KeyExpansion(void) { for(j = 0; j < 4; ++j) { - tempa[j]=RoundKey[(i-1) * 4 + j]; + tempa[j]=state->RoundKey[(i-1) * 4 + j]; } if (i % Nk == 0) { @@ -211,37 +214,37 @@ static void KeyExpansion(void) tempa[3] = getSBoxValue(tempa[3]); } } - RoundKey[i * 4 + 0] = RoundKey[(i - Nk) * 4 + 0] ^ tempa[0]; - RoundKey[i * 4 + 1] = RoundKey[(i - Nk) * 4 + 1] ^ tempa[1]; - RoundKey[i * 4 + 2] = RoundKey[(i - Nk) * 4 + 2] ^ tempa[2]; - RoundKey[i * 4 + 3] = RoundKey[(i - Nk) * 4 + 3] ^ tempa[3]; + state->RoundKey[i * 4 + 0] = state->RoundKey[(i - Nk) * 4 + 0] ^ tempa[0]; + state->RoundKey[i * 4 + 1] = state->RoundKey[(i - Nk) * 4 + 1] ^ tempa[1]; + state->RoundKey[i * 4 + 2] = state->RoundKey[(i - Nk) * 4 + 2] ^ tempa[2]; + state->RoundKey[i * 4 + 3] = state->RoundKey[(i - Nk) * 4 + 3] ^ tempa[3]; } }
// This function adds the round key to state. // The round key is added to the state by an XOR function. -static void AddRoundKey(uint8_t round) +static void AddRoundKey(aes_state_t *state, uint8_t round) { uint8_t i,j; for(i=0;i<4;++i) { for(j = 0; j < 4; ++j) { - (*state)[i][j] ^= RoundKey[round * Nb * 4 + i * Nb + j]; + (*state->state)[i][j] ^= state->RoundKey[round * Nb * 4 + i * Nb + j]; } } }
// The SubBytes Function Substitutes the values in the // state matrix with values in an S-box. -static void SubBytes(void) +static void SubBytes(aes_state_t *state) { uint8_t i, j; for(i = 0; i < 4; ++i) { for(j = 0; j < 4; ++j) { - (*state)[j][i] = getSBoxValue((*state)[j][i]); + (*state->state)[j][i] = getSBoxValue((*state->state)[j][i]); } } } @@ -249,32 +252,32 @@ static void SubBytes(void) // The ShiftRows() function shifts the rows in the state to the left. // Each row is shifted with different offset. // Offset = Row number. So the first row is not shifted. -static void ShiftRows(void) +static void ShiftRows(aes_state_t *state) { uint8_t temp;
// Rotate first row 1 columns to left - temp = (*state)[0][1]; - (*state)[0][1] = (*state)[1][1]; - (*state)[1][1] = (*state)[2][1]; - (*state)[2][1] = (*state)[3][1]; - (*state)[3][1] = temp; + temp = (*state->state)[0][1]; + (*state->state)[0][1] = (*state->state)[1][1]; + (*state->state)[1][1] = (*state->state)[2][1]; + (*state->state)[2][1] = (*state->state)[3][1]; + (*state->state)[3][1] = temp;
// Rotate second row 2 columns to left - temp = (*state)[0][2]; - (*state)[0][2] = (*state)[2][2]; - (*state)[2][2] = temp; + temp = (*state->state)[0][2]; + (*state->state)[0][2] = (*state->state)[2][2]; + (*state->state)[2][2] = temp;
- temp = (*state)[1][2]; - (*state)[1][2] = (*state)[3][2]; - (*state)[3][2] = temp; + temp = (*state->state)[1][2]; + (*state->state)[1][2] = (*state->state)[3][2]; + (*state->state)[3][2] = temp;
// Rotate third row 3 columns to left - temp = (*state)[0][3]; - (*state)[0][3] = (*state)[3][3]; - (*state)[3][3] = (*state)[2][3]; - (*state)[2][3] = (*state)[1][3]; - (*state)[1][3] = temp; + temp = (*state->state)[0][3]; + (*state->state)[0][3] = (*state->state)[3][3]; + (*state->state)[3][3] = (*state->state)[2][3]; + (*state->state)[2][3] = (*state->state)[1][3]; + (*state->state)[1][3] = temp; }
static uint8_t xtime(uint8_t x) @@ -283,18 +286,18 @@ static uint8_t xtime(uint8_t x) }
// MixColumns function mixes the columns of the state matrix -static void MixColumns(void) +static void MixColumns(aes_state_t *state) { uint8_t i; uint8_t Tmp,Tm,t; for(i = 0; i < 4; ++i) { - t = (*state)[i][0]; - Tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3] ; - Tm = (*state)[i][0] ^ (*state)[i][1] ; Tm = xtime(Tm); (*state)[i][0] ^= Tm ^ Tmp ; - Tm = (*state)[i][1] ^ (*state)[i][2] ; Tm = xtime(Tm); (*state)[i][1] ^= Tm ^ Tmp ; - Tm = (*state)[i][2] ^ (*state)[i][3] ; Tm = xtime(Tm); (*state)[i][2] ^= Tm ^ Tmp ; - Tm = (*state)[i][3] ^ t ; Tm = xtime(Tm); (*state)[i][3] ^= Tm ^ Tmp ; + t = (*state->state)[i][0]; + Tmp = (*state->state)[i][0] ^ (*state->state)[i][1] ^ (*state->state)[i][2] ^ (*state->state)[i][3] ; + Tm = (*state->state)[i][0] ^ (*state->state)[i][1] ; Tm = xtime(Tm); (*state->state)[i][0] ^= Tm ^ Tmp ; + Tm = (*state->state)[i][1] ^ (*state->state)[i][2] ; Tm = xtime(Tm); (*state->state)[i][1] ^= Tm ^ Tmp ; + Tm = (*state->state)[i][2] ^ (*state->state)[i][3] ; Tm = xtime(Tm); (*state->state)[i][2] ^= Tm ^ Tmp ; + Tm = (*state->state)[i][3] ^ t ; Tm = xtime(Tm); (*state->state)[i][3] ^= Tm ^ Tmp ; } }
@@ -321,117 +324,117 @@ static uint8_t Multiply(uint8_t x, uint8_t y) // MixColumns function mixes the columns of the state matrix. // The method used to multiply may be difficult to understand for the inexperienced. // Please use the references to gain more information. -static void InvMixColumns(void) +static void InvMixColumns(aes_state_t *state) { int i; uint8_t a,b,c,d; for(i=0;i<4;++i) { - a = (*state)[i][0]; - b = (*state)[i][1]; - c = (*state)[i][2]; - d = (*state)[i][3]; - - (*state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09); - (*state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d); - (*state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b); - (*state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e); + a = (*state->state)[i][0]; + b = (*state->state)[i][1]; + c = (*state->state)[i][2]; + d = (*state->state)[i][3]; + + (*state->state)[i][0] = Multiply(a, 0x0e) ^ Multiply(b, 0x0b) ^ Multiply(c, 0x0d) ^ Multiply(d, 0x09); + (*state->state)[i][1] = Multiply(a, 0x09) ^ Multiply(b, 0x0e) ^ Multiply(c, 0x0b) ^ Multiply(d, 0x0d); + (*state->state)[i][2] = Multiply(a, 0x0d) ^ Multiply(b, 0x09) ^ Multiply(c, 0x0e) ^ Multiply(d, 0x0b); + (*state->state)[i][3] = Multiply(a, 0x0b) ^ Multiply(b, 0x0d) ^ Multiply(c, 0x09) ^ Multiply(d, 0x0e); } }
// The SubBytes Function Substitutes the values in the // state matrix with values in an S-box. -static void InvSubBytes(void) +static void InvSubBytes(aes_state_t *state) { uint8_t i,j; for(i=0;i<4;++i) { for(j=0;j<4;++j) { - (*state)[j][i] = getSBoxInvert((*state)[j][i]); + (*state->state)[j][i] = getSBoxInvert((*state->state)[j][i]); } } }
-static void InvShiftRows(void) +static void InvShiftRows(aes_state_t *state) { uint8_t temp;
// Rotate first row 1 columns to right - temp=(*state)[3][1]; - (*state)[3][1]=(*state)[2][1]; - (*state)[2][1]=(*state)[1][1]; - (*state)[1][1]=(*state)[0][1]; - (*state)[0][1]=temp; + temp=(*state->state)[3][1]; + (*state->state)[3][1]=(*state->state)[2][1]; + (*state->state)[2][1]=(*state->state)[1][1]; + (*state->state)[1][1]=(*state->state)[0][1]; + (*state->state)[0][1]=temp;
// Rotate second row 2 columns to right - temp=(*state)[0][2]; - (*state)[0][2]=(*state)[2][2]; - (*state)[2][2]=temp; + temp=(*state->state)[0][2]; + (*state->state)[0][2]=(*state->state)[2][2]; + (*state->state)[2][2]=temp;
- temp=(*state)[1][2]; - (*state)[1][2]=(*state)[3][2]; - (*state)[3][2]=temp; + temp=(*state->state)[1][2]; + (*state->state)[1][2]=(*state->state)[3][2]; + (*state->state)[3][2]=temp;
// Rotate third row 3 columns to right - temp=(*state)[0][3]; - (*state)[0][3]=(*state)[1][3]; - (*state)[1][3]=(*state)[2][3]; - (*state)[2][3]=(*state)[3][3]; - (*state)[3][3]=temp; + temp=(*state->state)[0][3]; + (*state->state)[0][3]=(*state->state)[1][3]; + (*state->state)[1][3]=(*state->state)[2][3]; + (*state->state)[2][3]=(*state->state)[3][3]; + (*state->state)[3][3]=temp; }
// Cipher is the main function that encrypts the PlainText. -static void Cipher(void) +static void Cipher(aes_state_t *state) { uint8_t round = 0;
// Add the First round key to the state before starting the rounds. - AddRoundKey(0); + AddRoundKey(state, 0);
// There will be Nr rounds. // The first Nr-1 rounds are identical. // These Nr-1 rounds are executed in the loop below. for(round = 1; round < Nr; ++round) { - SubBytes(); - ShiftRows(); - MixColumns(); - AddRoundKey(round); + SubBytes(state); + ShiftRows(state); + MixColumns(state); + AddRoundKey(state, round); }
// The last round is given below. // The MixColumns function is not here in the last round. - SubBytes(); - ShiftRows(); - AddRoundKey(Nr); + SubBytes(state); + ShiftRows(state); + AddRoundKey(state, Nr); }
-static void InvCipher(void) +static void InvCipher(aes_state_t *state) { uint8_t round=0;
// Add the First round key to the state before starting the rounds. - AddRoundKey(Nr); + AddRoundKey(state, Nr);
// There will be Nr rounds. // The first Nr-1 rounds are identical. // These Nr-1 rounds are executed in the loop below. for(round=Nr-1;round>0;round--) { - InvShiftRows(); - InvSubBytes(); - AddRoundKey(round); - InvMixColumns(); + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(state, round); + InvMixColumns(state); }
// The last round is given below. // The MixColumns function is not here in the last round. - InvShiftRows(); - InvSubBytes(); - AddRoundKey(0); + InvShiftRows(state); + InvSubBytes(state); + AddRoundKey(state, 0); }
static void BlockCopy(uint8_t* output, uint8_t* input) @@ -453,28 +456,30 @@ static void BlockCopy(uint8_t* output, uint8_t* input)
void AES128_ECB_encrypt(uint8_t* input, const uint8_t* key, uint8_t* output) { + aes_state_t state; // Copy input to output, and work in-memory on output BlockCopy(output, input); - state = (state_t*)output; + state.state = (state_t*)output;
- Key = key; - KeyExpansion(); + state.Key = key; + KeyExpansion(&state);
// The next function call encrypts the PlainText with the Key using AES algorithm. - Cipher(); + Cipher(&state); }
void AES128_ECB_decrypt(uint8_t* input, const uint8_t* key, uint8_t *output) { + aes_state_t state; // Copy input to output, and work in-memory on output BlockCopy(output, input); - state = (state_t*)output; + state.state = (state_t*)output;
// The KeyExpansion routine must be called before encryption. - Key = key; - KeyExpansion(); + state.Key = key; + KeyExpansion(&state);
- InvCipher(); + InvCipher(&state); }
@@ -487,12 +492,12 @@ void AES128_ECB_decrypt(uint8_t* input, const uint8_t* key, uint8_t *output) #if defined(CBC) && CBC
-static void XorWithIv(uint8_t* buf) +static void XorWithIv(aes_state_t *state, uint8_t* buf) { uint8_t i; for(i = 0; i < KEYLEN; ++i) { - buf[i] ^= Iv[i]; + buf[i] ^= state->Iv[i]; } }
@@ -500,29 +505,30 @@ void AES128_CBC_encrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length, { intptr_t i; uint8_t remainders = length % KEYLEN; /* Remaining bytes in the last non-full block */ + aes_state_t state;
BlockCopy(output, input); - state = (state_t*)output; + state.state = (state_t*)output;
// Skip the key expansion if key is passed as 0 if(0 != key) { - Key = key; - KeyExpansion(); + state.Key = key; + KeyExpansion(&state); }
if(iv != 0) { - Iv = (uint8_t*)iv; + state.Iv = (uint8_t*)iv; }
for(i = 0; i < length; i += KEYLEN) { - XorWithIv(input); + XorWithIv(&state, input); BlockCopy(output, input); - state = (state_t*)output; - Cipher(); - Iv = output; + state.state = (state_t*)output; + Cipher(&state); + state.Iv = output; input += KEYLEN; output += KEYLEN; } @@ -531,8 +537,8 @@ void AES128_CBC_encrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length, { BlockCopy(output, input); memset(output + remainders, 0, KEYLEN - remainders); /* add 0-padding */ - state = (state_t*)output; - Cipher(); + state.state = (state_t*)output; + Cipher(&state); } }
@@ -540,30 +546,31 @@ void AES128_CBC_decrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length, { intptr_t i; uint8_t remainders = length % KEYLEN; /* Remaining bytes in the last non-full block */ + aes_state_t state;
BlockCopy(output, input); - state = (state_t*)output; + state.state = (state_t*)output;
// Skip the key expansion if key is passed as 0 if(0 != key) { - Key = key; - KeyExpansion(); + state.Key = key; + KeyExpansion(&state); }
// If iv is passed as 0, we continue to encrypt without re-setting the Iv if(iv != 0) { - Iv = (uint8_t*)iv; + state.Iv = (uint8_t*)iv; }
for(i = 0; i < length; i += KEYLEN) { BlockCopy(output, input); - state = (state_t*)output; - InvCipher(); - XorWithIv(output); - Iv = input; + state.state = (state_t*)output; + InvCipher(&state); + XorWithIv(&state, output); + state.Iv = input; input += KEYLEN; output += KEYLEN; } @@ -572,8 +579,8 @@ void AES128_CBC_decrypt_buffer(uint8_t* output, uint8_t* input, uint32_t length, { BlockCopy(output, input); memset(output+remainders, 0, KEYLEN - remainders); /* add 0-padding */ - state = (state_t*)output; - InvCipher(); + state.state = (state_t*)output; + InvCipher(&state); } }