lizardqueen
1011 words
5 minutes
Not New PRNG, SECCON finals 2022

Description:#

Recently, I learned that this random number generator is called "MRG".

Source:#

import os
import random
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from Crypto.Util.number import getPrime


p = getPrime(128)

xs = [random.randint(1, 2**64) for _ in range(4)]

a = random.randint(1, p)
b = random.randint(1, p)
c = random.randint(1, p)
d = random.randint(1, p)
e = random.randint(1, p)  # unknown

xs.append((a*xs[-4] + b*xs[-3] + c*xs[-2] + d*xs[-1] + e) % p)
xs.append((a*xs[-4] + b*xs[-3] + c*xs[-2] + d*xs[-1] + e) % p)
xs.append((a*xs[-4] + b*xs[-3] + c*xs[-2] + d*xs[-1] + e) % p)

outs = xs[-3:]


# encryption
FLAG = os.getenv("FLAG", "fake{the_flag_is_a_lie}")
key = 0
for x in xs[:4]:
    key <<= 64
    key += x
key = int(key).to_bytes(32, "little")
iv = get_random_bytes(16)  # public
cipher = AES.new(key, AES.MODE_CBC, iv)
ct = cipher.encrypt(pad(FLAG.encode(), 16))  # public

# output
print(f"p = {p}")
print(f"a = {a}")
print(f"b = {b}")
print(f"c = {c}")
print(f"d = {d}")
print(f"outs = {outs}")
print(f"iv = 0x{iv.hex()}")
print(f"ct = 0x{ct.hex()}")

SOLUTION#

Finding the lattice for this challenge was not hard, but i found it a good way to practice enumeration.

From the source, I took out the following equations:

  • x4=ax0+bx1+cx2+dx3+e(modp)x_4 = a x_0 + b x_1 + c x_2 + d x_3 + e \pmod p

  • x5=ax1+bx2+cx3+dx4+e(modp)x_5 = a x_1 + b x_2 + c x_3 + d x_4 + e \pmod p

  • x6=ax2+bx3+cx4+dx5+e(modp)x_6 = a x_2 + b x_3 + c x_4 + d x_5 + e \pmod p

Where a,b,c,d,x4,x5,x6,pa,b,c,d, x_4, x_5, x_6, p are known.

By expanding the equations above on the known terms (and ee), we have:

  • x4=ax0+bx1+cx2+dx3+e(modp)x_4 = a x_0 + b x_1 + c x_2 + d x_3 + e \pmod p

  • x5=adx0+(a+bd)x1+(b+cd)x2+(c+d2)x3+(1+d)e(modp)x_5 = ad x_0 + (a + bd) x_1 + (b + cd) x_2 + (c + d^2) x_3 + (1 + d) e \pmod p

  • x6=(ca+ad2)x0+(cb+d(a+bd))x1+(a+c2+d(b+cd))x2+(b+cd+d(c+d2))x3+(1+c+d(1+d))e(modp)x_6 = (ca + ad^2) x_0 + (cb + d(a + bd)) x_1 + (a + c^2 + d(b + cd)) x_2 + (b + cd + d(c + d^2)) x_3 + (1 + c + d(1 + d)) e \pmod p

Now, we can remove ee from these expressions by calculating:

  • x5(d+1)x4(modp)x_5 - (d+1)*x_4 \pmod p

  • x6(d2+c+d+1)x4(modp)x_6 - (d^2+c+d+1)*x_4 \pmod p

  • x6x5(d2+c)x4(modp)x_6 - x_5 - (d^2+c)*x_4 \pmod p

Since we know a,b,c,d,x4,x5,x6,pa,b,c,d, x_4, x_5, x_6, p, we can calculate the value of each expression. This will result in a system of equations with 4 variables and 3 equations, and therefore it has more than one solution. Whatever, all our variables have values between 1 and 2642^{64}.

We can then use the result and coefficients of x0,x1,x2,x3x_0, x_1, x_2, x_3 in each equation to generate the lattice for our problem.

My first try was using LLL, but that was not enough to get the solution I was looking for.

After spending some time reading the fpylll documentation, I finally managed to enumerate the lattice vectors and get the one I was looking for.

Solve Script#


import random
import numpy as np
from Crypto.Util.number import isPrime, getPrime
from itertools import product
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from fpylll import IntegerMatrix, LLL
from fpylll.fplll.gso import MatGSO
from fpylll.fplll.enumeration import Enumeration

p = 234687789984662131107323206406195107369
a = 35686285754866388325178539790367732387
b = 36011211474181220344603698726947017489
c = 84664322357902232989540976252462702046
d = 154807718022294938130158404283942212610
outs = [222378874028969090293268624578715626424, 42182082074667038745014860626841402403, 217744703567906139265663577111207633608]
iv = bytes.fromhex('f2dd287ca870eb9908bf52c44dfd9d2b')
ct = bytes.fromhex('236a6aca059ae29056a23f5458c644abb74640d672dba1ee049eb956e629b7afb03ae33b2b2b419c24197d33baf6d88e2f0eedfa90c06e1a2be18b2fae2270f05ce39de5e0d59bb9a442d1b3eb392658e45cf721094543b13d35df8cf9ce420c')

