Introduction

This year, I finished Flare-On for the fourth time. I’m hoping the rumors about this possibly being the last Flare-On aren’t true, since I fell just short of my goal of top 50 and I want a second chance. I’m not sure if this is the hardest Flare-On challenge I’ve ever done (last year’s challenge serpentine was pretty insane too) but it’s easily the hardest one I’ve ever done a writeup of.

This challenge has a lot of different steps to it, many of which take a lot of time to compute and are easy to do wrong. For the most part, I handled this by writing a lot of different scripts to perform each intermediate step and saving the results to a JSON file. This was very messy and not at all optimal, but it did manage to get me the flag in the end. I’d be interested to know how other people did or didn’t optimize their solve scripts for this challenge.

Initial Observations

Te main function begins with a check of a license file called license.bin:

uint64_t main()
    sub_140010090()
    sub_140082940(
        f_print_to_stdout(&stdout_ostream, string_val: "checking license file..."), 
        sub_1400c58c0)
    void ifstream
    int512_t zmm1 = f_get_file(&ifstream, "license.bin", sub_1400c7d60(4, 2))

We can see that the license file has to be exactly 34000 bytes long, and that its SHA256 hash is used as a key to decrypt the flag.

if (license_length == 340000)
    f_do_smth_ifstream(&ifstream, 0, 0, zmm1)
    int16_t* buf = sub_1400c9ed0(sx.q(license_length))
    f_read_file(&ifstream, buf, sx.q(license_length))
    void hash_maybe
    f_sha256_license_file(buf, &buf[0x29810], &hash_maybe, &ifstream, 0x100000)

Additionally, there’s a validation function that’s run on the license file to check whether decryption of the flag should be performed. There are two conditions that need to be met: the loop performing the validation code must reach 10000 iterations, and a buffer must contain a specific expected value. I’ll refer to this buffer as the “checksum buffer” from here on out.

while (true)
    if (counter s> 9999)
        if (memcmp(_Buf1: &check_buf, _Buf2: &expected_buf, _Size: 0x9c40)
                == 0)
            sub_140082940(
                f_print_to_stdout(&stdout_ostream, 
                    string_val: "license valid!"), 

10000 Executables

The challenge binary contains 10000 resources starting with the string M8Z. Each of them looks kind of like a PE file, containing a mangled version of the string This program cannot be run in DOS mode.

00000000: 4d 38 5a 90 38 03 66 02 04 09 71 ff 81 b8 c2 91  M8Z.8.f...q.....
00000010: 01 40 c2 15 c6 80 0b 1c 0e 1f ba f8 00 b4 09 cd  .@..............
00000020: 21 b8 01 4c 80 54 68 01 69 73 20 70 72 6f 67 cc  !..L.Th.is prog.
00000030: 61 6d f0 63 e8 6e e3 e9 74 dc 62 65 e7 f9 75 e7  am.c.n..t.be..u.
00000040: a3 69 0e 06 44 4f 53 80 6d 6f 64 65 2e 71 0d 29  .i..DOS.mode.q.)

Googling the string M8Z, I found that the resources were in fact PE files compressed with aPlib. This explained the 10000-iteration loop in the license validation code: the license file would somehow be checked using functions from all 10000 of these DLLs.

Sure enough, sub_140001482 was responsible for retrieving the DLL at a given index and calling its entry point, passing in the checksum buffer as an argument to the entry function.

    FindResourceA(hModule: nullptr, lpName: zx.q(resource_index), lpType: 0xa)
char* var_68_1 =
    LockResource(hResData: LoadResource(hModule: nullptr, hResInfo))
uint32_t compressed_size = SizeofResource(hModule: nullptr, hResInfo)
int32_t decompressed_size = sub_140002690(var_68_1, zx.q(compressed_size))
void var_1c8
sub_1400c0c30(&var_1c8, sx.q(decompressed_size))
struct pe_headers* pe_file = sub_1400297d0(&var_1c8, 0)
int32_t var_1d0_1 = 0
f_decompress_pe(var_68_1, pe_file, zx.q(compressed_size), 
    sx.q(decompressed_size), nullptr)

//[...]

int32_t var_cc_1 = (dll_addr
    + zx.q(rax_21->OptionalHeader.AddressOfEntryPoint))(&check_buf, 1, 0)

The License File Format

My first step was to get an idea of how the file format of the license file was structured so that I could generate a test license file to use in a debugger. I found that on each run of the loop, the validation function reads two values from the license. The first field of the license file is a 16-bit integer that must be between 0 and 9999:

uint16_t license_index = license_buf->index

if (license_index u> 9999)
    sub_140082940(
        f_print_to_stdout(&stdout_ostream, 
            string_val: "invalid license file"), 
        sub_1400c58c0)
    rbx_1 = 1
    break

