3

I am trying to Solve elliptic curve discrete logarithm using Pollard rho (find k where G=kp), So i searched for implementation in c and i found one after adding problem specific data in the main function i got segmentation fault (core dumped)

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <gmp.h>
#include <limits.h>
#include <sys/time.h>

#include <openssl/ec.h>
#include <openssl/bn.h>
#include <openssl/obj_mac.h> // for NID_secp256k1

#define POLLARD_SET_COUNT 16

#if defined(WIN32) || defined(_WIN32)
#define EXPORT __declspec(dllexport)
#else
#define EXPORT
#endif

#define MAX_RESTART 100

int ec_point_partition(const EC_GROUP *ecgrp, const EC_POINT *x) {  

    size_t len = EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL );
    unsigned char ret[len]; 
    EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, ret, len, NULL );

    int id = ( ret[len - 1] & 0xFF ) % POLLARD_SET_COUNT;

    return id;
}

// P generator 
// Q result*P
// order of the curve
// result
//Reference: J. Sattler and C. P. Schnorr, "Generating random walks in groups"

int elliptic_pollard_rho_dlog(const EC_GROUP *group, const EC_POINT *P, const EC_POINT *Q, const BIGNUM *order, BIGNUM *res) {

    printf("Pollard rho discrete log algorithm... \n");

    BN_CTX* ctx;
    ctx = BN_CTX_new();

    int i, j;
    int iterations = 0;

    if ( !EC_POINT_is_on_curve(group, P, ctx ) || !EC_POINT_is_on_curve(group, Q, ctx ) ) return 1;

    EC_POINT *X1 = EC_POINT_new(group);
    EC_POINT *X2 = EC_POINT_new(group);

    BIGNUM *c1 = BN_new();
    BIGNUM *d1 = BN_new();
    BIGNUM *c2 = BN_new();
    BIGNUM *d2 = BN_new();

    BIGNUM* a[POLLARD_SET_COUNT];
    BIGNUM* b[POLLARD_SET_COUNT];
    EC_POINT* R[POLLARD_SET_COUNT];

    BN_zero(c1); BN_zero(d1);
    BN_zero(c2); BN_zero(d2);


    for (i = 0; i < POLLARD_SET_COUNT; i++) {   

        a[i] = BN_new();
        b[i] = BN_new();
        R[i] = EC_POINT_new(group);

        BN_rand_range(a[i], order);     
        BN_rand_range(b[i], order);

        // R = aP + bQ

        EC_POINT_mul(group, R[i], a[i], Q, b[i], ctx);
        //ep_norm(R[i], R[i]);
    }

    BN_rand_range(c1, order);       
    BN_rand_range(d1, order);       


    // X1 = c1*P + d1*Q
    EC_POINT_mul(group, X1, c1, Q, d1,  ctx);  
    //ep_norm(X1, X1);

    BN_copy(c2, c1);
    BN_copy(d2, d1);
    EC_POINT_copy(X2, X1);


    double work_time = (double) clock();
    do {
        j = ec_point_partition(group, X1);
        EC_POINT_add(group, X1, X1, R[j], ctx);

        BN_mod_add(c1, c1, a[j], order, ctx); 

        BN_mod_add(d1, d1, b[j], order, ctx); 

        for (i = 0; i < 2; i++) {
            j = ec_point_partition(group, X2);

            EC_POINT_add(group, X2, X2, R[j], ctx);

            BN_mod_add(c2, c2, a[j], order, ctx); 

            BN_mod_add(d2, d2, b[j], order, ctx);
        }

        iterations++;
        printf("Iteration %d \r",iterations );
    } while ( EC_POINT_cmp(group, X1, X2, ctx) != 0 ) ;


    printf("\n ");

    work_time = ( (double) clock() - work_time ) / (double)CLOCKS_PER_SEC;

    printf("Number of iterations %d %f\n",iterations, work_time );

    BN_mod_sub(c1, c1, c2, order, ctx);
    BN_mod_sub(d2, d2, d1, order, ctx);

    if (BN_is_zero(d2) == 1) return 1;


    //d1 = d2^-1 mod order  
    BN_mod_inverse(d1, d2, order, ctx);

    BN_mod_mul(res, c1, d1, order, ctx);

    for (int k = 0; k < POLLARD_SET_COUNT; ++k) {
        BN_free(a[k]); 
        BN_free(b[k]);
        EC_POINT_free(R[k]);
    }
    BN_free(c1); BN_free(d1);
    BN_free(c2); BN_free(d2);
    EC_POINT_free(X1); EC_POINT_free(X2);

    BN_CTX_free(ctx);
    return 0;
}


