This library provides a way to easily handle arbitrary large integers.

This library provides the following operations :

  • addition, substraction, multiplication, division and modulo
  • bits operators (AND, OR, XOR, left and right shifts)
  • boolean operators
  • modular exponentiation (using montgomery algorithm)
  • modular inverse

Example

In this example, we use a 1024 bits long RSA key to encrypt and decrypt a message. We first encrypt the value 0x41 (65 in decimal) and then decrypt it. At the end, m should be equal to 0x41. The encryption is fast (0, 4 second) while the decryption is really slow. This code will take between 30 seconds and 2 minutes to execute depending on the compiler and optimization flags.

main.cpp

#include "mbed.h"
#include "BigInt.h"
#include <stdlib.h>
#include <stdio.h>

uint8_t modbits[] = {
0xd9, 0x4d, 0x88, 0x9e, 0x88, 0x85, 0x3d, 0xd8, 0x97, 0x69, 0xa1, 0x80, 0x15, 0xa0, 0xa2, 0xe6,
0xbf, 0x82, 0xbf, 0x35, 0x6f, 0xe1, 0x4f, 0x25, 0x1f, 0xb4, 0xf5, 0xe2, 0xdf, 0x0d, 0x9f, 0x9a,
0x94, 0xa6, 0x8a, 0x30, 0xc4, 0x28, 0xb3, 0x9e, 0x33, 0x62, 0xfb, 0x37, 0x79, 0xa4, 0x97, 0xec,
0xea, 0xea, 0x37, 0x10, 0x0f, 0x26, 0x4d, 0x7f, 0xb9, 0xfb, 0x1a, 0x97, 0xfb, 0xf6, 0x21, 0x13,
0x3d, 0xe5, 0x5f, 0xdc, 0xb9, 0xb1, 0xad, 0x0d, 0x7a, 0x31, 0xb3, 0x79, 0x21, 0x6d, 0x79, 0x25,
0x2f, 0x5c, 0x52, 0x7b, 0x9b, 0xc6, 0x3d, 0x83, 0xd4, 0xec, 0xf4, 0xd1, 0xd4, 0x5c, 0xbf, 0x84,
0x3e, 0x84, 0x74, 0xba, 0xbc, 0x65, 0x5e, 0x9b, 0xb6, 0x79, 0x9c, 0xba, 0x77, 0xa4, 0x7e, 0xaf,
0xa8, 0x38, 0x29, 0x64, 0x74, 0xaf, 0xc2, 0x4b, 0xeb, 0x9c, 0x82, 0x5b, 0x73, 0xeb, 0xf5, 0x49
};

uint8_t dbits[] = {
0x04, 0x7b, 0x9c, 0xfd, 0xe8, 0x43, 0x17, 0x6b, 0x88, 0x74, 0x1d, 0x68, 0xcf, 0x09, 0x69, 0x52,
0xe9, 0x50, 0x81, 0x31, 0x51, 0x05, 0x8c, 0xe4, 0x6f, 0x2b, 0x04, 0x87, 0x91, 0xa2, 0x6e, 0x50,
0x7a, 0x10, 0x95, 0x79, 0x3c, 0x12, 0xba, 0xe1, 0xe0, 0x9d, 0x82, 0x21, 0x3a, 0xd9, 0x32, 0x69,
0x28, 0xcf, 0x7c, 0x23, 0x50, 0xac, 0xb1, 0x9c, 0x98, 0xf1, 0x9d, 0x32, 0xd5, 0x77, 0xd6, 0x66,
0xcd, 0x7b, 0xb8, 0xb2, 0xb5, 0xba, 0x62, 0x9d, 0x25, 0xcc, 0xf7, 0x2a, 0x5c, 0xeb, 0x8a, 0x8d,
0xa0, 0x38, 0x90, 0x6c, 0x84, 0xdc, 0xdb, 0x1f, 0xe6, 0x77, 0xdf, 0xfb, 0x2c, 0x02, 0x9f, 0xd8,
0x92, 0x63, 0x18, 0xee, 0xde, 0x1b, 0x58, 0x27, 0x2a, 0xf2, 0x2b, 0xda, 0x5c, 0x52, 0x32, 0xbe,
0x06, 0x68, 0x39, 0x39, 0x8e, 0x42, 0xf5, 0x35, 0x2d, 0xf5, 0x88, 0x48, 0xad, 0xad, 0x11, 0xa1
};

