#!/usr/bin/env python3
import argparse
import elftools.elf.elffile
import elftools.dwarf.descriptions
from collections import namedtuple
from struct import unpack

SRAM_OFFSET = 0x800000
EEPROM_OFFSET = 0x810000
FILL_BYTE = b'\0'


Entry = namedtuple('Entry', ['name', 'loc', 'size'])
Member = namedtuple('Member', ['name', 'off', 'size'])


def array_inc(loc, dim, idx=0):
    if idx == len(dim):
        return True
    loc[idx] += 1
    if loc[idx] == dim[idx]:
        loc[idx] = 0
        return array_inc(loc, dim, idx+1)
    return False

def get_type_size(type_DIE):
    while True:
        if 'DW_AT_byte_size' in type_DIE.attributes:
            return type_DIE, type_DIE.attributes.get('DW_AT_byte_size').value
        if 'DW_AT_type' not in type_DIE.attributes:
            return None
        type_DIE = type_DIE.get_DIE_from_attribute('DW_AT_type')

def get_type_arrsize(type_DIE):
    size = get_type_size(type_DIE)
    if size is None:
        return None
    byte_size = size[1]
    if size[0].tag != 'DW_TAG_pointer_type':
        array_DIE = get_type_def(type_DIE, 'DW_TAG_array_type')
        if array_DIE is not None:
            for range_DIE in array_DIE.iter_children():
                if range_DIE.tag == 'DW_TAG_subrange_type' and \
                   'DW_AT_upper_bound' in range_DIE.attributes:
                    dim = range_DIE.attributes['DW_AT_upper_bound'].value + 1
                    byte_size *= dim
    return byte_size

def get_type_def(type_DIE, type_tag):
    while True:
        if type_DIE.tag == type_tag:
            return type_DIE
        if 'DW_AT_type' not in type_DIE.attributes:
            return None
        type_DIE = type_DIE.get_DIE_from_attribute('DW_AT_type')

def get_FORM_block1(attr):
    if attr.form != 'DW_FORM_block1':
        return None
    if attr.value[0] == 3: # OP_addr
        return int.from_bytes(attr.value[1:], 'little')
    if attr.value[0] == 35: # OP_plus_uconst (ULEB128)
        v = 0
        s = 0
        for b in attr.value[1:]:
            v |= (b & 0x7f) << s
            if b & 0x80 == 0:
                break
            s += 7
        return v
    return None

