// AKS.cpp -- Implementation of the AKS algorithm

// Compiles under Microsoft Visual C++

//#define _SIEVE

#pragma message ("TODO: Add the path to <NTL\\ZZ.h> and <NTL\\ZZ_pX.h> via Project Properties -> C/C++ -> General -> Additional Include Directories.")
#include <NTL\ZZ.h>
#include <NTL\ZZ_pX.h>

#ifdef _DEBUG
    #pragma comment(lib, "..\\NTL\\x64\\Debug\\NTL.lib")
#else
    #pragma comment(lib, "..\\NTL\\x64\\Release\\NTL.lib")
#endif

#include <iostream>
#include <string>
#include <exception>

using namespace std;

inline unsigned __int64 Power(const unsigned __int64 A, const unsigned int B)
{
    unsigned __int64 a = A;
    unsigned int b = B;
    unsigned __int64 n = 1;
    while(b)
    {
        if(b & 1)
        {
            n *= a;
        }
        a *= a;
        b >>= 1;
    }

    return n;
}

inline bool IsNPerfectPower(const unsigned __int64 n)
{
    const unsigned int nLogTwoN = static_cast<unsigned int>(
        ::log(static_cast<long double>(n)) / ::log(2.0) + 1
        );
    for(unsigned int b = 2; b <= nLogTwoN; b++)
    {
        long double a = ::pow(
            static_cast<long double>(n), static_cast<long double>(1) / static_cast<long double>(b)
            );
        if(n == Power(static_cast<unsigned __int64>(::floorl(a + 0.5)), b))
        {
            return true;
        }
    }

    return false;
}

inline unsigned __int64 GCD(
    const unsigned __int64 N, const unsigned __int64 R
    )
{  
    unsigned __int64 n = N;
    unsigned __int64 r = R;
    unsigned __int64 r0;
    while(r)
    {
        r0 = r;
        r = n % r;
        n = r0;
    }

    return n;
}

inline bool IsPrime(const unsigned __int64 n)
{
    // Step 1
    if(IsNPerfectPower(n))
    {
        return false;
    }

    // Step 2
    unsigned __int64 r = 3;
#ifdef _SIEVE
    bool b = false;
    while(r < n)
    {
        if(GCD(n, r) > 1)
        {
            b = true;
            break;
        }
        
        r += 2;
    }

    // Step 3
    if(n == r)
    {
        return true;
    }

    // Step 4
    if(b)
    {
        return false;
    }
#endif

    // Step 5
    char buff[_MAX_PATH];
    sprintf_s(buff, "%I64u", n);
    NTL::ZZ zzn = NTL::to_ZZ(buff);
    NTL::ZZ_p::init(zzn);

    NTL::ZZ_pX zzPxf(static_cast<long>(r), 1);
    zzPxf -= 1;
    const NTL::ZZ_pXModulus zzPxMpf(zzPxf);
    NTL::ZZ_pX zzPxRhs(1, 1);
    NTL::PowerMod(zzPxRhs, zzPxRhs, zzn, zzPxMpf);

    long aUpper = static_cast<long>(
        2 * sqrtl(static_cast<long double>(r)) * ::log(static_cast<long double>(n)) / ::log(2.0)
        );
    for(long a = 1; a <= aUpper; ++a)
    {
        NTL::ZZ_pX zzPxLhs(1, 1);
        zzPxLhs += a;
        NTL::PowerMod(zzPxLhs, zzPxLhs, zzn, zzPxMpf);
        zzPxLhs -= a;
        if(zzPxLhs != zzPxRhs)
        {
            return false;
        }
    }

    // Step 6
    return true;
}

int main(int argc, char* argv[])
{
    try
    {
        string s;
        while(true)
        {           
            cout << "Enter a number > 1 (0 to quit): ";
            cin >> s;
            unsigned __int64 n = ::_atoi64(s.c_str());
            if(n <= 1)
            {
                if(n == 0)
                {
                    break;
                }

                continue;
            }
            cout << n << (IsPrime(n) ? " is prime."  : " is composite.") << endl;
        }
    }
    catch(const exception& e)
    {
        cout << e.what()  << endl;
    }
    catch(...)
    {
        cout << "Errors occured."  << endl;
    }

    return 0;
}

