summaryrefslogtreecommitdiffstats
path: root/docs/writeups/ImaginaryCTF_2021/rock_solid_algorithm.txt
blob: efce1a7b1fc9fafdeb202672390998e5a5537490 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
The Problem
-----------
we're given a text file with three numbers labeled n, e, and c.  This looks like an RSA problem where we need to crack the ciphertext c.  e is only 5, which looks like this may be a small exponent attack.  Normally, e would be something like 65537 = 0x10001



The Small Exponent Attack
-------------------------
Most documentation online about small exponent attacks talk about e=3, but it should be the same for e=5.  Note that this doesn't necessarily mean c will be easy to crack, just that it might be.  Because of this, when my original attempt during the competition didn't work, I assumed it must be something else and just moved on.  After seeing writeups say that it was, in fact, a small exponent attack, I assumed that I must have made a typo or mistake somewhere in the math and went back to solve it after the fact.

The idea behind a small exponent attack is that our ciphertext c is calculated from plaintext p with exponent e and modulo n

    c = (p**e)%n

which would mean that

    (p**e) = c + (k * n)

for some k. Obviously we don't know what that k is, and it could be very large if p was big enough. A higher e makes it much more likely that k would be extremely large, but since e is so small, there is a chance that we can find (p**e) with only a few iterations trying increasing values for k.

That alone isn't enough, though, as we have no way to know directly if the (p**e) we calculate for a given k is the correct one.  For that, we need to calculate the e-th root of (p**e) to get p and check if a) that root even exists and b) if that p is our flag or some other random value that happens to work.  For my attack, I only ever actually checked if the root existed as I figured if I found a false positive I could just fix the code then.



e-th Root In Python
-------------------
Now the hard part is calculating the e-th root of (p**e).  There are a number of math tools that can do this for us, but I didn't feel like relearning sagemath or NumPy, so I was trying to just deal with it myself in pure python.  As it turns out, there are tons of little hiccups that make this pretty annoying.

First, for a t = (p**e), p = t**(1/e) will try to convert t into a float and fail because the int is too large.  So we can't just do that.

Instead, I decided to write a binary search over range(0,t) which would calculate i**e and check it against t.  This is a pretty decently fast way to search for an e-th root.

I wrote a generic binary search function which I figured I could reuse later for something else.  I did it recursively originally, but python had a fit when my numbers were too large and it hit the recursion limit.  So I had to rewrite it as a loop instead.

I also ran into various issues with keeping it generic and dealing with large numbers.  For instance, len(range(0,really_big_number)) throws an exception that that int would be too large for a C ssize_t.  So I had to add the ability to define custom start and end points in the iterable.

The final code looks like this:

```
#binary search
#searches for an s in i that satisfies x == f(i[s])
#i = iterable
#f = function to call on each element of i
#x = value to search for
#start = offset into iterable to start
#end = offset into iterable to end
#if it finds a match, it returns a tuple of (s,i[s],f(i[s]))
#if it does not find a match, it returns (-1,None,None)
def bsearch(i,f,x,start=0,end=-1):
    if end == -1:
        end = len(i)-1
    #s = _bsearch(i,f,start,end,x)
    s = _bsearch2(i,f,start,end,x)
    return (s,i[s],f(i[s])) if s != -1 else (s,None,None)

#recursive
def _bsearch(i,f,lo,hi,x):
    if hi >= lo:
        md = (hi+lo)//2
        a = f(i[md])
        if a == x:
            return md
        elif a > x:
            return _bsearch(i,f,lo,md-1,x)
        else:
            return _bsearch(i,f,md+1,hi,x)
    else:
        return -1

#loop
def _bsearch2(i,f,lo,hi,x):
    while True:
        if hi >= lo:
            md = (hi+lo)//2
            a = f(i[md])
            if a == x:
                return md
            elif a > x:
                hi = md-1
            else:
                lo = md+1
        else:
            return -1

def small_e_attack(c,e,n,kend=100):
    for k in range(1,kend):
        t = c + (k*n)
        p = bsearch(range(1,t),lambda p: p**e,t,end=t-2)[1]
        if p != None:
            print(p.to_bytes(40,'big').decode())
            break;
```

and then we can just call

    small_e_attack(c,e,n)

and it prints the flag