def get_elf_globals(path, expand_structs, struct_gaps=True):
    fd = open(path, "rb")
    if fd is None:
        return
    elffile = elftools.elf.elffile.ELFFile(fd)
    if elffile is None or not elffile.has_dwarf_info():
        return

    # probably not needed, since we're decoding expressions manually
    elftools.dwarf.descriptions.set_global_machine_arch(elffile.get_machine_arch())
    dwarfinfo = elffile.get_dwarf_info()

    grefs = []
    for CU in dwarfinfo.iter_CUs():
        for DIE in CU.iter_DIEs():
            # handle only variable types
            if DIE.tag != 'DW_TAG_variable':
                continue
            if 'DW_AT_location' not in DIE.attributes:
                continue
            if 'DW_AT_name' not in DIE.attributes and \
               'DW_AT_abstract_origin' not in DIE.attributes:
                continue

            # handle locations encoded directly as DW_OP_addr (leaf globals)
            loc = get_FORM_block1(DIE.attributes['DW_AT_location'])
            if loc is None or loc < SRAM_OFFSET or loc >= EEPROM_OFFSET:
                continue
            loc -= SRAM_OFFSET

            # variable name/type
            if 'DW_AT_name' not in DIE.attributes and \
               'DW_AT_abstract_origin' in DIE.attributes:
                DIE = DIE.get_DIE_from_attribute('DW_AT_abstract_origin')
                if 'DW_AT_location' in DIE.attributes:
                    # duplicate reference (handled directly), skip
                    continue
            if 'DW_AT_name' not in DIE.attributes:
                continue
            if 'DW_AT_type' not in DIE.attributes:
                continue

            name = DIE.attributes['DW_AT_name'].value.decode('ascii')

            # get final storage size
            size = get_type_size(DIE)
            if size is None:
                continue
            byte_size = size[1]

            # fetch array dimensions (if known)
            array_dim = []
            array_DIE = get_type_def(DIE, 'DW_TAG_array_type')
            if array_DIE is not None:
                for range_DIE in array_DIE.iter_children():
                    if range_DIE.tag == 'DW_TAG_subrange_type' and \
                       'DW_AT_upper_bound' in range_DIE.attributes:
                        array_dim.append(range_DIE.attributes['DW_AT_upper_bound'].value + 1)

            # fetch structure members (one level only)
            members = []
            if expand_structs and size[0].tag != 'DW_TAG_pointer_type':
                struct_DIE = get_type_def(DIE, 'DW_TAG_structure_type')
                if struct_DIE is not None:
                    for member_DIE in struct_DIE.iter_children():
                        if member_DIE.tag == 'DW_TAG_member' and 'DW_AT_name' in member_DIE.attributes:
                            m_name = member_DIE.attributes['DW_AT_name'].value.decode('ascii')
                            m_off = get_FORM_block1(member_DIE.attributes['DW_AT_data_member_location'])
                            m_byte_size = get_type_size(member_DIE)[1]

                            # still expand member arrays
                            m_array_dim = []
                            m_array_DIE = get_type_def(member_DIE, 'DW_TAG_array_type')
                            if m_array_DIE is not None:
                                for range_DIE in m_array_DIE.iter_children():
                                    if range_DIE.tag == 'DW_TAG_subrange_type' and \
                                       'DW_AT_upper_bound' in range_DIE.attributes:
                                        m_array_dim.append(range_DIE.attributes['DW_AT_upper_bound'].value + 1)

                            # likely string, remove one dimension
                            if m_byte_size == 1 and len(m_array_dim) > 1:
                                m_byte_size *= m_array_dim.pop()
                            if len(m_array_dim) == 0 or (len(m_array_dim) == 1 and m_array_dim[0] == 1):
                                # plain entry
                                members.append(Member(m_name, m_off, m_byte_size))
                            elif len(m_array_dim) == 1 and m_byte_size == 1:
                                # likely string, avoid expansion
                                members.append(Member(m_name + '[]', m_off, m_array_dim[0]))
                            else:
                                # expand array entries
                                m_array_pos = m_off
                                m_array_loc = [0] * len(m_array_dim)
                                while True:
                                    # location index
                                    sfx = ''
                                    for d in range(len(m_array_dim)):
                                        sfx += '[{}]'.format(m_array_loc[d])

                                    members.append(Member(m_name + sfx, m_array_pos, m_byte_size))

                                    # advance
                                    if array_inc(m_array_loc, m_array_dim):
                                        break
                                    m_array_pos += m_byte_size

                if struct_gaps and len(members):
                    # fill gaps in the middle
                    members = list(sorted(members, key=lambda x: x.off))
                    last_end = 0
                    for n in range(len(members)):
                        member = members[n]
                        if member.off > last_end:
                            members.append(Member('*UNKNOWN*', last_end, member.off - last_end))
                        last_end = member.off + member.size

                if struct_gaps and len(members):
                    # fill gap at the end
                    members = list(sorted(members, key=lambda x: x.off))
                    last = members[-1]
                    last_end = last.off + last.size
                    if byte_size > last_end:
                        members.append(Member('*UNKNOWN*', last_end, byte_size - last_end))


            def expand_members(entry, members):
                if len(members) == 0:
                    grefs.append(entry)
                else:
                    for member in members:
                        grefs.append(Entry(entry.name + '.' + member.name,
                                           entry.loc + member.off, member.size))

            # likely string, remove one dimension
            if byte_size == 1 and len(array_dim) > 1:
                byte_size *= array_dim.pop()
            if len(array_dim) == 0 or (len(array_dim) == 1 and array_dim[0] == 1):
                # plain entry
                expand_members(Entry(name, loc, byte_size), members)
            elif len(array_dim) == 1 and byte_size == 1:
                # likely string, avoid expansion
                grefs.append(Entry(name + '[]', loc, array_dim[0]))
            else:
                # expand array entries
                array_pos = loc
                array_loc = [0] * len(array_dim)
                while True:
                    # location index
                    sfx = ''
                    for d in range(len(array_dim)):
                        sfx += '[{}]'.format(array_loc[d])

                    expand_members(Entry(name + sfx, array_pos, byte_size), members)

                    # advance
                    if array_inc(array_loc, array_dim):
                        break
                    array_pos += byte_size

    return grefs


