Opened 4 hours ago

Last modified 2 hours ago

#36874 assigned Cleanup/optimization

Speed up mask/unmask cipher secret functions

Reported by: Tim Lansen Owned by: Tim Lansen
Component: CSRF Version: 6.0
Severity: Normal Keywords: CSRF cipher token mask unmask
Cc: Triage Stage: Unreviewed
Has patch: yes Needs documentation: no
Needs tests: no Patch needs improvement: no
Easy pickings: no UI/UX: no

Description

The functions _mask_cipher_secret and _unmask_cipher_secret use chars.index() to get every char order.
They scan CSRF_ALLOWED_CHARS 64 times on every call.
The idea is to create translation table and flip every char like XLAT does.

Change History (2)

comment:1 by Tim Lansen, 3 hours ago

Has patch: set
Owner: set to Tim Lansen
Status: newassigned

comment:2 by Tim Lansen, 2 hours ago

Benchmarking the approach with Python 3.12.7 on ASUS Vivobook (Intel Core Ultra 9)

$ python csrf_cipher_benchmark.py
Execution time 1: 2.173560 (2.7169501781463623e-05 sec per mask+unmask)
Execution time 2: 1.631568 (2.0394599437713623e-05 sec per mask+unmask)

The code

import secrets


def get_random_string(length, allowed_chars):
    return "".join(secrets.choice(allowed_chars) for i in range(length))


CSRF_SECRET_LENGTH = 32
CSRF_ALLOWED_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'


def _get_new_csrf_string():
    return get_random_string(CSRF_SECRET_LENGTH, allowed_chars=CSRF_ALLOWED_CHARS)


def _make_xlat(chars: str):
    xlat = [0 for _ in range(1 + max((ord(x) for x in chars)))]
    for i, c in enumerate(chars):
        xlat[ord(c)] = i
    return xlat


CSRF_XLAT = _make_xlat(CSRF_ALLOWED_CHARS)


def _mask_cipher_secret(secret):
    """
    Given a secret (assumed to be a string of CSRF_ALLOWED_CHARS), generate a
    token by adding a mask and applying it to the secret.
    """
    mask = _get_new_csrf_string()
    chars = CSRF_ALLOWED_CHARS
    pairs = zip((chars.index(x) for x in secret), (chars.index(x) for x in mask))
    cipher = "".join(chars[(x + y) % len(chars)] for x, y in pairs)
    return mask + cipher


def _unmask_cipher_token(token):
    """
    Given a token (assumed to be a string of CSRF_ALLOWED_CHARS, of length
    CSRF_TOKEN_LENGTH, and that its first half is a mask), use it to decrypt
    the second half to produce the original secret.
    """
    mask = token[:CSRF_SECRET_LENGTH]
    token = token[CSRF_SECRET_LENGTH:]
    xlat = CSRF_XLAT
    chars = CSRF_ALLOWED_CHARS
    pairs = zip((chars.index(x) for x in token), (chars.index(x) for x in mask))
    return "".join(chars[x - y] for x, y in pairs)  # Note negative values are ok


def _mask_cipher_secret_xlat(secret):
    """
    Given a secret (assumed to be a string of CSRF_ALLOWED_CHARS), generate a
    token by adding a mask and applying it to the secret.
    """
    mask = _get_new_csrf_string()
    chars = CSRF_ALLOWED_CHARS
    pairs = zip((CSRF_XLAT[ord(x)] for x in secret), (CSRF_XLAT[ord(x)] for x in mask))
    cipher = "".join(chars[(x + y) % len(chars)] for x, y in pairs)
    return mask + cipher


def _unmask_cipher_token_xlat(token):
    """
    Given a token (assumed to be a string of CSRF_ALLOWED_CHARS, of length
    CSRF_TOKEN_LENGTH, and that its first half is a mask), use it to decrypt
    the second half to produce the original secret.
    """
    mask = token[:CSRF_SECRET_LENGTH]
    token = token[CSRF_SECRET_LENGTH:]
    chars = CSRF_ALLOWED_CHARS
    pairs = zip((CSRF_XLAT[ord(x)] for x in token), (CSRF_XLAT[ord(x)] for x in mask))
    return "".join(chars[x - y] for x, y in pairs)  # Note negative values are ok


def benchmark(secrets: int, iterations: int):
    d1, d2 = 0.0, 0.0
    for i in range(secrets):
        secret = _get_new_csrf_string()
        import time
        t0 = time.time()
        for _ in range(iterations):
            token = _mask_cipher_secret(secret)
            secret = _unmask_cipher_token(token)
        t1 = time.time()
        for _ in range(iterations):
            token = _mask_cipher_secret_xlat(secret)
            secret = _unmask_cipher_token_xlat(token)
        t2 = time.time()
        d1 += t1 - t0
        d2 += t2 - t1
    print(f'Execution time 1: {d1:.6f} ({d1 / secrets / iterations} sec per mask+unmask)')
    print(f'Execution time 2: {d2:.6f} ({d2 / secrets / iterations} sec per mask+unmask)')


if __name__ == '__main__':
    benchmark(200, 400)
Version 0, edited 2 hours ago by Tim Lansen (next)
Note: See TracTickets for help on using tickets.
Back to Top