The second field is a sequence of 32 bytes that gets passed to a function called check. If the check function succeeds, a value gets written to the checksum buffer, but if it fails, the license is invalid.

f_get_func_by_name(&var_78, "_Z5checkPh", rsi)
char check_result =  // check function called here
    (*f_get_next_export_addr(rbx_2, &var_78))(license_data) ^ 1
sub_1400b4c10(&var_78)

if (check_result != 0)
    sub_140082940(
        f_print_to_stdout(&stdout_ostream, 
            string_val: "invalid license file"), 
        sub_1400c58c0)
    rbx_1 = 1
    break

license_buf = license_data + 0x20
f_write_to_check_buf(counter)
counter += 1

After figuring this out, the license length of 34000 made sense. Each of the 10000 DLLs had a 2-byte index and a 32-byte data sequence associated with it. The index field had to refer to which of the DLLs to run, and the data field was an input that the executable would validate.

The Executable Contents

The Transformation Functions

Each DLL contains many functions with names starting with f followed by a series of numbers. Many of these functions are almost identical except for some constant values that they reference, and it turns out they fall into 3 main categories.

The first is a substitution of each byte of the input:

uint64_t f73179180583603935578(uint8_t* arg1)
    *arg1 ^= *data_786020
    uint8_t sbox[0x100]
    __builtin_memcpy(dest: &sbox, 
        src: "\x80\x99\x10\xb8\x5b\x23\x97\xf8\x12\x18\x65\x67\xb5\x02\xc3\x8e\x44\xd7\x01\x"
    "d6\x7a\x28\x15\x63\x76\x98\x13\x05\x34\x60\xec\x61\xfa\xb4\x4d\xe4\xf2\xb3\xe1\xab"
    "ea\xb2\xd0\x94\x30\x4f\x0d\x33\xfc\x4a\x5c\xb0\x4c\xd1\x2f\x41\xee\x62\x31\x7b\x54"
    "91\xf6\xdc\xf7\x1a\xe3\xaa\x71\x3a\x08\xce\xaf\x3b\x86\x82\x3c\x56\x35\xc8\x29\x57"
    "3d\x43\x64\xff\x9f\xbe\xbd\xa9\x45\xd9\x7f\x3e\xe2\xf5\xb6\x88\x87\xe0\x81\x03\xe9"
    "25\xcb\xdf\x9d\xdb\xf0\x48\xd5\x49\x68\xc2\x6f\x2c\xbf\x47\x1e\xed\x24\x5e\xe8\x16"
    "c9\x9b\x09\xd3\x"
        count: 0x100)
    uint64_t result = 0x27f36e9c748f9017
    
    for (int32_t i = 0; i s<= 0x1f; i += 1)
        result = zx.q(sbox[sx.q(zx.d(arg1[sx.q(i)]))])
        arg1[sx.q(i)] = result.b
    
    return result

The second is a “shuffle” function that permutes the bytes of the input:

uint64_t f49843883851507363229(uint8_t* arg1)
    *arg1 ^= *data_786020
    uint8_t shuffle[0x20]
    __builtin_memcpy(dest: &shuffle, 
        src: "\x18\x03\x09\x1e\x08\x0b\x07\x1f\x1a\x17\x1b\x1c\x19\x1d\x0c\x13\x15\x11\x02\x"
    "0d\x0f\x16\x05\x14\x0e\x01\x00\x04\x10\x06\x12\x0a", 
        count: 0x20)
    uint64_t result = 0x1405160f0d021115
    int64_t var_58
    __builtin_memset(dest: &var_58, ch: 0, count: 0x20)
    
    for (int32_t i = 0; i s<= 0x1f; i += 1)
        result = sx.q(i)
        *(&var_58 + result) = arg1[zx.q(shuffle[sx.q(i)])]
    
    for (int32_t i_1 = 0; i_1 s<= 0x1f; i_1 += 1)
        result = zx.q(*(&var_58 + sx.q(i_1)))
        arg1[sx.q(i_1)] = result.b
    
    return result

The third type of function does something more complicated. First it ORs the first bit of the input with 1, while keeping track of what the original value of the bit was. Then, it does something involving a 32-byte constant value, and then the result is XORed with the original first bit again in order to ensure the relationship between input and output remains 1-to-1.

