Zukane CTF

Square Up (UMDCTF 2025)

RSA Legendre Symbol
Challenge overview

In this CTF challenge we are given the following encryption script:

from Crypto.Util.number import getPrime, bytes_to_long, long_to_bytes
from os import urandom

with open("flag.txt", "rb") as f:
        flag = f.read()
        m = bytes_to_long(flag)

p = q = 0
while p % 4 != 3: p = getPrime(384)
while q % 4 != 3: q = getPrime(384)

N = p * q
print(f"{N = }")
def encrypt(m):
        lp = (pow(m, (p-1)//2, p) + 1) % p - 1
        lq = (pow(m, (q-1)//2, q) + 1) % q - 1
        return m * m % N, lp, lq

def decrypt(c, lp, lq):
        yq = pow(q, -1, p)
        yp = (1 - yq * q) // p

        mp = pow(c, (p+1)//4, p)
        mq = pow(c, (q+1)//4, q)

        if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
        if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mp

        return (yp * p * mq + yq * q * mp) % N


c, lp, lq = encrypt(m)
print(f"{(c, lp, lq) = }")
print(f"{long_to_bytes(decrypt(c, lp, lq)) = }")   

as well as the output:

N = 1298690852855676717877172430649235439701166577296380685015744142960768447038281361897617173145966407353660262643273693068083328108519398663073368426744653753236312330497119252304579628565448615356293308415969827357877088267274695333
(c, lp, lq) = (162345251908758036296170413099695514860545515965805244415511843227313118622229046299657295062100889503276740904118647336251473821440423216697485906153356736210597508871299190718706584361947325513349221296586217139380060755033205077, 1, -1)
long_to_bytes(decrypt(c, lp, lq)) = b'\x1bR \xc4\xf0\x8f\xa7l\xa4\xdd\xbf\xf73\xf3\xe9(\xc8Q\xdd\xbd,\x08\xbd\x7f\xafm\x9b\xbf\xa0\xbe\xd4)t\xd4e\xc0,J\xb8H\x93i\xea\xbcy\x9a7AA\xeb]q\xae\x00\xebJ(Y\x8a\xa4B\xdc\t(\x8b\xcef&@b\x91\x06Y~\x88m\xaf\x9bl\\\x12\xf2\x9f\xe1\x1f\x18q\x16\xd8\xb4\x9f$\x88%8\x0f'

We are given the modulus $N$, the ciphertext $c$, the legendre symbols $lp$ and $lq$, in addition to a broken plaintext $x$.

The code implements a RSA-like system with public exponent $e=2$:

\[\large c = m^{2} \mod N\]

The legendre symbols are used to distinguish the four square roots of $c \mod N$

Source code analysis

There is a mistake in the implementation of the decryption function:

if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mp # <-- !!!

This section essentially checks whether the calculated modular square roots $mp$ and $mq$ match the given legendre symbols $lp$ and $lq$. If the wrong square root is picked, the other square root is picked instead. However, $mq$ is updated to an incorrect value. The correct implementation would be to have:

mp = p - mp
mq = q - mq # <-- use mq instead of mp again

This mistake means that the decrypted output $x$, recovered via CRT:

return (yp * p * mq + yq * q * mp) % N

doesn’t hold for $x^{2} = c \mod N$. It only holds for $x^{2}=c \mod p$ but not for $q$.

We can use this fact to recover the prime factors like so:

\[\large \begin{align} \nonumber x^{2} \equiv c \mod p \\ \nonumber x^{2}- c \equiv 0 \mod p \\ \nonumber N \equiv 0 \mod p \\ \nonumber gcd(x^{2}-c, N) = p \end{align}\]

With the recovered prime factors, we can reimplement the decryption function with the bug fixed:

def decrypt(c, lp, lq):
        yq = pow(q, -1, p)
        yp = (1 - yq * q) // p

        mp = pow(c, (p+1)//4, p)
        mq = pow(c, (q+1)//4, q)

        if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
        if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mq # <- fix bug

        return (yp * p * mq + yq * q * mp) % N

and we can simply decrypt to get the flag!

UMDCTF{e=3_has_many_attacks_and_e=2_has_its_own_problems...maybe_we_should_try_e=1_next?}
solve.py
from Crypto.Util.number import bytes_to_long
from math import gcd

x = bytes_to_long(b'\x1bR \xc4\xf0\x8f\xa7l\xa4\xdd\xbf\xf73\xf3\xe9(\xc8Q\xdd\xbd,\x08\xbd\x7f\xafm\x9b\xbf\xa0\xbe\xd4)t\xd4e\xc0,J\xb8H\x93i\xea\xbcy\x9a7AA\xeb]q\xae\x00\xebJ(Y\x8a\xa4B\xdc\t(\x8b\xcef&@b\x91\x06Y~\x88m\xaf\x9bl\\\x12\xf2\x9f\xe1\x1f\x18q\x16\xd8\xb4\x9f$\x88%8\x0f')
N  = 1298690852855676717877172430649235439701166577296380685015744142960768447038281361897617173145966407353660262643273693068083328108519398663073368426744653753236312330497119252304579628565448615356293308415969827357877088267274695333
c  = 162345251908758036296170413099695514860545515965805244415511843227313118622229046299657295062100889503276740904118647336251473821440423216697485906153356736210597508871299190718706584361947325513349221296586217139380060755033205077
lp = 1
lq = -1

p = gcd(x*x - c, N)
q = N // p
assert p*q == N

def decrypt(c, lp, lq):
        yq = pow(q, -1, p)
        yp = (1 - yq * q) // p

        mp = pow(c, (p+1)//4, p)
        mq = pow(c, (q+1)//4, q)

        if (pow(mp, (p-1)//2, p) - lp) % p != 0: mp = p - mp
        if (pow(mq, (q-1)//2, q) - lq) % q != 0: mq = q - mq # <- fix bug

        return (yp * p * mq + yq * q * mp) % N

print(bytes.fromhex(f"{decrypt(c, lp, lq):x}").decode())