int main(int argc, char *argv[])
{
    unsigned char *p_str="134747661567386867366256408824228742802669457";
    unsigned char *a_str="-1";
    unsigned char *b_str="0";
    BIGNUM *p = BN_bin2bn(p_str, sizeof(p_str), NULL);
    BIGNUM *a = BN_bin2bn(a_str, sizeof(a_str), NULL);
    BIGNUM *b = BN_bin2bn(b_str, sizeof(b_str), NULL);
    BN_CTX* ctx;
    ctx = BN_CTX_new();
    EC_GROUP* g = EC_GROUP_new(EC_GFp_simple_method());
    EC_GROUP_set_curve_GFp(g,p,a,b,ctx);    
    unsigned char *XP_str="18185174461194872234733581786593019886770620";
    unsigned char *YP_str="74952280828346465277451545812645059041440154";

    BN_CTX* ctx1;
    ctx1 = BN_CTX_new();
    BIGNUM *XP = BN_bin2bn(XP_str, sizeof(XP_str), NULL);
    BIGNUM *YP = BN_bin2bn(YP_str, sizeof(YP_str), NULL);
    EC_POINT* P = EC_POINT_new(g);
    EC_POINT_set_affine_coordinates_GFp(g,P,XP,YP,ctx1);

    unsigned char *XQ_str="76468233972358960368422190121977870066985660";
    unsigned char *YQ_str="33884872380845276447083435959215308764231090";
    BIGNUM* XQ = BN_bin2bn(XQ_str, sizeof(XQ_str), NULL);
    BIGNUM* YQ = BN_bin2bn(YQ_str, sizeof(YQ_str), NULL);
    EC_POINT *Q = EC_POINT_new(g);
    BN_CTX* ctx2;
    ctx2 = BN_CTX_new();
    EC_POINT_set_affine_coordinates_GFp(g,Q,XQ,YQ,ctx2);
    char * str;


    unsigned char *N_str="2902021510595963727029";
    BIGNUM *N = BN_bin2bn(N_str, sizeof(N_str), NULL);
    BIGNUM *res;
    elliptic_pollard_rho_dlog (g,P,Q,N,res);
    BN_bn2mpi(res,str); 
    printf("%s\n", str);


  return 0;
}

This is the statement that cause segmentation fault

    BN_bn2mpi(res,str); 
Chaker
  • 1,197
  • 9
  • 22
  • You should use a debugger to see the exact values you are passing to `BN_bn2mpi`. That might reveal the bug. You might also need to make an [MCVE](http://stackoverflow.com/help/mcve). – David Grayson May 16 '15 at 18:46

1 Answers1

-1

Part 1. Python version.

Update: See new Part 2 of my answer, there I present C++ version of same algorithm as this Python version.

Your task is very interesting!

Maybe you wanted your code to be fixed but instead I decided to implement from scratch pure Python (Part 1 of answer) and pure C++ (Part 2) solutions without using any external non-standard modules. This kind of solutions from scratch without dependencies I think are very useful for educational purposes.

Algorithm like this is quite complex, and Python is easy enough to make implementation of such algorithm possible in short time.

In code below I used help of Wikipedia to implement Pollard's Rho Discrete Logarithm and Elliptic Curve Point Multiplication.

Code doesn't depend on any external modules, it uses just few built-in Python modules. There is a possibility to use gmpy2 module if you install it through python -m pip install gmpy2 and uncomment line #import gmpy2 in code.

You may see that I generate random base point myself and compute its order. I don't use any external curve like BitCoin's secp256k1, or other standard curves.

In the beginning of main() function you can see that I set up bits = 24, this is number of bits for prime modulus of curve, order of curve (number of distinct points) will have about the same bit size. You may set it to bits = 32 to try solving task for bigger curve.

As known, algorithm's complexity is O(Sqrt(Curve_Order)), it takes this many elliptic curve points additions. Points additions are not primitive operations and also take some time. So algorithm run for curve order bit size of bits = 32 takes about 10-15 seconds. While bits = 64 will take a way too long time for Python, but C++ version (that I'm going to implement later) will be fast enough to crack 64 bits within an hour or so.