uint8_t* f92961177136248183669(uint8_t* arg1)
    *arg1 ^= *data_786020
    char rax_6 = *arg1 & 1
    *arg1 |= 1
    uint8_t exp[0x20]
    __builtin_memcpy(dest: &exp, 
        src: "\x5f\xc3\x2c\xb6\x7b\xe3\x24\x44\x19\xfd\x97\x84\x57\x82\xbc\xc1\x80\xad\x21\x"
    "43\x21\x83\xd1\x6e\xb5\xd0\xe6\xc5\x88\x17\xfb", 
        count: 0x1f)
    uint8_t result[0x20]
    __builtin_memcpy(dest: &result, 
        src: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x"
    "00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 

// [...]

            if ((zx.d(exp[sx.q(i_1)]) s>> j.b & 1) != 0)
                int32_t var_18_1 = 0
                
                for (int32_t k = 0; k s<= 0x1f; k += 1)
                    for (int32_t var_20_1 = 0; var_20_1 s<= k; var_20_1 += 1)
                        var_18_1 += zx.d(input[sx.q(k - var_20_1)])
                            * zx.d(result[sx.q(var_20_1)])

// [...]

    *arg1 = *arg1 ^ rax_6 ^ 1
    return arg1

To figure out what the mathematical operation was, I first noticed that 1) each bit of the 32-byte constant was checked to see if it was 1 or 0, and 2) if the bit was 1, a multiplication was performed and saved in an accumulator value. I was already familiar with the repeated-squaring method of modular exponentiation, so I guessed that was probably what was happening here. That also explained the OR with the low bit of the base: since the modulus is a power of 2, performing the modular exponentation on an even number would almost certainly have a result of 0. In order to ensure the exponentation operation could be inverted, the exponentation had to be performed on an odd number.

For all three types of functions, the user input is first XORed with a 32-bit integer pulled from the checksum buffer. The index into the array is equal to the numeric index of the DLL.

For example, from 1756.dll:

uint64_t f92027960352701313593(uint8_t* arg1)
    *arg1 ^= xor_data[1756]
    int64_t var_38

And from 3463.dll:

uint64_t f88730141134197849752(uint8_t* arg1)
    *arg1 ^= xor_data[3463]
    int64_t var_118

The Check Function

The numbered functions are called in a function called check. Each executable imports numbered functions from other executables and calls them in addition to its own.

uint64_t check(uint8_t* arg1)
    f38236877289593244403(arg1)
    f32422423688401106395(arg1)
    f14945866699034032513(arg1)
    f42785923630423812381(arg1)
    f01804799128886574724(arg1)
    f63971385288077461058(arg1)
    f32043095044683465902(arg1)
    f35008263310466305120(arg1)
    f39008422158736727341(arg1) // Import
    f77204538293621348673(arg1)
    f63086630871216239504(arg1)
    f62306332011382063629(arg1)
    f16456010376118152083(arg1)
    f09449763057560931840(arg1) // Import
    f52479682207617065503(arg1) // Import

After all of these transformations are called on the input to check, we get to the actual checking step. First, the 32-byte transformed input gets XORed with the values in an array of 16 128-bit integers. Binary Ninja had a hard time with the decompilation here for some reson, but this loop is just XORing one array with another.

uint64_t compare_val = 0xdc37c0e304978087
int64_t var_80 = 0
int64_t var_90 = 0x594b7f91f11228e5
int128_t xor_vals[0x10]
xor_vals[0].q = 0x264f1c2a310e43aa
xor_vals[0]:8.q = 0
xor_vals[1].q = 0x6f62577ddb8f7c8
// [...]
xor_vals[0xe]:8.q = 0
xor_vals[0xf].q = 0xf3a55fbbc4837e78
xor_vals[0xf]:8.q = 0
int128_t _Buf1[0x10]

for (int32_t i = 0; i s<= 0xf; i += 1)
    int64_t* rax_595 = (sx.q(i) << 4) + 0x530 + &_Buf1[4] - 0x170
    uint32_t rax_598 = i s>> 0x1f u>> 0x1e
    int64_t rdx_6 = rax_595[1]
    int64_t* rcx_595 = (sx.q(i) << 4) + 0x530 + &_Buf1[4] - 0x170
    // XOR the input data with value from XOR table, and write
    // it back to the XOR table. Each input value gets xored
    // with 4 table values.
    *rcx_595 = *(arg1 + (sx.q(((i + rax_598) & 3) - rax_598) << 3)) ^ *rax_595
    rcx_595[1] = rdx_6
    int64_t* rax_611 = (sx.q(i) << 4) + 0x530 + &_Buf1[4] - 0x170
    rax_611[1]
    *rax_611
    bool c_2 = unimplemented  {sbb rax, qword [rbp+0x548]}
    
    // All the XORs need to be less than compare_val
    if (not(c_2))
        return 0

Then we have a really long unrolled loop with a lot of modular additions and multiplications:

