In this CTF challenge we are given the following encryption script:
from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes
from os import urandom
with open("flag.txt", "rb") as f:
flag = f.read()
m = bytes_to_long(flag)
p = q = 0
while p % 4 != 3: p = getPrime(384)
while q % 4 != 3: q = getPrime(384)
N = p * q
print(f"{N = }")
def encrypt(m):
lp = (pow(m, (p-1)//2, p) + 1) % p - 1
lq = (pow(m, (q-1)//2, q) + 1) % q - 1
return m * m % N, lp, lq
def decrypt(c, lp, lq):
yq = pow(q, -1, p)
yp = (1 - yq * q) // p
mp = pow(c, (p+1)//4, p)
mq = pow(c, (q+1)//4, q)
if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mp
return (yp * p * mq + yq * q * mp) % N
c, lp, lq = encrypt(m)
print(f"{(c, lp, lq) = }")
print(f"{long_to_bytes(decrypt(c, lp, lq)) = }")
as well as the output:
N = 1298690852855676717877172430649235439701166577296380685015744142960768447038281361897617173145966407353660262643273693068083328108519398663073368426744653753236312330497119252304579628565448615356293308415969827357877088267274695333
(c, lp, lq) = (162345251908758036296170413099695514860545515965805244415511843227313118622229046299657295062100889503276740904118647336251473821440423216697485906153356736210597508871299190718706584361947325513349221296586217139380060755033205077, 1, -1)
long_to_bytes(decrypt(c, lp, lq)) = b'\x1bR \xc4\xf0\x8f\xa7l\xa4\xdd\xbf\xf73\xf3\xe9(\xc8Q\xdd\xbd,\x08\xbd\x7f\xafm\x9b\xbf\xa0\xbe\xd4)t\xd4e\xc0,J\xb8H\x93i\xea\xbcy\x9a7AA\xeb]q\xae\x00\xebJ(Y\x8a\xa4B\xdc\t(\x8b\xcef&@b\x91\x06Y~\x88m\xaf\x9bl\\\x12\xf2\x9f\xe1\x1f\x18q\x16\xd8\xb4\x9f$\x88%8\x0f'
We are given the modulus $N$, the ciphertext $c$, the legendre symbols $lp$ and $lq$, in addition to a broken plaintext $x$.
The code implements a RSA-like system with public exponent $e=2$:
\[\large c = m^{2} \mod N\]The legendre symbols are used to distinguish the four square roots of $c \mod N$
There is a mistake in the implementation of the decryption function:
if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mp # <-- !!!
This section essentially checks whether the calculated modular square roots $mp$ and $mq$ match the given legendre symbols $lp$ and $lq$. If the wrong square root is picked, the other square root is picked instead. However, $mq$ is updated to an incorrect value. The correct implementation would be to have:
mp = p - mp
mq = q - mq # <-- use mq instead of mp again
This mistake means that the decrypted output $x$, recovered via CRT:
return (yp * p * mq + yq * q * mp) % N
doesn’t hold for $x^{2} = c \mod N$. It only holds for $x^{2}=c \mod p$ but not for $q$.
We can use this fact to recover the prime factors like so:
\[\large \begin{align} \nonumber x^{2} \equiv c \mod p \\ \nonumber x^{2}- c \equiv 0 \mod p \\ \nonumber N \equiv 0 \mod p \\ \nonumber gcd(x^{2}-c, N) = p \end{align}\]With the recovered prime factors, we can reimplement the decryption function with the bug fixed:
def decrypt(c, lp, lq):
yq = pow(q, -1, p)
yp = (1 - yq * q) // p
mp = pow(c, (p+1)//4, p)
mq = pow(c, (q+1)//4, q)
if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mq # <- fix bug
return (yp * p * mq + yq * q * mp) % N
and we can simply decrypt to get the flag!
UMDCTF{e=3_has_many_attacks_and_e=2_has_its_own_problems...maybe_we_should_try_e=1_next?}
from Crypto.Util.number import bytes_to_long
from math import gcd
x = bytes_to_long(b'\x1bR \xc4\xf0\x8f\xa7l\xa4\xdd\xbf\xf73\xf3\xe9(\xc8Q\xdd\xbd,\x08\xbd\x7f\xafm\x9b\xbf\xa0\xbe\xd4)t\xd4e\xc0,J\xb8H\x93i\xea\xbcy\x9a7AA\xeb]q\xae\x00\xebJ(Y\x8a\xa4B\xdc\t(\x8b\xcef&@b\x91\x06Y~\x88m\xaf\x9bl\\\x12\xf2\x9f\xe1\x1f\x18q\x16\xd8\xb4\x9f$\x88%8\x0f')
N = 1298690852855676717877172430649235439701166577296380685015744142960768447038281361897617173145966407353660262643273693068083328108519398663073368426744653753236312330497119252304579628565448615356293308415969827357877088267274695333
c = 162345251908758036296170413099695514860545515965805244415511843227313118622229046299657295062100889503276740904118647336251473821440423216697485906153356736210597508871299190718706584361947325513349221296586217139380060755033205077
lp = 1
lq = -1
p = gcd(x*x - c, N)
q = N // p
assert p*q == N
def decrypt(c, lp, lq):
yq = pow(q, -1, p)
yp = (1 - yq * q) // p
mp = pow(c, (p+1)//4, p)
mq = pow(c, (q+1)//4, q)
if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mq # <- fix bug
return (yp * p * mq + yq * q * mp) % N
print(bytes.fromhex(f"{decrypt(c, lp, lq):x}").decode())