Overview
Although I solved this challenge during competition, about two years ago, it was the first time I played in person with my team mates, so I decided to have this writeup as the first post for “sentimental reasons”:).
Description:
Descriptions are hard, but are decryptions?
Source:
#!/usr/bin/env python3
import struct, hashlib, random, os
from Crypto.Cipher import AES
n = 256
q = 11777
w = 8
################################################################
sample = lambda rng: [bin(rng.getrandbits(w)).count('1') - w//2 for _ in range(n)]
add = lambda f,g: [(x + y) % q for x,y in zip(f,g)]
def mul(f,g):
r = [0]*n
for i,x in enumerate(f):
for j,y in enumerate(g):
s,k = divmod(i+j, n)
r[k] += (-1)**s * x*y
r[k] %= q
return r
################################################################
def genkey():
a = [random.randrange(q) for _ in range(n)]
rng = random.SystemRandom()
s,e = sample(rng), sample(rng)
b = add(mul(a,s), e)
return s, (a,b)
center = lambda v: min(v%q, v%q-q, key=abs)
extract = lambda r,d: [2*t//q for u,t in zip(r,d) if u]
ppoly = lambda g: struct.pack(f'<{n}H', *g).hex()
pbits = lambda g: ''.join(str(int(v)) for v in g)
hbits = lambda g: hashlib.sha256(pbits(g).encode()).digest()
mkaes = lambda bits: AES.new(hbits(bits), AES.MODE_CTR, nonce=b'')
def encaps(pk):
seed = os.urandom(32)
rng = random.Random(seed)
a,b = pk
s,e = sample(rng), sample(rng)
c = add(mul(a,s), e)
d = add(mul(b,s), e)
r = [int(abs(center(2*v)) > q//7) for v in d]
bits = extract(r,d)
return bits, (c,r)
def decaps(sk, ct):
s = sk
c,r = ct
d = mul(c,s)
return extract(r,d)
################################################################
if __name__ == '__main__':
while True:
sk, pk = genkey()
dh, ct = encaps(pk)
if decaps(sk, ct) == dh:
break
print('pk[0]:', ppoly(pk[0]))
print('pk[1]:', ppoly(pk[1]))
print('ct[0]:', ppoly(ct[0]))
print('ct[1]:', pbits(ct[1]))
flag = open('flag.txt').read().strip()
print('flag: ', mkaes([0]+dh).encrypt(flag.encode()).hex())
for _ in range(2048):
c = list(struct.unpack(f'<{n}H', bytes.fromhex(input())))
r = list(map('01'.index, input()))
if len(r) != n or sum(r) < n//2: exit('!!!')
bits = decaps(sk, (c,r))
print(mkaes([1]+bits).encrypt(b'hxp<3you').hex())
SOLUTION
Notice the 3 main values:
n, the degree of the polynomial ring we’re working over, namelyq, the order of the prime field the polynomials are defined overw, the width of the binomial distirbution we’ll sample coefficients from
On line 11, we have:
sample = lambda rng: [bin(rng.getrandbits(w)).count('1') - w//2 for _ in range(n)]
This is the centered binomial distribution sampler. It samples w random bits and counts the number of ‘1’s, then subtracts w/2 to center the distribution around zero.
It’s also important to note that polynomials are represented here as lists of 256 coefficients, since in we have that .
My attack consisted of 2 steps:
1: collecting data for each
cCollect and save “side choices” for values of
c, wherecis a constant polynomial Letd = cs. Assumingd[0] < q/2, we do all the encrypted message comparisons to figure out which side of q/2 each coefficient is on.
c = ppoly([0]*128+[1]+[0]*127)
c = ppoly([1]+[0]*255)
c = ppoly([0]*128+[q//4+1]+[0]*127)
c = ppoly([q//4+1]+[0]*255)
c = ppoly([0]*128+[q//2+1]+[0]*127)
c = ppoly([q//2+1]+[0]*255)
c = ppoly([0]*128+[q//8+1]+[0]*127)
c = ppoly([q//8+1]+[0]*255)
2: enumerating the possibilities for
d[0]and gettingsSince
d[0] < q/2may be false for some values ofc, we iterate over every possible combination of side choices that need to be flipped versus which don’t and use these guesses to try to determine each coefficient ofs.
Solve Script
My exploit during the competition was really messy, but I don’t wanna mess it up by fixing it so much time after the competition, so here goes nothing:
import struct, hashlib, random, os
from Crypto.Cipher import AES
from pwn import *
n = 256
q = 11777
w = 8
sample = lambda rng: [bin(rng.getrandbits(w)).count('1') - w//2 for _ in range(n)]
add = lambda f,g: [(x + y) % q for x,y in zip(f,g)]
def mul(f,g):
r = [0]*n
for i,x in enumerate(f):
for j,y in enumerate(g):
s,k = divmod(i+j, n)
r[k] += (-1)**s * x*y
r[k] %= q
return r
def genkey():
a = [random.randrange(q) for _ in range(n)]
rng = random.SystemRandom()
s,e = sample(rng), sample(rng)
b = add(mul(a,s), e)
return s, (a,b)
center = lambda v: min(v%q, v%q-q, key=abs)
extract = lambda r,d: [2*t//q for u,t in zip(r,d) if u]
ppoly = lambda g: struct.pack(f'<{n}H', *g).hex()
ppolyrev = lambda g: list(struct.unpack(f'<{n}H', bytes.fromhex(g)))
pbits = lambda g: ''.join(str(int(v)) for v in g)
pbitsinv = lambda g: list(map('01'.index, g))
hbits = lambda g: hashlib.sha256(pbits(g).encode()).digest()
mkaes = lambda bits: AES.new(hbits(bits), AES.MODE_CTR, nonce=b'')
def encaps(pk):
seed = os.urandom(32)
rng = random.Random(seed)
a,b = pk
s,e = sample(rng), sample(rng)
c = add(mul(a,s), e)
d = add(mul(b,s), e)
r = [int(abs(center(2*v)) > q//7) for v in d]
bits = extract(r,d)
return bits, (c,r)
def decaps(sk, ct):
s = sk
c,r = ct
d = mul(c,s)
return extract(r,d)
remote = remote('116.203.41.47', 4421)
pk = [0,0]
ct = [0,0]
pk[0] = (remote.recvline()[7:-1]).decode()
pk[1] = (remote.recvline()[7:-1]).decode()
pk[0] = list(struct.unpack(f'<{n}H', bytes.fromhex(pk[0])))
pk[1] = list(struct.unpack(f'<{n}H', bytes.fromhex(pk[1])))
ct[0] = (remote.recvline()[7:-1]).decode()
ct[1] = (remote.recvline()[7:-1]).decode()
c = list(struct.unpack(f'<{n}H', bytes.fromhex(ct[0])))
r = list(map('01'.index, ct[1]))
flag = remote.recvline()[6:-1].decode()
"""
key = ['' for _ in range(256)]
txt = b'hxp<3you'
test0 = (mkaes([1]+[0]+[0]*255).encrypt(b'hxp<3you').hex())
test1 = (mkaes([1]+[1]+[0]*255).encrypt(b'hxp<3you').hex())
c1 = ppoly([1]+[0]*255)
c2 = ppoly([q//4+1]+[q]*255)
c3 = ppoly([q//2+1]+[q]*255)
c4 = ppoly([q//8+1]+[q]*255)
"""
c = ppoly([0]*128+[1]+[0]*127)
eq11 = []
neq11 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq11.append(i-128)
else:
neq11.append(i-128)
c = ppoly([1]+[0]*255)
eq12 = []
neq12 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq12.append(i)
else:
neq12.append(i)
c = ppoly([0]*128+[q//4+1]+[0]*127)
eq21 = []
neq21 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq21.append(i-128)
else:
neq21.append(i-128)
c = ppoly([q//4+1]+[0]*255)
eq22 = []
neq22 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq22.append(i)
else:
neq22.append(i)
c = ppoly([0]*128+[q//2+1]+[0]*127)
eq31 = []
neq31 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq31.append(i-128)
else:
neq31.append(i-128)
c = ppoly([q//2+1]+[0]*255)
eq32 = []
neq32 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq32.append(i)
else:
neq32.append(i)
c = ppoly([0]*128+[q//8+1]+[0]*127)
eq41 = []
neq41 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq41.append(i-128)
else:
neq41.append(i-128)
c = ppoly([q//8+1]+[0]*255)
eq42 = []
neq42 = []
r = pbits([1]*(n//2) + [0]*(n//2))
remote.sendline(c)
remote.sendline(r)
ref = (remote.recvline())
for i in range(128,256):
r = pbits([1]*127 + [0]*(i-127) + [1] + [0]*(256-i-1))
remote.sendline(c)
remote.sendline(r)
txt = (remote.recvline())
if txt==ref:
eq42.append(i)
else:
neq42.append(i)
print(eq11)
print(neq11)
print(eq12)
print(neq12)
print(eq21)
print(neq21)
print(eq22)
print(neq22)
print(eq31)
print(neq31)
print(eq32)
print(neq32)
print(eq41)
print(neq41)
print(eq42)
print(neq42)
print(ct)
print(pk)
print(flag)
For part 2, I wrote a simple script using itertools to test the combinations.
Flag: hxp{e4zy_p34zY_p34nuT_Bu7t3r}