// [...]
int64_t rax_686 = out_33.q
val = rax_686 + out_22.q
int64_t var_5c0_34 = adc.q(out_33:8.q, out_22:8.q, rax_686 + out_22.q u< rax_686)
modulus = compare_val
int64_t var_5d0_35 = var_80
int128_t out_34 = modulo_int128(&val, &modulus, out: out_33)
// [...]

And then, finally, the result of that is another array of 16 128-bit integers, which is compared against a target value. If they match, we’ve passed the check.

// [...]
_Buf1[0xd]:8.q = 0
_Buf1[0xe].q = 0xa5b7c08151fface8
_Buf1[0xe]:8.q = 0
_Buf1[0xf].q = 0xc7b8d0a6d71a6e00
_Buf1[0xf]:8.q = 0
int32_t rax_884
rax_884.b = memcmp(&_Buf1, &_Buf2, _Size: 0x100) == 0

I guessed pretty early on that the operations that were happening in the unrolled loop could be some kind of operations on a matrix, with the array of 16 values being treated as a 4x4 matrix. (It helped that I’d seen a few people complaining that the challenge was too hard because there was math in it.) This was confirmed when I looked at which indices of the array were being multiplied together at each step and realized that it was consistent with the calculation of a determinant.

The check function returns failure if the determinant it calulates is 0, which told me that whatever step came next required the 4x4 matrix to be invertible. Rather than manually reverse all the math, I tried patching in a bunch of test values for the matrix in x64dbg and checked to see what the final result was. For example, passing in the identity matrix as an input gave me the identity matrix as an output, and passing in a diagonal matrix other than the identity as an input gave me a different diagonal matrix as an output. After trying out a couple different common operations that might produce this result, it turned out that the operation was modular exponentiation, with each check function using a different prime modulus.

(This kind of educated guessing is almost always how I approach reverse engineering mathematical functions. Often, something that looks like a very complicated algorithm turns out to be something simple like multiplication or exponentation over big integers, so it’s a good idea to test for that first. It’s often possible to make a very good guess about what a function does just by looking at what happens when one of the inputs is 0 or 1. In this case, I was already on the lookout for modular exponentiation because it was one of the three types of transformation functions called before check.)

The Dependency Graph

The Checksum Buffer

After each of the DLLs is finished running, the checksum buffer is written to. The checksum buffer is an array of 10000 32-bit integers: one for each of the DLLs. The program iterates through each of the currently loaded DLLs, then adds the current value of the counter to the index associated with each DLL before freeing the memory associated with the DLL.

int64_t f_write_to_check_buf(int32_t counter)
    void* var_10 = &resources_vec
    int64_t resources_start = f_get_vec_start(&resources_vec)
    int64_t resources_end = f_get_vec_end(var_10)
    
    while (f_compare_args(&resources_start, &resources_end) != 1)
        struct resource_struct* rax_4 = *f_dereference_ptr(&resources_start)
        *((sx.q(zx.d(rax_4->index.w)) << 2) + &check_buf) += counter
        VirtualFree(lpAddress: rax_4->addr, dwSize: 0, dwFreeType: MEM_RELEASE)
        sub_140020560(&resources_start)
    
    return sub_1400af7b0(&resources_vec)

If only a single DLL was loaded each time, it would be easy to determine the order in which the DLLs were supposed to run based on the expected final value of the checksum buffer. For example, if index 5402 contained the number 1, we would know 5402.dll was supposed to run when the counter equaled 1.

However, more than one DLL is loaded each time: the DLLs depend on each other, and when a DLL is loaded, its dependencies need to be loaded (and the dependencies of those dependencies, etc.). That means that the value contained in each index is the sum of all counter values where the corresponding DLL was loaded as a dependency somewhere along the chain.

Constructing the Import Graph

I started by traversing the imports section of each of the DLLs to determine the direct imports of each DLL, and wrote the result to a big JSON file.

def get_imports(idx):
    imports = []
    filepath = exe_paths[idx]
    pe = pefile.PE('exes/' + filepath)
    pe.parse_data_directories(directories=[1])  # DIRECTORY_ENTRY['IMAGE_DIRECTORY_ENTRY_IMPORT']
    for iid in pe.DIRECTORY_ENTRY_IMPORT:
        dll_name = iid.dll.decode('ascii')
        num = dll_name.split('.')[0]
        if num.isdecimal():
            imports.append(int(num))
    return imports

imports = {}
for i in range(10000):
    imports[i] = get_imports(i)

j = json.dumps(imports)
f = open('imports.json', 'w')
f.write(j)
f.close()