int main() 
{
    BigInt e = 65537, mod, d;
    mod.importData(modbits, sizeof(modbits));
    d.importData(dbits, sizeof(dbits));

    BigInt c = modPow(0x41,e,mod);
    c.print();
    BigInt m = modPow(c,d,mod);
    m.print();
    printf("done\n");
    
    return 0;
}

Revision:
26:94e26bcd229d
Parent:
25:3d5c1f299da2
--- a/BigInt.cpp	Sun Apr 13 07:35:47 2014 +0000
+++ b/BigInt.cpp	Sun May 11 10:33:20 2014 +0000
@@ -27,13 +27,22 @@
 
 
 BigInt::BigInt():
+sign(POS),
 size(0),
 bits(0)
 {
 }
 
-BigInt::BigInt(const uint32_t a)
+BigInt::BigInt(int32_t a)
 {
+    if(a < 0) 
+    {
+        a = -a;
+        sign = NEG;
+    }
+    else
+        sign = POS;
+
     if(a >> 24)
         size = 4;
     else if(a >> 16)
@@ -47,6 +56,7 @@
 }
 
 BigInt::BigInt(const BigInt &a):
+sign(a.sign),
 size(a.size)
 {
     uint32_t l = num(size);
@@ -66,6 +76,7 @@
 
 BigInt& BigInt::operator=(const BigInt& a)
 {
+    sign = a.sign;
     size = a.size;
     uint32_t l = num(size);
     if(bits)
@@ -76,8 +87,9 @@
     return *this;
 }
 
-void BigInt::importData(uint8_t *data, uint32_t length)
+void BigInt::importData(uint8_t *data, uint32_t length, bool sign)
 {
+    this->sign = sign;
     size = length;
     if(bits)
         delete[] bits;
@@ -89,12 +101,14 @@
     trim();
 }
 
-void BigInt::exportData(uint8_t *data, uint32_t length)
+void BigInt::exportData(uint8_t *data, uint32_t length, bool &sign)
 {
     assert(isValid() && data != 0);
     
     if(length < size)
         return;
+        
+    sign = this->sign;
     uint32_t offset = length-size;
     memset(data, 0, offset);
     for(int i = size-1; i >= 0; --i)
@@ -105,29 +119,14 @@
 {
     assert(a.isValid() && b.isValid());
 
-    BigInt result;
-        
-    result.size = std::max(a.size, b.size) + 1;
-    size_t l = num(result.size);
-    result.bits = new uint32_t[l];
-    memset(result.bits, 0, sizeof(uint32_t)*l);
-    uint32_t al = num(a.size);
-    uint32_t bl = num(b.size);
-    uint32_t carry = 0;
-    for(int i = 0; i < (int)l; ++i)
-    {
-        uint32_t tmpA = 0, tmpB = 0;
-        if(i < (int)al)
-            tmpA = a.bits[i];
-        if(i < (int)bl)
-            tmpB = b.bits[i];
-        result.bits[i] = tmpA + tmpB + carry;
-        carry = result.bits[i] < std::max(tmpA, tmpB);
-    }
-
-    result.trim();
-
-    return result;
+    if(a.sign == POS && b.sign == POS)      // a+b
+        return add(a, b);
+    if(a.sign == NEG && b.sign == NEG)      // (-a)+(-b) = -(a+b)
+        return -add(a, b);
+    else if(a.sign == POS)  // a + (-b) = a-b
+        return a - (-b);
+    else                    // (-a) + b = b-a
+        return b - (-a);
 }
 
 BigInt& BigInt::operator+=(const BigInt &b)
@@ -147,49 +146,33 @@
     return t;
 }
 
-// a - b, if b >= a, returns 0
-// No negative number allowed
 BigInt operator-(const BigInt& a, const BigInt& b)
 {
     assert(a.isValid() && b.isValid());
-
-    if(b >= a)
-        return 0;
-
-    BigInt result;
-    result.size = a.size;
-    uint32_t l = num(a.size);
-    result.bits = new uint32_t[l];
-    memset(result.bits, 0, sizeof(uint32_t)*l);
-    uint32_t bl = num(b.size);
-    uint8_t borrow = 0;
-    for(uint32_t i = 0; i < l; ++i)
+    
+    if(a.sign == POS && b.sign == POS)
     {
-        uint32_t tmpA = a.bits[i], tmpB = 0;
-        if(i < bl)
-            tmpB = b.bits[i];
-            
-        if(borrow)  
-        {
-            if(tmpA == 0)
-                tmpA = 0xFFFFFFFF;
-            else
-            {
-                --tmpA;
-                borrow = 0;
-            }
-        }
-        if(tmpA >= tmpB)
-            result.bits[i] = tmpA - tmpB;
-        else 
-        {
-            result.bits[i] = 0xFFFFFFFF - tmpB;
-            result.bits[i] += tmpA + 1;
-            borrow = 1;
-        }
+        if(equals(a, b))
+            return 0;
+        else if(greater(a, b))
+            return sub(a, b);
+        else
+            return -sub(b, a);
     }
-    result.trim();
-        
+    else if(a.sign == NEG && b.sign == NEG)
+        return (-b) - (-a);
+    else if(a.sign == NEG && b.sign == POS)
+        return -add(a, b);
+    else 
+        return add(a, b);
+}
+
+BigInt operator-(const BigInt &a)
+{
+    assert(a.isValid());
+    
+    BigInt result = a;
+    result.sign = !a.sign;
     return result;
 }
 
@@ -217,37 +200,19 @@
     // if a == 0 or b == 0 then result = 0
     if(!a || !b)
         return 0;
-    
-    // if a == 1, then result = b
-    if(a == 1)
-        return b;
-    
-    // if b == 1, then result = a
-    if(b == 1)
-        return a;
-    
-    BigInt result;          
-    result.size = a.size + b.size;
-    result.bits = new uint32_t[num(result.size)];
-    memset(result.bits, 0, sizeof(uint32_t)*num(result.size));
-    for(int i = 0; i < (int)num(a.size); ++i)
-    {
-        uint64_t carry = 0;
-        for(int j = 0; j < (int)num(b.size); ++j)
-        {
-            uint64_t tmp = (uint64_t)a.bits[i] * (uint64_t)b.bits[j] + carry;        
-            uint32_t t = result.bits[i+j];
-            result.bits[i+j] += tmp;
-            carry = tmp >> 32;     
-            if(t > result.bits[i+j])
-                ++carry;                    
-        }
-        if(carry != 0)
-            result.bits[i+num(b.size)] += carry;
-    }
-    
-    result.trim();
-    
+
+    BigInt result;
+    if(equals(a, 1))
+        result = b;
+    else if(equals(b, 1))
+        result = a;
+    else
+        result = mul(a, b);
+
+    if(a.sign == b.sign)
+        result.sign = POS;
+    else
+        result.sign = NEG;
     return result;
 }
 