def decode_dump(path):
    fd = open(path, 'r')
    if fd is None:
        return None

    buf_addr = None # starting address
    buf_data = None # data

    for line in fd:
        tokens = line.split(maxsplit=1)
        if len(tokens) == 0 or tokens[0] == 'ok':
            break
        elif len(tokens) < 2 or tokens[0] == 'D2':
            continue

        addr = int.from_bytes(bytes.fromhex(tokens[0]), 'big')
        data = bytes.fromhex(tokens[1])

        if buf_addr is None:
            buf_addr = addr
            buf_data = data
        else:
            # grow buffer as needed
            if addr < buf_addr:
                buf_data = FILL_BYTE * (buf_addr - addr)
                buf_addr = addr
            addr_end = addr + len(data)
            buf_end = buf_addr + len(buf_data)
            if addr_end > buf_end:
                buf_data += FILL_BYTE * (addr_end - buf_end)

            # replace new part
            rep_start = addr - buf_addr
            rep_end = rep_start + len(data)
            buf_data = buf_data[:rep_start] + data + buf_data[rep_end:]

    return (buf_addr, buf_data)


def annotate_refs(grefs, addr, data, width=46, gaps=True, overlaps=True):
    last_end = None
    for entry in grefs:
        if entry.loc < addr:
            continue
        if entry.loc + entry.size > addr + len(data):
            continue

        pos = entry.loc-addr
        end_pos = pos + entry.size
        buf = data[pos:end_pos]

        buf_repr = ''
        if len(buf) in [1, 2, 4]:
            # attempt to decode as integers
            buf_repr += ' I:' + str(int.from_bytes(buf, 'little')).rjust(10)
        if len(buf) in [4, 8]:
            # attempt to decode as floats
            typ = 'f' if len(buf) == 4 else 'd'
            buf_repr += ' F:' + '{:10.3f}'.format(unpack(typ, buf)[0])

        if last_end is not None:
            if gaps and last_end < pos:
                # decode gaps
                gap_size = pos - last_end
                gap_buf = data[last_end:pos]
                print('{:04x} {} {:4} R:{}'.format(addr+last_end, "*UNKNOWN*".ljust(width),
                                                   gap_size, gap_buf.hex()))
            if overlaps and last_end > pos + 1:
                gap_size = pos - last_end
                print('{:04x} {} {:4}'.format(addr+last_end, "*OVERLAP*".ljust(width), gap_size))

        print('{:04x} {} {:4}{} R:{}'.format(entry.loc, entry.name.ljust(width),
                                             entry.size, buf_repr, buf.hex()))
        last_end = end_pos


def print_map(grefs):
    print('OFFSET\tSIZE\tNAME')
    for entry in grefs:
        print('{:x}\t{}\t{}'.format(entry.loc, entry.size, entry.name))


def main():
    ap = argparse.ArgumentParser(description="""
        Generate a symbol table map starting directly from an ELF
        firmware with DWARF2 debugging information.
        When used along with a memory dump obtained from the D2 g-code,
        show the value of each symbol which is within the address range.
    """)
    ap.add_argument('elf', help='ELF file containing DWARF2 debugging information')
    ap.add_argument('--no-gaps', action='store_true',
                    help='do not dump memory inbetween known symbols')
    ap.add_argument('--no-expand-structs', action='store_true',
                    help='do not decode structure data')
    ap.add_argument('--overlaps', action='store_true',
                    help='annotate overlaps greater than 1 byte')
    ap.add_argument('--name-width', type=int, default=46,
                    help='set name column width')
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument('dump', nargs='?', help='RAM dump obtained from D2 g-code')
    g.add_argument('--map', action='store_true', help='dump global memory map')
    args = ap.parse_args()

    grefs = get_elf_globals(args.elf, expand_structs=not args.no_expand_structs)
    grefs = list(sorted(grefs, key=lambda x: x.loc))
    if args.dump is None:
        print_map(grefs)
    else:
        addr, data = decode_dump(args.dump)
        annotate_refs(grefs, addr, data,
                      width=args.name_width,
                      gaps=not args.no_gaps,
                      overlaps=args.overlaps)

if __name__ == '__main__':
    exit(main())