I then recursively traversed the parsed list of imports for each of the DLLs, giving me a full list of the imports that would be loaded for each DLL (i.e., a list that included not just direct dependencies, but also dependencies of those dependencies, and so on), and wrote that to an even bigger JSON file. (I’m sure there are much nicer ways to create and store large graphs in Python, but to be honest, I was well into “extremely cursed CTF code” territory at this point and I didn’t feel like learning about any of those.)

visited = set()
def get_imports_recursive(idx, visited):
    imports = []
    for num in imports_json[idx]:
        if num not in visited:
            imports += get_imports_recursive(num, visited)
        else:
            print('already visited', num)
    imports.append(idx)
    visited.add(idx)
    return imports

f = open('imports_chain.jsonl', 'w')

for i in range(10000):
    print('Getting imports for:', i)
    imps_i = get_imports_recursive(i, set())
    j = json.dumps({i: imps_i})
    f.write(j + '\n')

One thing that I noticed during this process was that there were no circular dependencies, i.e., there are no DLLs A.dll and B.dll such that A.dll imports B.dll and B.dll imports A.dll. That meant that the chain of imports had to eventually end in DLLs that aren’t loaded by anything other than themselves. It’s easy to figure out where those DLLs fall in the run order: if A.dll is only loaded once and the checksum buffer contains the value x at index A, then it’s supposed to be run when the counter is at value x.

Now consider a slightly more complicated case: Suppose A.dll and B.dll aren’t loaded by any other DLLs, and that their corresponding checksum buffer values are x and y respectively. Now suppose C.dll is a DLL loaded only by A.dll and B.dll, and the value in its checksum buffer is z. Then the value in the checksum buffer is equal to the sum of the counter values when A.dll, B.dll, and C.dll are run, since those are all the times when C.dll is loaded. That means that C must be run when the counter is at z - x - y.

Note that the dependency graph I initially constructed tracks the DLLs that each DLL imports, but what we actually need for this calculation is a graph of the DLLs that each DLL is imported by:

imported_by = {}
for i in range(10000):
    imported_by[i] = []

f = open('imports_chain.jsonl').read()
for line in f.split('\n'):
    data = json.loads(line)
    for k, v in data.items():
        for val in v:
            imported_by[val].append(int(k))

We can then recursively walk the “imported by” graph until we get to a DLL that isn’t imported by anything, and calculate the counter value for each of the DLLs.

done = {}
def walk_graph(idx):
    if idx in done:
        return done[idx]
    counter_val = targets[idx]
    for num in imported_by[idx]:
        counter_val -= walk_graph(num)
    print('Got:', idx, counter_val)
    done[idx] = counter_val
    return counter_val

for i in range(10000):
    if i not in done:
        walk_graph(i)

As expected, the values obtained from the script were each of the numbers 0 through 9999, which told me I had the right order for the DLLs.

Passing the Check

Inverting the Matrix Exponentiation

First, I needed to extract the target matrix to compare to, along with the modulus, the exponent, and the table of XOR values that are used to generate the matrix to exponentiate. For something like this I would normally use a disassembler like capstone, but since there are so many functions to process, it would’ve taken a long time. Luckily, there was very little variation in the disassembly of the functions, so it was possible to extract everything using regular expressions. The constant values were always loaded using the register rax, and they were always loaded in the same order: the modulus, then the exponent, then the XOR table, then the target matrix.

import pefile
import glob
import json
import re

# 4881c4c0000000     add     rsp, 0xc0            |   4883c450           add     rsp, 0x50  | 415f               pop     r15 {__saved_r15}
# 5d                 pop     rbp {__saved_rbp}
# c3                 retn     {__return_addr}
cleanup_expr = rb'(\x48\x81\xc4....|\x48\x83\xc4.|\x41\x5f)\x5d\xc3'

# 55                 push    rbp {__saved_rbp}
# 4889e5             mov     rbp, rsp {__saved_rbp}  | 4881ec10010000     sub     rsp, 0x110 | 4157               push    r15 {__saved_r15}
start_expr = rb'\x55(\x48\x89\xe5|\x48\x81\xec....|\x41\x57)'

function_expr = start_expr + rb'.*?' + cleanup_expr

def build_matrix(l):
    m = []
    for i in range(4):
        m.append([0,0,0,0])
    for i in range(4):
        for j in range(4):
            m[i][j] = l[4*i+j]
    return m

# 48b8????????????????   mov     rax, ????????????????    
# ba00000000         mov     edx, 0x0   
mov_rax_expr = rb'\x48\xb8(?P<constval>.{8})' + rb'(\xba\x00\x00\x00\x00|\x48\x89)'