@@ -256,38 +221,28 @@
     return (*this = (*this) * b);
 }
 
-
 BigInt operator/(const BigInt &a, const BigInt &b)
 {
     assert(a.isValid() && b.isValid() && b != 0);
     
-    if(b == 1)
-        return a;
-    if(a < b)
+    if(lesser(a, b))
         return 0;
-    if(a == b)
-        return 1;
-    BigInt u = a; 
-    const uint32_t m = a.numBits() - b.numBits();
 
-    BigInt q;
-    q.size = m/8 + 1;
-    q.bits = new uint32_t[num(q.size)];
-    memset(q.bits, 0, num(q.size)*sizeof(uint32_t));
-    BigInt tmp = b;
-    tmp <<= m;
-    for(int j = m; j >= 0; --j)
-    {
-        if(tmp <= u)
-        {   
-            u -= tmp;
-            q.bits[j/32] |= BITS[j%32]; 
-        }
-        tmp >>= 1;
-    }
-    q.trim();
-
-    return q;
+    BigInt result;
+    
+    if(equals(a, b))
+        result = 1;
+    else if(equals(b, 1))
+        result = a;
+    else
+        result = div(a, b);
+        
+    if(a.sign == b.sign)
+        result.sign = POS;
+    else
+        result.sign = NEG;
+        
+    return result;
 }
 
 BigInt& BigInt::operator/=(const BigInt &b)