Sometimes you may notice when running code that it shows that Pollard Rho failed few times, this happens if algorithm tries to find modular inverse of non-invertible number (non-coprime to modulus) both at last step of Pollard Rho and also when computing Infinite Point as a result of elliptic curve point addition. Same kind of failure also happens from time to time in regular Pollard Rho Integer Factorization when GCD is equal to N.

Try it online!

import random
#random.seed(10)

class ECPoint:
    gmpy2 = None
    #import gmpy2
    import random
    
    class InvError(Exception):
        def __init__(self, *args):
            self.value = args
    
    @classmethod
    def Int(cls, x):
        return int(x) if cls.gmpy2 is None else cls.gmpy2.mpz(x)
    
    @classmethod
    def fermat_prp(cls, n, trials = 32):
        # https://en.wikipedia.org/wiki/Fermat_primality_test
        if n <= 16:
            return n in (2, 3, 5, 7, 11, 13)
        for i in range(trials):
            if pow(cls.random.randint(2, n - 2), n - 1, n) != 1:
                return False
        return True
    
    @classmethod
    def rand_prime(cls, bits):
        while True:
            p = cls.random.randrange(1 << (bits - 1), 1 << bits) | 1
            if cls.fermat_prp(p):
                return p
    
    @classmethod
    def base_gen(cls, bits = 128, *, min_order_pfactor = 0):
        while True:
            while True:
                N = cls.rand_prime(bits)
                if N % 4 != 3:
                    continue
                x0, y0, A = [cls.random.randrange(1, N) for i in range(3)]
                B = (y0 ** 2 - x0 ** 3 - A * x0) % N
                y0_calc = pow(x0 ** 3 + A * x0 + B, (N + 1) // 4, N)
                if y0 == y0_calc:
                    break
            bp = ECPoint(A, B, N, x0, y0, calc_q = True)
            if bp.q is not None and min(bp.q_ps) >= min_order_pfactor:
                break
        assert bp.q > 1 and (bp.q + 1) * bp == bp
        return bp
    
    def __init__(self, A, B, N, x, y, *, q = 0, prepare = True, calc_q = False):
        if prepare:
            N = self.Int(N)
            assert (x is None) == (y is None), (x, y)
            A, B, x, y, q = [(self.Int(e) % N if e is not None else None) for e in [A, B, x, y, q]]
            assert (4 * A ** 3 + 27 * B ** 2) % N != 0
            assert N % 4 == 3
            if x is not None:
                assert (y ** 2 - x ** 3 - A * x - B) % N == 0, (hex(N), hex((y ** 2 - x ** 3 - A * x) % N))
                assert y == pow(x ** 3 + A * x + B, (N + 1) // 4, N)
        self.A, self.B, self.N, self.x, self.y, self.q = A, B, N, x, y, q
        if calc_q:
            self.q, self.q_ps = self.find_order()
    
    def copy(self):
        return ECPoint(self.A, self.B, self.N, self.x, self.y, q = self.q, prepare = False)
    
    def inf(self):
        return ECPoint(self.A, self.B, self.N, None, None, q = self.q, prepare = False)
    
    def find_order(self, *, _m = 1, _ps = []):
        if 1:
            try:
                r = _m * self
            except self.InvError:
                return _m, _ps
            B = 2 * self.N
            for p in self.gen_primes():
                if p * p > B * 2:
                    return None, []
                assert _m % p != 0, (_m, p)
                assert p <= B, (p, B)
                hi = 1
                try:
                    for cnt in range(1, 1 << 60):
                        hi *= p
                        if hi > B:
                            cnt -= 1
                            break
                        r = p * r
                except self.InvError:
                    return self.find_order(_m = hi * _m, _ps = [p] * cnt + _ps)
        else:
            # Alternative slower way
            r = self
            for i in range(1 << 60):
                try:
                    r = r + self
                except self.InvError:
                    return i + 2, []
    
    @classmethod
    def gen_primes(cls, *, ps = [2, 3]):
        yield from ps
        for p in range(ps[-1] + 2, 1 << 60, 2):
            is_prime = True
            for e in ps:
                if e * e > p:
                    break
                if p % e == 0:
                    is_prime = False
                    break
            if is_prime:
                ps.append(p)
                yield ps[-1]
    
    def __add__(self, other):
        if self.x is None:
            return other.copy()
        if other.x is None:
            return self.copy()
        A, B, N, q = self.A, self.B, self.N, self.q
        Px, Py, Qx, Qy = self.x, self.y, other.x, other.y
        if Px == Qx and Py == Qy:
            s = ((Px * Px * 3 + A) * self.inv(Py * 2, N)) % N
        else:
            s = ((Py - Qy) * self.inv(Px - Qx, N)) % N
        x = (s * s - Px - Qx) % N
        y = (s * (Px - x) - Py) % N
        return ECPoint(A, B, N, x, y, q = q, prepare = False)
    
    def __rmul__(self, other):
        assert other > 0, other
        if other == 1:
            return self.copy()
        other = self.Int(other - 1)
        r = self
        while True:
            if other & 1:
                r = r + self
                if other == 1:
                    return r
            other >>= 1
            self = self + self
    
    @classmethod
    def inv(cls, a, n):
        a %= n
        if cls.gmpy2 is None:
            try:
                return pow(a, -1, n)
            except ValueError:
                import math
                raise cls.InvError(math.gcd(a, n), a, n)
        else:
            g, s, t = cls.gmpy2.gcdext(a, n)
            if g != 1:
                raise cls.InvError(g, a, n)
            return s % n

    def __repr__(self):
        return str(dict(x = self.x, y = self.y, A = self.A, B = self.B, N = self.N, q = self.q))

    def __eq__(self, other):
        for i, (a, b) in enumerate([(self.x, other.x), (self.y, other.y), (self.A, other.A), (self.B, other.B), (self.N, other.N), (self.q, other.q)]):
            if a != b:
                return False
        return True

def pollard_rho_ec_log(a, b, bp):
    # https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
    import math

    for itry in range(1 << 60):    
        try:
            i = -1
            part_p = bp.rand_prime(max(3, int(math.log2(bp.N) / 2)))
            
            def f(x):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return b + x
                elif mod3 == 1:
                    return x + x
                elif mod3 == 2:
                    return a + x
                else:
                    assert False
            
            def g(x, n):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return n
                elif mod3 == 1:
                    return (2 * n) % bp.q
                elif mod3 == 2:
                    return (n + 1) % bp.q
                else:
                    assert False
            
            def h(x, n):
                mod3 = ((x.x or 0) % part_p) % 3
                if mod3 == 0:
                    return (n + 1) % bp.q
                elif mod3 == 1:
                    return (2 * n) % bp.q
                elif mod3 == 2:
                    return n
                else:
                    assert False
            
            a0, b0, x0 = 0, 0, bp.inf()
            aim1, bim1, xim1 = a0, b0, x0
            a2im2, b2im2, x2im2 = a0, b0, x0
            
            for i in range(1, 1 << 60):
                xi = f(xim1)
                ai = g(xim1, aim1)
                bi = h(xim1, bim1)
                
                x2i = f(f(x2im2))
                a2i = g(f(x2im2), g(x2im2, a2im2))
                b2i = h(f(x2im2), h(x2im2, b2im2))
                
                if xi == x2i:
                    return (bp.inv(bi - b2i, bp.q) * (a2i - ai)) % bp.q
                
                xim1,  aim1, bim1 = xi, ai, bi
                x2im2, a2im2, b2im2 = x2i, a2i, b2i
        except bp.InvError as ex:
            print(f'Try {itry:>4}, Pollard-Rho failed, invert err at iter {i:>7},', ex.value)

def main():
    import random, math
    bits = 24
    print('Generating base point, wait...')
    bp = ECPoint.base_gen(bits, min_order_pfactor = 10)
    print('order', bp.q, '=', ' * '.join([str(e) for e in bp.q_ps]))
    k0, k1 = [random.randrange(1, bp.q) for i in range(2)]
    a = k0 * bp
    x = k1
    b = x * a
    x_calc = pollard_rho_ec_log(a, b, bp)
    print('our x', x, 'found x', x_calc)
    print('equal points:', x * a == x_calc * a)

if __name__ == '__main__':
    main()

Output:

Generating base point, wait...
order 5805013 = 19 * 109 * 2803
Try    0, Pollard-Rho failed, invert err at iter    1120, (109, 1411441, 5805013)
Try    1, Pollard-Rho failed, invert err at iter    3992, (19, 5231802, 5805013)
our x 990731 found x 990731
equal points: True

Part 2. C++ version.

Almost identical code as above, but rewritten in C++.

This C++ version is much faster then Python, C++ code spends around 1 minute on 1 Ghz CPU to crack 48-bit curve. Same amount of time is spent by Python on 32-bit curve.

To remind, complexity is O(Sqrt(Curve_Order)) it means that if C++ spends same time for 48 bits (sqrt is 2^24) as Python for 32 bits (sqrt is 2^16) then C++ is around 2^24/2^16 = 2^8 = 256 times faster than Python's version.

Following version is compilable only in CLang, because it uses 128 and 192 bit integers. In GCC there also exists __int128 but no 192/256 ints. 192-bit int is only used in BarrettMod() function, so if you replace this function's body with return x % n; then you don't need 256-bit int and then you can compile in GCC.

I implemented Barrett Reduction algorithm, to replace operation of taking modulus (% N) that is based on slow division with special Barrett formula based on just multiply/shift/sub. This boosts modulus operations several times.

Try it online!

#include <cstdint>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <iomanip>
#include <iostream>
#include <string>
#include <chrono>
#include <cmath>

using u64 = uint64_t;
using u128 = unsigned __int128;
using u192 = unsigned _ExtInt(192);
using Word = u64;
using DWord = u128;
using SWord = std::make_signed_t<Word>;
using TWord = u192;

#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { g_log << " LN " << __LINE__ << " " << std::flush; }
#define DUMP(x) { g_log << " " << (#x) << " = " << (x) << " " << std::flush; }

static auto & g_log = std::cout;

class ECPoint {
public:
    class InvError : public std::runtime_error {
    public:
        InvError(Word const & gcd, Word const & x, Word const & mod)
            : std::runtime_error("(gcd " + std::to_string(gcd) + ", x " + std::to_string(x) +
                ", mod " + std::to_string(mod) + ")") {}
    };
    
    static Word pow_mod(Word a, Word b, Word const & c) {
        // https://en.wikipedia.org/wiki/Modular_exponentiation
        Word r = 1;
        while (b != 0) {
            if (b & 1)
                r = (DWord(r) * a) % c;
            a = (DWord(a) * a) % c;
            b >>= 1;
        }
        return r;
    }
    
    static Word rand_range(Word const & begin, Word const & end) {
        u64 const seed = (u64(std::random_device{}()) << 32) + std::random_device{}();
        thread_local std::mt19937_64 rng{seed};
        ASSERT(begin < end);
        return std::uniform_int_distribution<Word>(begin, end - 1)(rng);
    }
    
    static bool fermat_prp(Word const & n, size_t trials = 32) {
        // https://en.wikipedia.org/wiki/Fermat_primality_test
        if (n <= 16)
            return n == 2 || n == 3 || n == 5 || n == 7 || n == 11 || n == 13;
        for (size_t i = 0; i < trials; ++i)
            if (pow_mod(rand_range(2, n - 2), n - 1, n) != 1)
                return false;
        return true;
    }
    
    static Word rand_prime_range(Word begin, Word end) {
        while (true) {
            Word const p = rand_range(begin, end) | 1;
            if (fermat_prp(p))
                return p;
        }
    }
    
    static Word rand_prime(size_t bits) {
        return rand_prime_range(Word(1) << (bits - 1), Word((DWord(1) << bits) - 1));
    }
    
    std::tuple<Word, size_t> BarrettRS(Word n) {
        size_t constexpr extra = 3;
        for (size_t k = 0; k < sizeof(DWord) * 8; ++k) {
            if (2 * (k + extra) < sizeof(Word) * 8)
                continue;
            if ((DWord(1) << k) <= DWord(n))
                continue;
            k += extra;
            ASSERT_MSG(2 * k < sizeof(DWord) * 8, "k " + std::to_string(k));
            DWord r = (DWord(1) << (2 * k)) / n;
            ASSERT_MSG(DWord(r) < (DWord(1) << (sizeof(Word) * 8)),
                "k " + std::to_string(k) + " n " + std::to_string(n));
            ASSERT(2 * k >= sizeof(Word) * 8);
            return std::make_tuple(Word(r), size_t(2 * k - sizeof(Word) * 8));
        }
        ASSERT(false);
    }
    
    template <bool Adjust>
    static Word BarrettMod(DWord const & x, Word const & n, Word const & r, size_t s) {
        //return x % n;
        DWord const q = DWord(((TWord(x) * r) >> (sizeof(Word) * 8)) >> s);
        Word t = Word(DWord(x) - q * n);
        if constexpr(Adjust) {
            Word const mask = ~Word(SWord(t - n) >> (sizeof(Word) * 8 - 1));
            t -= mask & n;
        }
        return t;
    }
    
    static Word Adjust(Word const & a, Word const & n) {
        return a >= n ? a - n : a;
    }
    
    Word modNn(DWord const & a) const { return BarrettMod<false>(a, N_, N_br_, N_bs_); }
    Word modNa(DWord const & a) const { return BarrettMod<true>(a, N_, N_br_, N_bs_); }
    Word modQn(DWord const & a) const { return BarrettMod<false>(a, q_, Q_br_, Q_bs_); }
    Word modQa(DWord const & a) const { return BarrettMod<true>(a, q_, Q_br_, Q_bs_); }
    
    static Word mod(DWord const & a, Word const & n) { return a % n; }
    
    static ECPoint base_gen(size_t bits = 128, Word min_order_pfactor = 0) {
        while (true) {
            Word const N = rand_prime(bits);
            if (mod(N, 4) != 3)
                continue;
            Word const
                x0 = rand_range(1, N), y0 = rand_range(1, N), A = rand_range(1, N),
                B = mod(mod(DWord(y0) * y0, N) + N * 2 - mod(DWord(mod(DWord(x0) * x0, N)) * x0, N) - mod(DWord(A) * x0, N), N),
                y0_calc = pow_mod(mod(DWord(y0) * y0, N), (N + 1) >> 2, N);
            if (y0 != y0_calc)
                continue;
            auto const bp = ECPoint(A, B, N, x0, y0, 0, true, true);
            
            auto BpCheckOrder = [&]{
                for (auto e: bp.q_ps())
                    if (e < min_order_pfactor)
                        return false;
                return true;
            };
            
            if (!(bp.q() != 0 && !bp.q_ps().empty() && BpCheckOrder()))
                continue;
            ASSERT(bp.q() > 1 && bp * (bp.q() + 1) == bp);
            return bp;
        }
        ASSERT(false);
    }
    
    ECPoint(Word A, Word B, Word N, Word x, Word y, Word q = 0, bool prepare = true, bool calc_q = false) {
        if (prepare) {
            A = mod(A, N); B = mod(B, N); x = mod(x, N); y = mod(y, N); q = mod(q, N);
            ASSERT(mod(4 * mod(DWord(mod(DWord(A) * A, N)) * A, N) + 27 * mod(DWord(B) * B, N), N) != 0);
            ASSERT(mod(N, 4) == 3);
            if (!(x == 0 && y == 0)) {
                ASSERT(mod(mod(DWord(y) * y, N) + 3 * N - mod(DWord(mod(DWord(x) * x, N)) * x, N) - mod(DWord(A) * x, N) - B, N) == 0);
                ASSERT(y == pow_mod(mod(DWord(mod(DWord(x) * x, N)) * x, N) + mod(DWord(A) * x, N) + B, (N + 1) >> 2, N));
            }
            std::tie(N_br_, N_bs_) = BarrettRS(N);
            if (q != 0)
                std::tie(Q_br_, Q_bs_) = BarrettRS(q);
        }
        std::tie(A_, B_, N_, x_, y_, q_) = std::tie(A, B, N, x, y, q);
        if (calc_q) {
            std::tie(q_, q_ps_) = find_order();
            if (q_ != 0)
                std::tie(Q_br_, Q_bs_) = BarrettRS(q_);
        }
    }
    
    auto copy() const {
        return ECPoint(A_, B_, N_, x_, y_, q_, false);
    }
    
    auto inf() const {
        return ECPoint(A_, B_, N_, 0, 0, q_, false);
    }
    
    static auto const & gen_primes(Word const B) {
        thread_local std::vector<Word> ps = {2, 3};
        
        for (Word p = ps.back() + 2; p <= B; p += 2) {
            bool is_prime = true;
            for (auto const e: ps) {
                if (e * e > p)
                    break;
                if (p % e == 0) {
                    is_prime = false;
                    break;
                }
            }
            if (is_prime)
                ps.push_back(p);
        }
        
        return ps;
    }
    
    std::tuple<Word, std::vector<Word>> find_order(Word _m = 1, std::vector<Word> _ps = {}) const {
        ASSERT(_m <= 2 * N_);
        
        if constexpr(1) {
            auto r = *this;
            try {
                r *= _m;
            } catch (InvError const &) {
                return std::make_tuple(_m, _ps);
            }
            Word const B = 2 * N_;
            for (Word const p: gen_primes(std::llround(std::cbrt(B) + 1))) {
                if (p * p * p > B)
                    break;
                ASSERT(p <= B);
                size_t cnt = 0;
                Word hi = 1;
                try {
                    for (cnt = 1;; ++cnt) {
                        if (hi * p > B) {
                            cnt -= 1;
                            break;
                        }
                        hi *= p;
                        r *= p;
                     }
                } catch (InvError const & ex) {
                    _ps.insert(_ps.begin(), cnt, p);
                    return find_order(hi * _m, _ps);
                }
            }
        } else {
            // Alternative slower way
            auto r = *this;
            for (Word i = 0;; ++i)
                try {
                    r += *this;
                } catch (InvError const &) {
                    _ps.clear();
                    return std::make_tuple(i + 2, _ps);
                }
        }
        
        _ps.clear();
        return std::make_tuple(Word(0), _ps);
    }
    
    static std::tuple<Word, SWord, SWord> EGCD(Word const & a, Word const & b) {
        Word ro = 0, r = 0, qu = 0, re = 0;
        SWord so = 0, s = 0;
        std::tie(ro, r, so, s) = std::make_tuple(a, b, 1, 0);
        while (r != 0) {
            std::tie(qu, re) = std::make_tuple(ro / r, ro % r);
            std::tie(ro, r) = std::make_tuple(r, re);
            std::tie(so, s) = std::make_tuple(s, so - s * SWord(qu));
        }
        SWord const to = (SWord(ro) - SWord(a) * so) / SWord(b);
        return std::make_tuple(ro, so, to);
    }
    
    Word inv(Word a, Word const & n, size_t any_n_q = 0) const {
        ASSERT(n > 0);
        a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
        auto [gcd, s, t] = EGCD(a, n);
        if (gcd != 1)
            throw InvError(gcd, a, n);
        a = Word(SWord(n) + s);
        a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
        return a;
    }
    
    Word invN(Word a) const { return inv(a, N_, 1); }
    Word invQ(Word a) const { return inv(a, q_, 2); }
    
    ECPoint & operator += (ECPoint const & o) {
        if (x_ == 0 && y_ == 0) {
            *this = o;
            return *this;
        }
        if (o.x_ == 0 && o.y_ == 0)
            return *this;
        Word const Px = x_, Py = y_, Qx = o.x_, Qy = o.y_;
        Word s = 0;
        if ((Adjust(Px, N_) == Adjust(Qx, o.N_)) && (Adjust(Py, N_) == Adjust(Qy, o.N_)))
            s = modNn(DWord(modNn(DWord(Px) * Px * 3) + A_) * invN(Py * 2));
        else
            s = modNn(DWord(Py + 2 * N_ - Qy) * invN(Px + 2 * N_ - Qx));
        x_ = modNn(DWord(s) * s + 4 * N_ - Px - Qx);
        y_ = modNn(DWord(s) * (Px + 2 * N_ - x_) + 2 * N_ - Py);
        return *this;
    }
    
    ECPoint operator + (ECPoint const & o) const {
        ECPoint c = *this;
        c += o;
        return c;
    }
    
    ECPoint & operator *= (Word k) {
        auto const ok = k;
        ASSERT(k > 0);
        if (k == 1)
            return *this;
        k -= 1;
        auto r = *this, s = *this;
        while (true) {
            if (k & 1) {
                r += s;
                if (k == 1)
                    break;
            }
            k >>= 1;
            s += s;
        }
        if constexpr(0) {
            auto r2 = *this;
            for (u64 i = 1; i < ok; ++i)
                r2 += *this;
            ASSERT(r == r2);
        }
        *this = r;
        return *this;
    }
    
    ECPoint operator * (Word k) const {
        ECPoint r = *this;
        r *= k;
        return r;
    }
    
    bool operator == (ECPoint const & o) const {
        return A_ == o.A_ && B_ == o.B_ && N_ == o.N_ && q_ == o.q_ &&
            Adjust(x_, N_) == Adjust(o.x_, o.N_) && Adjust(y_, N_) == Adjust(o.y_, o.N_);
    }
    
    Word const & q() const { return q_; }
    std::vector<Word> const & q_ps() const { return q_ps_; }
    Word const & x() const { return x_; }
    
private:
    Word A_ = 0, B_ = 0, N_ = 0, q_ = 0, x_ = 0, y_ = 0, N_br_ = 0, Q_br_ = 0;
    size_t N_bs_ = 0, Q_bs_ = 0;
    std::vector<Word> q_ps_;
};

Word pollard_rho_ec_log(ECPoint const & a, ECPoint const & b, ECPoint const & bp) {
    // https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
    
    for (u64 itry = 0;; ++itry) {
        u64 i = 0;
        
        try {
            Word const part_p = bp.rand_prime_range(8, bp.q() >> 4);
            
            auto ModQ = [&](Word n) {
                return n >= bp.q() ? n - bp.q() : n;
            };
            
            auto f = [&](auto const & x) -> ECPoint {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return b + x;
                else if (mod3 == 1)
                    return x + x;
                else if (mod3 == 2)
                    return a + x;
                else
                    ASSERT(false);
            };
            
            auto const g = [&](auto const & x, Word n) -> Word {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return n;
                else if (mod3 == 1)
                    return ModQ(2 * n);
                else if (mod3 == 2)
                    return ModQ(n + 1);
                else
                    ASSERT(false);
            };
            
            auto const h = [&](auto const & x, Word n) -> Word {
                Word const mod3 = (x.x() % part_p) % 3;
                if (mod3 == 0)
                    return ModQ(n + 1);
                else if (mod3 == 1)
                    return ModQ(2 * n);
                else if (mod3 == 2)
                    return n;
                else
                    ASSERT(false);
            };
            
            Word aim1 = 0, bim1 = 0, a2im2 = 0, b2im2 = 0, ai = 0, bi = 0, a2i = 0, b2i = 0;
            ECPoint xim1 = bp.inf(), x2im2 = bp.inf(), xi = bp.inf(), x2i = bp.inf();
            
            for (i = 1;; ++i) {
                xi = f(xim1);
                ai = g(xim1, aim1);
                bi = h(xim1, bim1);
                
                x2i = f(f(x2im2));
                a2i = g(f(x2im2), g(x2im2, a2im2));
                b2i = h(f(x2im2), h(x2im2, b2im2));
                
                if (xi == x2i)
                    return bp.modQa(DWord(bp.invQ(bp.q() + bi - b2i)) * (bp.q() + a2i - ai));
                
                std::tie(xim1,  aim1, bim1) = std::tie(xi, ai, bi);
                std::tie(x2im2, a2im2, b2im2) = std::tie(x2i, a2i, b2i);
            }
        } catch (ECPoint::InvError const & ex) {
            g_log << "Try " << std::setfill(' ') << std::setw(4) << itry << ", Pollard-Rho failed, invert err at iter "
                << std::setw(7) << i << ", " << ex.what() << std::endl;
        }
    }
}

void test() {
    auto const gtb = std::chrono::high_resolution_clock::now();
    auto Time = [&]() -> double {
        return std::chrono::duration_cast<std::chrono::milliseconds>(
            std::chrono::high_resolution_clock::now() - gtb).count() / 1'000.0;
    };
    double tb = 0;
    size_t constexpr bits = 36;
    g_log << "Generating base point, wait... " << std::flush;
    tb = Time();
    auto const bp = ECPoint::base_gen(bits, 50);
    g_log << "Time " << Time() - tb << " sec" << std::endl;
    g_log << "order " << bp.q() << " = ";
    for (auto e: bp.q_ps())
        g_log << e << " * " << std::flush;
    g_log << std::endl;
    Word const k0 = ECPoint::rand_range(1, bp.q()),
               x  = ECPoint::rand_range(1, bp.q());
    auto a = bp * k0;
    auto b = a * x;
    g_log << "Searching discrete logarithm... " << std::endl;
    tb = Time();
    Word const x_calc = pollard_rho_ec_log(a, b, bp);
    g_log << "Time " << Time() - tb << " sec" << std::endl;
    g_log << "our x " << x << ", found x " << x_calc << std::endl;
    g_log << "equal points: " << std::boolalpha << (a * x == a * x_calc) << std::endl;
}

int main() {
    try {
        test();
    } catch (std::exception const & ex) {
        g_log << "Exception: " << ex.what() << std::endl;
    }
}

Output:

Generating base point, wait... Time 38.932 sec
order 195944962603297 = 401 * 4679 * 9433 * 11071 *
Searching discrete logarithm...
Time 69.791 sec
our x 15520105103514, found x 15520105103514
equal points: true
Arty
  • 14,883
  • 6
  • 36
  • 69