def extract_matrix(mapped, startaddr, endaddr):
    data = mapped[startaddr:endaddr]
    consts = []
    for i in re.finditer(mov_rax_expr, data, re.DOTALL | re.MULTILINE):
        constval = i.group("constval")
        consts.append(int.from_bytes(constval, 'little'))
    if len(consts) != 34:
        raise ValueError('extraction failed')
    else:
        n = consts[0]
        e = consts[1]
        xors = consts[2:18]
        matrix = consts[18:34]
        return n, e, xors, build_matrix(matrix)

def try_extract(filepath, out):
    r = {}

    pe = pefile.PE(filepath)
    mapped = pe.get_memory_mapped_image()

    d = [pefile.DIRECTORY_ENTRY["IMAGE_DIRECTORY_ENTRY_EXPORT"]]
    pe.parse_data_directories(directories=d)

    for sym in pe.DIRECTORY_ENTRY_EXPORT.symbols:
        r[sym.name.decode()] = sym.address

    # Find each function by matching the function prologue and epilogue
    functions = {}
    for i in re.finditer(function_expr, mapped, re.DOTALL | re.MULTILINE):
        functions[i.start()] = (i.start(), i.end())

    # Get the name of each function
    ranges = {}
    for k, v in r.items():
        if v in functions:
            ranges[k] = functions[v]
        else:
            print('fail:', filepath, k)

    try:
        n, e, xors, matrix = extract_matrix(mapped, ranges['_Z5checkPh'][0], ranges['_Z5checkPh'][1])
    except:
        print('fail:', filepath)
    j = {filepath: {'n': n, 'e': e, 'x': xors, 'm': matrix}}
    out.write(json.dumps(j) + '\n')

out = open('matrices.json', 'w')

g = glob.glob('exes/exe_0x*.exe')
for filepath in g:
    try_extract(filepath, out)

Once I had the target matrix mat, modulus n, and exponent e, I needed to invert the exponentiation operation. As it turned out, the way I ended up doing this was needlessly complicated, so I’ll cover both my way of solving it and the way that was given in the official writeup.

My solve script

My first thought was that there’s one case where it’s obvious how to approach the problem: if mat is diagonal, then we can find the eth root of mat just by taking the eth root of each entry. My linear algebra knowledge is a little rusty at this point, but I at least remembered enough to know that many matrices can be diagonalized. From the Wikipedia page:

In linear algebra, a square matrix A is called diagonalizable or non-defective if it is similar to a diagonal matrix. That is, if there exists an invertible matrix P and a diagonal matrix D such that P**-1 * A * P = D. This is equivalent to A = P * D * P**-1.

Note that if A = P * D * P**-1, then A**e = (P * D * P**-1)**e = P * D**e * P**-1. That means that if A is diagonalizable, then the problem of finding the eth root of A reduces to the easier problem of finding the eth root of D.

Unfortunately, not all the given matrices were diagonalizable. However, sometimes a matrix can be diagonalizable over one field but not another. For example, the matrix [[0, -1], [1, 0]] is not diagonalizable over the real numbers, but is diagonalizable over the complex numbers:

([[0, i], [-i, 0]])**-1  * [[0, -1], [1, 0]] * [[0, i], [-i, 0]] = [[i, 0], [0, -i]]

The case we’re looking at is a little different, as the elements of the matrices we’re given are not real numbers, but members of a finite field. However, it’s possible to extend other fields in a way that’s similar to going from the real numbers to the complex numbers, in a process called taking the algebraic closure. It turned out that even though the given matrices weren’t diagonalizable over the given finite fields, they were diagonalizable over the algebraic closures of those fields.

The final Sage script I ended up with is as follows:

import json

def get_eth_root(n, e, mat):
    A = Matrix(GF(n).algebraic_closure(), mat)

    if not A.is_diagonalizable():
        return None

    (D,P)=A.right_eigenmatrix()

    prod = P.inverse() * A * P

    x0 = prod[0][0].nth_root(e)
    x1 = prod[1][1].nth_root(e)
    x2 = prod[2][2].nth_root(e)
    x3 = prod[3][3].nth_root(e)

    X = Matrix(GF(n).algebraic_closure(), [
    [x0, 0, 0, 0],
    [0, x1, 0, 0],
    [0, 0, x2, 0],
    [0, 0, 0, x3]
    ])

    X2 = P * X * P.inverse()

    vals = []
    for row in X2:
        for i in row:
            try:
                vals.append(int(str(i)))
            except:
                return None
    return vals

lines = open('matrices.json').read()

out = open('roots.json', 'w')
for line in lines.split('\n'):    
    j = json.loads(line)
    for filename, data in j.items():
        root_mat = get_eth_root(data['n'], data['e'], data['m'])
        res = json.dumps({filename: root_mat})
        out.write(res + '\n')

The simpler way to solve it