@@ -318,7 +273,8 @@
     }
     
     result.trim();
-        
+    result.checkZero();
+            
     return result;
 }
 
@@ -331,11 +287,11 @@
 {
     assert(a.isValid());
 
+    if(m == 0)
+        return a;
+
     BigInt result;
     
-    if(m == 0)
-        return result = a;
-
     result.size = m/8 + a.size;
     if((m%32)%8 != 0)
         ++result.size;
@@ -346,8 +302,8 @@
     result.bits[m/32] = a.bits[0] << s;
     for(uint32_t i = 1; i < num(a.size); ++i)
         result.bits[m/32+i] = (a.bits[i] << s) | (a.bits[i-1] >> (32-s));
-    if(s != 0)
-        result.bits[num(result.size)-1] = a.bits[num(a.size)-1] >> (32-s);
+    if(s != 0 && num(result.size) != 1)
+        result.bits[num(result.size)-1] |= a.bits[num(a.size)-1] >> (32-s);
 
     result.trim();
     
@@ -362,7 +318,10 @@
 BigInt operator%(const BigInt &a, const BigInt &b)
 {
     assert(a.isValid() && b.isValid() && b > 0);
-    
+    if(a < b)
+        return a;
+    if(a == b)
+        return 0;
     return a - (a/b)*b;
 }
 
@@ -375,14 +334,7 @@
 {
     assert(a.isValid() && b.isValid());
 
-    if(a.size != b.size)
-        return false;
-        
-    uint32_t l = num(a.size);
-    for(int i = 0; i < (int)l; ++i)
-        if(a.bits[i] != b.bits[i])
-            return false;
-    return true;
+    return a.sign == b.sign && equals(a, b);
 }
 
 bool operator!=(const BigInt &a, const BigInt &b)
@@ -393,20 +345,15 @@
 bool operator<(const BigInt &a, const BigInt &b)
 {
     assert(a.isValid() && b.isValid());
-
-    if(a.size < b.size)
+    
+    if(a.sign == NEG && b.sign == NEG)
+        return !lesser(a, b);
+    else if(a.sign == NEG && b.sign == POS)
         return true;
-    if(a.size > b.size)
+    else if(a.sign == POS && b.sign == NEG)
         return false;
-    uint32_t l = num(a.size);
-    for(int i = l-1; i >= 0; --i)
-    {
-        if(a.bits[i] < b.bits[i])
-            return true;
-        else if(a.bits[i] > b.bits[i])
-            return false;
-    }
-    return false;
+    else
+        return lesser(a, b);
 }
 
 bool operator<=(const BigInt &a, const BigInt &b)
@@ -418,19 +365,14 @@
 {
     assert(a.isValid() && b.isValid());
 
-    if(a.size > b.size)
-        return true;
-    if(a.size < b.size)
+    if(a.sign == NEG && b.sign == NEG)
+        return !greater(a, b);
+    else if(a.sign == NEG && b.sign == POS)
         return false;
-    uint32_t l = num(a.size);
-    for(int i = l-1; i >= 0; --i)
-    {
-        if(a.bits[i] > b.bits[i])
-            return true;
-        else if(a.bits[i] < b.bits[i])
-            return false;       
-    }
-    return false;
+    else if(a.sign == POS && b.sign == NEG)
+        return true;
+    else
+        return greater(a, b);
 }
 
 bool operator>=(const BigInt &a, const BigInt &b)
@@ -544,14 +486,14 @@
     while(r > 0)
     {
         if(a.bits[j/32] & BITS[j%32])
-            result.add(b);
+            result.fastAdd(b);
         
         if(result.bits[0] & BITS[0])
-            result.add(m);
+            result.fastAdd(m);
      
         ++j; 
         --r;
-        result.shr();
+        result.fastShr();
     }
     
     if(result >= m)
@@ -602,6 +544,36 @@
     return montgomeryStep(tmp, 1, modulus, r);
 }
 