"""
#F = Zmod(p)
kxs = [var(f"xs{i}") for i in range(4)]
e = var(f'e')

kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]
kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]
kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]

xs = [kxs[0],
 kxs[1],
 kxs[2],
 kxs[3],
 e + (35686285754866388325178539790367732387%p)*kxs[0] + (36011211474181220344603698726947017489%p)*kxs[1] + (84664322357902232989540976252462702046%p)*kxs[2] + (154807718022294938130158404283942212610%p)*kxs[3],
 (154807718022294938130158404283942212611%p)*e + (5524512462402396504522631022186993689941902586904182387956969847513136800070%p)*kxs[0] + (5574813471536278371812192944436468893673412103510013549372682471857394068677%p)*kxs[1] + (13106690542130809785016373593585159951283185043634670521563223863167361017549%p)*kxs[2] + (23965429559270380990259083357664627080454574539545585116705269958114905714146%p)*kxs[3],
 (23965429559270380990259083357664627080609382257567880054835428362398847926757%p)*e + (855237167490244464144875675573572867629762831370918736738696909377997562042125797695788728304700411167882848246502%p)*kxs[0] + (863024151928479330388076030989990627721702706923244759964245940344901793434930274712611324824222143408391433499464%p)*kxs[1] + (2029016853651666375038394107112295102356624289908454784624904137323194978538243016927915014578940870109681839411393%p)*kxs[2] + (3710033461494701195379191462199896884427637415170375363625339494855060676724077241776127864987795159943314777598609%p)*kxs[3]]


#[xs0, xs1, xs2, xs3, e + 35686285754866388325178539790367732387*xs0 + 36011211474181220344603698726947017489*xs1 + 84664322357902232989540976252462702046*xs2 + 154807718022294938130158404283942212610*xs3, 154807718022294938130158404283942212611*e + 24233268721794315913299373990028841403*xs0 + 83801899324939637851561928647683080672*xs1 + 124789883491551059250886060798635202617*xs2 + 117261698239628161615415951256878674368*xs3, 37381626277260968638251149134625779610*e + 27700972286058499906845172507055534594*xs0 + 174414886418714913984767195293981001560*xs1 + 68655552639519520906718520656801969746*xs2 + 125049678834442576080997917331138548268*xs3]
"""

A = []
A+=[[35686285754866388325178539790367732387,36011211474181220344603698726947017489,84664322357902232989540976252462702046,154807718022294938130158404283942212610]]
A+=[[24233268721794315913299373990028841403,83801899324939637851561928647683080672,124789883491551059250886060798635202617,117261698239628161615415951256878674368]]
A+=[[27700972286058499906845172507055534594,174414886418714913984767195293981001560,68655552639519520906718520656801969746,125049678834442576080997917331138548268]]

A = Matrix(A)
E = vector([1,154807718022294938130158404283942212611,37381626277260968638251149134625779610])


#here, we have Ax + eE = B, we can "remove" the e 

F = Zmod(p)
xs = [var(f"xs{i}") for i in range(4)]
kxs = [x for x in xs]
e = var(f'e')

kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]
kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]
kxs += [(a*kxs[-4] + b*kxs[-3] + c*kxs[-2] + d*kxs[-1] + e)]

x4=kxs[4]
x5=kxs[5]
x6=kxs[6]


f1 = x5 - (d+1)*x4
out3 = (outs[1] - (d+1)*outs[0]) % p

f2 = x6 - (d^2+c+d+1)*x4
out4 = (outs[2] - (d^2+c+d+1)*outs[0]) % p

f3 = x6 - x5 - (d^2+c)*x4
out5 = (outs[2] - outs[1] - (d^2+c)*outs[0]) % p

line1 = [F(f1.coefficient(x)) for x in xs]
line2 = [F(f2.coefficient(x)) for x in xs]
line3 = [F(f3.coefficient(x)) for x in xs]

A = Matrix([line1,line2,line3])
B = vector([out3,out4,out5])



def enumerator(B, matrix, p, bound):
    n = len(B)
    m = len(matrix[0])
    L = [
        [0 for _ in range(n+m)] for _ in range(n+m)
    ]
    for i in range(n):
        L[i][i] = p

    for i in range(m-1):
        L[n+i][n+i] = 1

    L[-1][-1] = bound

    for i, (y, coeff) in enumerate(zip(B, matrix)):
        a_inv = coeff[0]^-1
        constant = y*a_inv 
        _coeff = [-v * a_inv for v in coeff][1:] + [constant]

        for j, x in enumerate(_coeff):
            L[j+n][i] = int(x)

    sols = []

    A = IntegerMatrix.from_matrix(L)
    LLL.reduction(A)
    M = MatGSO(A)
    M.update_gso()

    sol_nr = 1000
    enum = Enumeration(M, sol_nr)
    answers = enum.enumerate(0, n+m, (n+m * bound**2), 0, pruning=None)

    for _, s in answers:
        v = IntegerMatrix.from_iterable(1, A.nrows, map(int, s))
        newsol = v * A

        if abs(newsol[0, n+m-1]) == bound:
            sig = 1 if newsol[0, n+m-1] == bound else -1
            newsol = [sig*x for x in newsol[0]]
            ok = True
            for x in newsol:
                if x < 0:
                    ok = False
                    break
            if not ok:
                continue

            if len(set(newsol[:n])) != 1:
                continue

            sols.append([newsol[0]] + newsol[n:-1])

    return sols

candidates = enumerator([out3, out4, out5], [line1, line2, line3], p, 2**64)

for v in candidates:
    key = 0
    for x in v:
        key <<= 64
        key += x
    key = int(key).to_bytes(32, "little")
    cipher = AES.new(key, AES.MODE_CBC, iv)
    pt = cipher.decrypt(ct)

    if b'SECCON' in pt:
        print(pt)
        break

Flag: SECCON{My_challenges_tend_to_be_solved_by_lattice_'reduction'. How_did_you_do_this_time?}

Not New PRNG, SECCON finals 2022
https://ctf.l1z4rdq.com/posts/notnewprng_seccon22/
Author
lizardqueen
Published at
2024-03-30