My solution was overly complicated because I got way too hung up on the idea of diagonalizing the matrices, when I really should’ve realized that the matrices didn’t need to be diagonalized at all. In fact, finding the eth root of a matrix isn’t any harder, or fundamentally any different, from finding the eth root of one of its elements.

To review, we can find the eth root of a number m modulo a prime number p by taking advantage of Fermat’s little theorem, which states that m**(p-1) = 1 mod p. Then if we can find an integer d such that d * e = k * (p-1) + 1, for some integer k, then m**(d * e) = m mod p, so (m**d)**e = m mod p. Thus m**d is the desired eth root of m mod p.

Finding the eth root of a 4x4 matrix modulo p is a very similar process. The set of 4x4 matrices modulo p forms a group called the general linear group of 4x4 matrices over the integers mod p. Let g be the order of this group. Then by Lagrange’s theorem, for a matrix A in this group, A ** g = I, where I is the identity matrix. Then if we can find an integer d such that d * e = k * g + 1 for some integer k, then A**(d * e) = A, so (A**d)**e = A. Thus A**d is the desired eth root of A. The only extra step we have to do here is to find the order of the group, but that’s a well-known problem with lots of writeups available.

See the official writeup, as well as unofficial writeups from jro and SuperFashi, for more explanations of this.

Inverting the Transformation Functions

Detecting the Function Types

After calculating the 32-byte transformed input that produced the correct matrix, the next step was to invert each of the transformation functions in order to determine the input from the license file that would be required.

For any given transformation, I needed a way to determine if it was an sbox, a shuffle, or an exponentiation. I would also need to extract the relevant operand from each one (the sbox, the indices of the shuffle, or the exponent). Again, I used binary regular expressions to extract the data.

For example, all three transformation types use rax and rdx to initialize immediate values, so the following regex could be used to extract the operand:

# 48b8????????????????  mov     rax, ????????????????
# 48ba????????????????  mov     rdx, ????????????????
mov_expr = rb'\x48\xb8(?P<rax>.{8})' + rb'\x48\xba(?P<rdx>.{8})'

For the sbox operation, a total of 256 operand bytes would be extracted from this regex, and for the other two transformations, at total of 32 bytes would be extracted. Additionally, the exponentiation operation is the only one of the three types that uses multiplication, which could be detected with another regular expression:

# 0fb6c0             movzx   eax, al
# 0fafc2             imul    eax, edx
mul_expr = rb'\x0f\xb6\xc0\x0f\xaf\xc2'

This on its own is enough to distinguish the three types of functions, but since I knew it would be an absolute nightmare to debug any false positives, I added a couple of extra verification functions to check whether the sbox really contained all the values between 0 and 256 and whether the shuffle order really contained all the values between 0 and 32. I ended up with the following extraction function:

# 4881c4c0000000     add     rsp, 0xc0            |   4883c450           add     rsp, 0x50  | 415f               pop     r15 {__saved_r15}
# 5d                 pop     rbp {__saved_rbp}
# c3                 retn     {__return_addr}
cleanup_expr = rb'(\x48\x81\xc4....|\x48\x83\xc4.|\x41\x5f)\x5d\xc3'

# 55                 push    rbp {__saved_rbp}
# 4889e5             mov     rbp, rsp {__saved_rbp}  | 4881ec10010000     sub     rsp, 0x110 | 4157               push    r15 {__saved_r15}
start_expr = rb'\x55(\x48\x89\xe5|\x48\x81\xec....|\x41\x57)'

function_expr = start_expr + rb'.*?' + cleanup_expr

# 48b8????????????????  mov     rax, ????????????????
# 48ba????????????????  mov     rdx, ????????????????
mov_expr = rb'\x48\xb8(?P<rax>.{8})' + rb'\x48\xba(?P<rdx>.{8})'

# 0fb6c0             movzx   eax, al
# 0fafc2             imul    eax, edx
mul_expr = rb'\x0f\xb6\xc0\x0f\xaf\xc2'

mul_mov = rb'H\x89E\xa0H\x89U\xa8' + mov_expr + rb'H\x89E\xafH\x89U\xb7'

shuffle_mov = b'H\x89E\xd0H\x89U\xd8' + mov_expr + b'H\x89E\xe0H\x89U\xe8'

def verify_sbox(data):
    s = set()
    for i in data:
        if i < 0 or i >= 0x100:
            return False
        s.add(i)
    return len(s) == 0x100

def verify_shuffle(data):
    s = set()
    for i in data:
        if i < 0 or i >= 0x20:
            return False
        s.add(i)
    return len(s) == 0x20