+// Implementation as described in FIPS.186-4, Appendix C.1
+BigInt invMod(const BigInt &a, const BigInt &modulus)
+{
+    assert(a.isValid() && modulus.isValid() && 0 < a && a < modulus);
+    
+    BigInt i = modulus;
+    BigInt j = a;
+    BigInt y2 = 0;
+    BigInt y1 = 1;
+    do
+    {
+        BigInt quotient = i / j;
+        BigInt remainder = i - (j * quotient);
+        BigInt y = y2 - (y1 * quotient);
+        i = j; 
+        j = remainder; 
+        y2 = y1;
+        y1 = y;
+    }while(j > 0);
+    
+
+    assert(i == 1);
+    
+    y2 %= modulus;
+    if(y2 < 0)
+        y2 += modulus;
+        
+    return y2;
+}
+
 bool BigInt::isValid() const
 {
     return size != 0 && bits != 0;
@@ -613,12 +585,184 @@
     
     printf("size: %lu bytes\n", size);
     uint32_t n = num(size);
+    if(sign == NEG)
+        printf("- ");
     for(int i = n-1; i >= 0; --i)
         printf("%08x ", (int)bits[i]);
     printf("\n");
 }
 
-void BigInt::add(const BigInt &b)
+// return a + b
+BigInt add(const BigInt &a, const BigInt &b)
+{
+    BigInt result;
+        
+    result.size = std::max(a.size, b.size) + 1;
+    size_t l = num(result.size);
+    result.bits = new uint32_t[l];
+    memset(result.bits, 0, sizeof(uint32_t)*l);
+    uint32_t al = num(a.size);
+    uint32_t bl = num(b.size);
+    uint32_t carry = 0;
+    for(int i = 0; i < (int)l; ++i)
+    {
+        uint32_t tmpA = 0, tmpB = 0;
+        if(i < (int)al)
+            tmpA = a.bits[i];
+        if(i < (int)bl)
+            tmpB = b.bits[i];
+        result.bits[i] = tmpA + tmpB + carry;
+        carry = result.bits[i] < std::max(tmpA, tmpB);
+    }
+
+    result.trim();
+    result.checkZero();
+    
+    return result;
+}
+
+// return a - b
+// Assume that a > b
+BigInt sub(const BigInt &a, const BigInt &b)
+{
+    BigInt result;
+    result.size = a.size;
+    uint32_t l = num(a.size);
+    result.bits = new uint32_t[l];
+    memset(result.bits, 0, sizeof(uint32_t)*l);
+    uint32_t bl = num(b.size);
+    uint8_t borrow = 0;
+    for(uint32_t i = 0; i < l; ++i)
+    {
+        uint32_t tmpA = a.bits[i], tmpB = 0;
+        if(i < bl)
+            tmpB = b.bits[i];
+            
+        if(borrow)  
+        {
+            if(tmpA == 0)
+                tmpA = 0xFFFFFFFF;
+            else
+            {
+                --tmpA;
+                borrow = 0;
+            }
+        }
+        if(tmpA >= tmpB)
+            result.bits[i] = tmpA - tmpB;
+        else 
+        {
+            result.bits[i] = 0xFFFFFFFF - tmpB;
+            result.bits[i] += tmpA + 1;
+            borrow = 1;
+        }
+    }
+    result.trim();
+    result.checkZero();
+    
+    return result;  
+}
+
+BigInt mul(const BigInt &a, const BigInt &b)
+{
+    BigInt result;          
+    result.size = a.size + b.size;
+    result.bits = new uint32_t[num(result.size)];
+    memset(result.bits, 0, sizeof(uint32_t)*num(result.size));
+    for(int i = 0; i < (int)num(a.size); ++i)
+    {
+        uint64_t carry = 0;
+        for(int j = 0; j < (int)num(b.size); ++j)
+        {
+            uint64_t tmp = (uint64_t)a.bits[i] * (uint64_t)b.bits[j] + carry;        
+            uint32_t t = result.bits[i+j];
+            result.bits[i+j] += tmp;
+            carry = tmp >> 32;     
+            if(t > result.bits[i+j])
+                ++carry;                    
+        }
+        if(carry != 0)
+            result.bits[i+num(b.size)] += carry;
+    }
+    
+    result.trim();
+
+    return result;
+}
+
+BigInt div(const BigInt &a, const BigInt &b)
+{
+    BigInt u = a; 
+    const uint32_t m = a.numBits() - b.numBits();
+    BigInt q;
+    q.size = m/8 + 1;
+    q.bits = new uint32_t[num(q.size)];
+    memset(q.bits, 0, num(q.size)*sizeof(uint32_t));
+    BigInt tmp = b;
+    tmp <<= m;
+    for(int j = m; j >= 0; --j)
+    {
+        if(tmp <= u)
+        {   
+            u -= tmp;
+            q.bits[j/32] |= BITS[j%32]; 
+        }
+        tmp >>= 1;
+    }
+    q.trim();
+
+    return q;
+}
+
+bool equals(const BigInt &a, const BigInt &b)
+{
+    if(a.size != b.size)
+        return false;
+        
+    uint32_t l = num(a.size);
+    for(int i = 0; i < (int)l; ++i)
+        if(a.bits[i] != b.bits[i])
+            return false;
+    return true;
+}
+
+bool lesser(const BigInt &a, const BigInt &b)
+{
+    if(a.size < b.size)
+        return true;
+    if(a.size > b.size)
+        return false;
+        
+    uint32_t l = num(a.size);
+    for(int i = l-1; i >= 0; --i)
+    {
+        if(a.bits[i] < b.bits[i])
+            return true;
+        else if(a.bits[i] > b.bits[i])
+            return false;
+    }
+    return false;
+}
+        
+bool greater(const BigInt &a, const BigInt &b)
+{
+    if(a.size > b.size)
+        return true;
+    if(a.size < b.size)
+        return false;
+    uint32_t l = num(a.size);
+    for(int i = l-1; i >= 0; --i)
+    {
+        if(a.bits[i] > b.bits[i])
+            return true;
+        else if(a.bits[i] < b.bits[i])
+            return false;       
+    }
+    return false;
+}
+
+
+void BigInt::fastAdd(const BigInt &b)
 {
     uint32_t al = num(size);
     uint32_t bl = num(b.size);
@@ -647,7 +791,7 @@
     trim();
 }
 
-void BigInt::shr()
+void BigInt::fastShr()
 {
     uint32_t lastBit = 0;
     uint32_t tmp;
@@ -665,6 +809,7 @@
 void BigInt::trim()
 {
     assert(isValid());
+
     
     uint8_t *tmp = (uint8_t*)bits;
     uint32_t newSize = size;
@@ -697,4 +842,13 @@
     n += tmp2;
 
     return n;
-}
\ No newline at end of file
+}
+
+// Ensure that there is no negative zero
+void BigInt::checkZero()
+{
+    assert(isValid());
+    
+    if(size == 1 && bits[0] == 0)
+        sign = POS;
+}