#!/usr/bin/env python3
import argparse
import elftools.elf.elffile
import elftools.dwarf.descriptions
from collections import namedtuple
from struct import unpack

SRAM_OFFSET = 0x800000
EEPROM_OFFSET = 0x810000
FILL_BYTE = b'\0'


Entry = namedtuple('Entry', ['name', 'loc', 'size'])


def get_elf_globals(path):
    fd = open(path, "rb")
    if fd is None:
        return
    elffile = elftools.elf.elffile.ELFFile(fd)
    if elffile is None or not elffile.has_dwarf_info():
        return

    # probably not needed, since we're decoding expressions manually
    elftools.dwarf.descriptions.set_global_machine_arch(elffile.get_machine_arch())
    dwarfinfo = elffile.get_dwarf_info()

    grefs = []
    for CU in dwarfinfo.iter_CUs():
        for DIE in CU.iter_DIEs():
            # handle only variable types
            if DIE.tag != 'DW_TAG_variable':
                continue
            if 'DW_AT_name' not in DIE.attributes:
                continue
            if 'DW_AT_location' not in DIE.attributes:
                continue
            if 'DW_AT_type' not in DIE.attributes:
                continue

            # handle locations encoded directly as DW_OP_addr (leaf globals)
            at_loc = DIE.attributes['DW_AT_location']
            if at_loc.form != 'DW_FORM_block1' or at_loc.value[0] != 3:
                continue
            loc = (at_loc.value[1]) + (at_loc.value[2] << 8) \
                + (at_loc.value[3] << 16) + (at_loc.value[4] << 24)
            if loc < SRAM_OFFSET or loc >= EEPROM_OFFSET:
                continue
            loc -= SRAM_OFFSET

            # variable name
            name = DIE.attributes['DW_AT_name'].value.decode('ascii')

            # recurse on type to find the final storage definition
            type_DIE = DIE
            byte_size = None
            while True:
                if 'DW_AT_byte_size' in type_DIE.attributes:
                    byte_size = type_DIE.attributes.get('DW_AT_byte_size')
                if 'DW_AT_type' not in type_DIE.attributes:
                    break
                type_DIE = type_DIE.get_DIE_from_attribute('DW_AT_type')
            if byte_size is None:
                continue
            size = byte_size.value

            grefs.append(Entry(name, loc, size))

    return grefs


def decode_dump(path):
    fd = open(path, 'r')
    if fd is None:
        return None

    buf_addr = None # starting address
    buf_data = None # data

    for line in fd:
        tokens = line.split(maxsplit=1)
        if len(tokens) == 0 or tokens[0] == 'ok':
            break
        elif len(tokens) < 2 or tokens[0] == 'D2':
            continue

        addr = int.from_bytes(bytes.fromhex(tokens[0]), 'big')
        data = bytes.fromhex(tokens[1])

        if buf_addr is None:
            buf_addr = addr
            buf_data = data
        else:
            # grow buffer as needed
            if addr < buf_addr:
                buf_data = FILL_BYTE * (buf_addr - addr)
                buf_addr = addr
            addr_end = addr + len(data)
            buf_end = buf_addr + len(buf_data)
            if addr_end > buf_end:
                buf_data += FILL_BYTE * (addr_end - buf_end)

            # replace new part
            rep_start = addr - buf_addr
            rep_end = rep_start + len(data)
            buf_data = buf_data[:rep_start] + data + buf_data[rep_end:]

    return (buf_addr, buf_data)


def annotate_refs(grefs, addr, data, width=45, gaps=True):
    last_end = None
    for entry in grefs:
        if entry.loc < addr:
            continue
        if entry.loc + entry.size > addr + len(data):
            continue

        pos = entry.loc-addr
        end_pos = pos + entry.size
        buf = data[pos:end_pos]

        buf_repr = ''
        if len(buf) in [1, 2, 4]:
            # attempt to decode as integers
            buf_repr += ' I:' + str(int.from_bytes(buf, 'big')).rjust(10)
        if len(buf) in [4, 8]:
            # attempt to decode as floats
            typ = 'f' if len(buf) == 4 else 'd'
            buf_repr += ' F:' + '{:10.3f}'.format(unpack(typ, buf)[0])

        if gaps and last_end is not None and last_end < pos:
            # decode gaps
            gap_size = pos - last_end
            gap_buf = data[last_end:pos]
            print('{:04x} {} {:4} R:{}'.format(addr+last_end, "*UNKNOWN*".ljust(width),
                                               gap_size, gap_buf.hex()))

        print('{:04x} {} {:4}{} R:{}'.format(entry.loc, entry.name.ljust(width),
                                             entry.size, buf_repr, buf.hex()))
        last_end = end_pos


def print_map(grefs):
    print('OFFSET\tSIZE\tNAME')
    for entry in grefs:
        print('{:x}\t{}\t{}'.format(entry.loc, entry.size, entry.name))


def main():
    ap = argparse.ArgumentParser(description="""
        Generate a symbol table map starting directly from an ELF
        firmware with DWARF2 debugging information.
        When used along with a memory dump obtained from the D2 g-code,
        show the value of each symbol which is within the address range.
    """)
    ap.add_argument('elf', help='ELF file containing DWARF2 debugging information')
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument('dump', nargs='?', help='RAM dump obtained from D2 g-code')
    g.add_argument('--map', action='store_true', help='dump global memory map')
    args = ap.parse_args()

    grefs = get_elf_globals(args.elf)
    grefs = list(sorted(grefs, key=lambda x: x.loc))
    if args.dump is None:
        print_map(grefs)
    else:
        addr, data = decode_dump(args.dump)
        annotate_refs(grefs, addr, data)

if __name__ == '__main__':
    exit(main())