def detect_type(data):
    constant_bytes = b''
    insn_type = None

    for i in re.finditer(mov_expr, data, re.DOTALL | re.MULTILINE):
        constant_bytes += i.group("rax")
        constant_bytes += i.group("rdx")
    if len(constant_bytes) == 0x100:
        if verify_sbox(constant_bytes):
            insn_type = 'sbox'
    elif len(constant_bytes) == 0x20:
        if len(re.findall(mul_expr, data, re.DOTALL | re.MULTILINE)) > 0 and len(re.findall(mul_mov, data, re.DOTALL | re.MULTILINE)) > 0:
            insn_type = 'modexp'
        elif len(re.findall(shuffle_mov, data, re.DOTALL | re.MULTILINE)) > 0:
            if verify_shuffle(constant_bytes):
                insn_type = 'shuffle'
    if insn_type is None:
        print(constant_bytes.hex(), len(constant_bytes), len(data))
        
    return insn_type, constant_bytes

Running this function on all the exports (other than check) of all of the DLLs, I compiled a list of every function type and operand in every DLL.

Creating the License

For each check function, I then had to extract the order of function calls. I again used regular expressions here, which turned out to be a little bit messy because both direct and indirect function calls were used:

# 48 89 c1               mov     rcx, rax
# e8 d2 ce fb ff         call    f41689142231683650251
mov_rax_rcx = rb'\x48\x89\xc1\xe8(?P<direct>....)'

#48 8b 05 6a 54 01 00     mov     rax, qword [rel f89290994951878036061]
#ff d0                    call    rax
mov_rax_addr = rb'\x48\x8b.(?P<indirect>...).\xff\xd0'

call_regex = mov_rax_rcx + rb'|' + mov_rax_addr

The other tricky part of this is that there are two different places to check for the name of the function that’s being called: each function is either one of the DLL’s imports, or one of its exports. It turned out that the direct calls are always used for exports and the indirect calls are always used for exports, so that at least made things a little easier to distinguish.

for i in re.finditer(call_regex, check_func, re.DOTALL | re.MULTILINE):
    direct = i.group('direct')
    indirect = i.group('indirect')
    if indirect is not None:
        relative_addr = struct.unpack('<i', indirect+b'\x00')[0]
        offset = i.start() + check_start + relative_addr + 7
        if offset in imports:
            call_seq.append(imports[offset])
    else:
        relative_addr = struct.unpack('<i', direct)[0]
        offset = i.start() + check_start + relative_addr + 8
        if offset in exports:
            call_seq.append(exports[offset])

I then wrote the corresponding inverse functions for each transformation operation. The one issue I ran into with this is that the initial value I parsed for the exponent of the modular exponentiation operation was wrong: the exponent is actually only 31 bytes, not 32, so a byte had to be removed from the middle of the value that I parsed:

def undo_modexp(enc, e):
    #print('modexp')
    e = e[0:15] + e[16:] # the exponent has 1 unused byte for some reason
    xor_val = enc[0] & 1
    enc[0] = enc[0] | 1
    enc = int.from_bytes(enc, 'little')
    e = int.from_bytes(e, 'little')
    d = pow(e, -1, (1 << 255))
    res = pow(enc, d, (1 << 256))
    res = bytearray(res.to_bytes(32, 'little'))
    res[0] = res[0] ^ xor_val ^ 1
    return res

At that point, I was able to call the inverses of each of the transformation functions to generate the expected initial input that should be in the license file. I made sure to do this in the order in which the DLLs were called and kept the state of the checksum buffer updated so that the XOR with the value from the buffer would be correct.

for i in check_calls[::-1]:
    opcode, hexval, source = func_types[i]
    operand = bytearray.fromhex(hexval)
    if opcode == 'shuffle':
        data = undo_shuffle(data, operand)
    elif opcode == 'modexp':
        data = undo_modexp(data, operand)
    elif opcode == 'sbox':
        data = undo_sbox(data, operand)
    #print(data.hex())

    # undo the XOR with the checksum buffer value
    first_int = int.from_bytes(data[0:4], 'little')
    first_int ^= counter_table[source]
    data[0:4] = first_int.to_bytes(4, 'little')

# update the checksum buffer
for imp in imports[idx]:
    counter_table[imp] += counter

license_part = idx.to_bytes(2, 'little') + bytes(data)
#print(license_part.hex())
license.write(license_part)

I now had a license file to check, but, I couldn’t just let the validation code run, as the DLL loading was slow enough the validation function took several minutes per DLL. Luckily, the only thing that’s used in the decryption of the flag is the SHA256 hash of the license file, so I just patched out the license check so that the hashing and decryption ran immediately. This finally gets us the flag:

Its_l1ke_10000_spooO0o0O0oOo0o0O0O0OoOoOOO00o0o0Ooons@flare-on.com