import json
import binascii
from os import path
from collections import namedtuple


def accept_file(file, name):
    metafile = name + '.json'
    if not path.isfile(metafile):
        return 0

    try:
        fmt = json.load(open(metafile))
    except:
        return 0

    libraries = [s for s in fmt if s['type'] == 0x2C]
    if not libraries:
        return 0
    return '8051 OMF library ({} modules) # {}'.format(libraries[0]['contents']['module_count'], metafile)


IDAModule = namedtuple('IDAModule', ['module', 'segments', 'externals'])
IDASegment = namedtuple('IDASegment', ['segment', 'symbols', 'contents', 'fixups'])


def place_segments(offset, modules, type, offsets, externals):
    start = offset

    for module in modules:
        seg_offsets = offsets.setdefault(module.module['name'], {})

        for seg_id, seg in module.segments.items():
            if seg.segment['name'].startswith('?XD?'):
                seg_type = 'XTRN'
            elif seg.segment['name'].startswith('?PR?'):
                seg_type = 'CODE'
            elif seg.segment['name'].startswith('?DT?'):
                seg_type = 'DATA'
            if seg_type != type:
                continue
            seg_offsets[seg_id] = offset

            contents = seg.contents + b'\x00' * max(0, seg.segment['size'] - len(seg.contents))
            size = len(contents)
            idaapi.mem2base(bytes(contents), offset)

            if seg_type == 'CODE':
                idaapi.create_insn(offset)
            elif seg_type == 'DATA':
                idaapi.create_data(offset, 0, size)

            for sym in seg.symbols:
                idaapi.set_name(offset + sym['offset'], sym['name'].encode('ascii'))
                externals[sym['name']] = offset + sym['offset']

            offset += size

    if start != offset:
        idaapi.add_segm(0, start, offset, type, type)
    return offset

def resolve_symbols(modules, offsets, externals):
    # Round 2: resolve externals
    for module in modules:
        seg_offsets = offsets[module.module['name']]
        for seg_id, seg in module.segments.items():
            for fixup in seg.fixups:
                ref = seg_offsets[seg_id] + fixup['offset']
                if fixup['operand_type'] in ('0', '1'):
                    base = seg_offsets[fixup['operand_id']]
                elif fixup['operand_type'] == '2':
                    ext_name = module.externals[2, fixup['operand_id']]['name']
                    base = externals[ext_name]

                if fixup['type'] == '7':
                    addr = (base - 0x20) * 8 + fixup['operand_offset']
                else:
                    addr = base + fixup['operand_offset']

                if fixup['type'] in ('0', '1', '6', '7'):
                    idaapi.put_byte(ref, addr & 0xff)
                elif fixup['type'] == '2':
                    idaapi.put_byte(ref, (addr - ref) & 0xff)
                elif fixup['type'] == '3':
                    idaapi.put_byte(ref, addr >> 8)
                elif fixup['type'] == '4':
                    idaapi.put_byte(ref, addr >> 8)
                    idaapi.put_byte(ref + 1, addr & 0xff)
                elif fixup['type'] == '5':
                    orig = idaapi.get_byte(ref)
                    idaapi.put_byte(ref, (orig & ~0b11100000) | (((addr >> 8) & 0b111) << 5))
                    idaapi.put_byte(ref + 1, addr & 0xff)
                    orig = idaapi.get_byte(ref + 2)
                    idaapi.put_byte(ref + 2, (orig & ~0b11111000) | (addr & 0b11111000))

def load_file(file, flags, format):
    metafile = format.split(' # ')[1]
    if not path.isfile(metafile):
        return 0

    try:
        fmt = json.load(open(metafile))
    except:
        return 0

    idaapi.set_processor_type('8051', idaapi.SETPROC_ALL)

    modules = []

    cur_module = None
    last_seg_id = None
    externals = set()
    symbols = set()

    for record in fmt:
        if record['type'] == 0x2:
            cur_module = IDAModule(record['contents'], {}, {})
            modules.append(cur_module)
        elif record['type'] == 0x18:
            for external in record['contents']['externals']:
                cur_module.externals[external['block_id'], external['ext_id']] = external
                externals.add(external['name'])
        elif record['type'] in (0xE, 0xF):
            for segment in record['contents']['segments']:
                seg = IDASegment(segment, [], bytearray(), [])
                cur_module.segments[segment['id']] = seg
        elif record['type'] in (0x16, 0x17):
            for symbol in record['contents']['symbols']:
                cur_module.segments[symbol['id']].symbols.append(symbol)
                symbols.add(symbol['name'])
        elif record['type'] in (0x6, 0x7):
            contents = record['contents']
            data = binascii.unhexlify(contents['data'])
            cur_module.segments[contents['segment_id']].contents[contents['offset']:contents['offset'] + len(data)] = data
            last_seg_id = contents['segment_id']
        elif record['type'] in (0x8, 0x9):
            cur_module.segments[last_seg_id].fixups.extend(record['contents']['fixups'])

    external_locs = {}
    externals -= symbols

    offset = 4096

    if externals:
        ext_size = 2
        idaapi.add_segm(0, offset, offset + ext_size * len(externals), 'EXT', 'XTRN')
        for e in externals:
            idaapi.set_name(offset, e.encode('ascii'))
            external_locs[e] = offset
            offset += ext_size

    module_offsets = {}

    # Round 1: place modules
    for type in ['CODE', 'DATA', 'XTRN']:
        offset = place_segments(offset, modules, type, module_offsets, external_locs)

    # Round 2: resolve externals
    resolve_symbols(modules, module_offsets, external_locs)

    return 1