import enum
from destruct import Arr, Struct, parse

record_types = {}

def record(id):
    def inner(fn):
        record_types[id] = fn
        return fn
    return inner


@record(0x02)
class ModuleStart(Struct):
    name = Str(kind='pascal')
    compiler_id  = UInt(8)
    pad1 = Data(1)

@record(0x04)
class ModuleEnd(Struct):
    name = Str(kind='pascal')
    pad1 = Data(2)
    regbank_mask = UInt(8)
    pad2 = Data(1)


class SegmentType(enum.Enum):
    Code = 0
    ExtData = 1
    Data = 2
    IntData = 3
    Bit = 4

class Segment(Struct):
    id   = UInt(8)
    info = UInt(8)
    type = Enum(SegmentType, UInt(8))
    pad1 = Data(1)
    base = UInt(16)
    size = UInt(16)
    name = Str(kind='pascal')

@record(0x0E)
class SegmentDefinitions(Struct):
    segments = Arr(Segment)

class KeilSegment(Struct):
    id   = UInt(8)
    info = UInt(8)
    type = UInt(8) # Enum(SegmentType, UInt(8))
    pad1 = Data(1)
    pad2 = Data(1)
    base = UInt(16)
    size = UInt(16)
    name = Str(kind='pascal')

@record(0x0F)
class KeilSegmentDefinitions(Struct):
    segments = Arr(KeilSegment)


class Symbol(Struct):
    id   = UInt(8)
    info = UInt(8)
    offset = UInt(16)
    pad1   = Data(1)
    name   = Str(kind='pascal')

@record(0x16)
class SymbolDefinitions(Struct):
    symbols = Arr(Symbol)

class KeilSymbol(Struct):
    id   = UInt(8)
    info = UInt(8)
    type = Enum(SegmentType, UInt(8))
    offset = UInt(16)
    pad1   = Data(1)
    name   = Str(kind='pascal')

@record(0x17)
class KeilSymbolDefinitions(Struct):
    symbols = Arr(KeilSymbol)


class External(Struct):
    block_id = UInt(8)
    ext_id   = UInt(8)
    info     = UInt(8)
    pad1     = Data(1)
    name     = Str(kind='pascal')

@record(0x18)
class ExternalDefinitions(Struct):
    externals = Arr(External)


class ScopeType(enum.Enum):
    Module    = 0
    Do        = 1
    Proc      = 2
    ModuleEnd = 3
    DoEnd     = 4
    ProcEnd   = 5

@record(0x10)
class ScopeDefinition(Struct):
    type = Enum(ScopeType, UInt(8))
    name = Str(kind='pascal')


class DebugType(enum.Enum):
    LocalSymbol   = 0
    PublicSymbol  = 1
    SegmentSymbol = 2
    LineNumber    = 3

class DebugLineNumber(Struct):
    segment_id  = UInt(8)
    offset      = UInt(16)
    line_number = UInt(16)

class DebugSymbol(Struct):
    segment_id = UInt(8)
    info       = UInt(8)
    offset     = UInt(16)
    pad1       = Data(1)
    name       = Str(kind='pascal')

@record(0x12)
class DebugItems(Struct):
    type  = Enum(DebugType, UInt(8))
    entries = Switch(options={
        DebugType.LineNumber: Arr(DebugLineNumber),
    }, fallback=Arr(DebugSymbol))

    def on_type(self, spec, context):
        spec.entries.selector = self.type

class KeilDebugLineNumber(Struct):
    segment_id  = UInt(8)
    val1        = Data(1)
    offset      = UInt(16)
    line_number = UInt(16)

class KeilDebugSymbol(Struct):
    segment_id = UInt(8)
    info       = UInt(8)
    val1       = UInt(8)
    offset     = UInt(16)
    pad1       = Data(1)
    name       = Str(kind='pascal')

@record(0x23)
class KeilDebugItems(Struct):
    type  = Enum(DebugType, UInt(8))
    entries = Switch(options={
        DebugType.LineNumber: Arr(KeilDebugLineNumber),
    }, fallback=Arr(KeilDebugSymbol))

    def on_type(self, spec, context):
        spec.entries.selector = self.type


@record(0x06)
class Content(Struct):
    segment_id = UInt(8)
    offset     = UInt(16)
    data       = Data(None)

@record(0x07)
class KeilContent(Struct):
    segment_id = UInt(8)
    offset     = UInt(16)
    val1       = UInt(8)
    data       = Data(None)


class FixupType(enum.Enum):
    Low     = 0
    Byte    = 1
    Rel     = 2
    High    = 3
    Word    = 4
    InBlock = 5
    Bit     = 6
    Conv    = 7

class OperandType(enum.Enum):
    Segment     = 0
    Relocatable = 1
    External    = 2

class Fixup(Struct):
    offset = UInt(16)
    type   = Enum(FixupType, UInt(8))
    operand_type   = Enum(OperandType, UInt(8))
    operand_id     = UInt(8)
    operand_offset = UInt(16)

@record(0x08)
class Fixups(Struct):
    fixups = Arr(Fixup)

class KeilFixup(Struct):
    offset = UInt(16)
    type   = Enum(FixupType, UInt(8))
    operand_type   = Enum(OperandType, UInt(8))
    operand_id     = UInt(8)
    operand_val1   = UInt(8)
    operand_offset = UInt(16)

@record(0x09)
class KeilFixups(Struct):
    fixups = Arr(KeilFixup)


@record(0x2C)
class Library(Struct):
    module_count = UInt(16)
    block_offset = UInt(16)
    byte_offset  = UInt(16)

@record(0x28)
class LibraryModuleNames(Struct):
    names = Arr(Str(kind='pascal'))

class LibraryModuleLocation(Struct):
    block_offset = UInt(16)
    byte_offset  = UInt(16)

@record(0x26)
class LibraryModuleLocations(Struct):
    relocations = Arr(LibraryModuleLocation)

@record(0x2A)
class LibrarySymbols(Struct):
    names = Arr(Str(kind='pascal'))


@record(0x24)
class KeilFilename(Struct):
    val   = Data(3)
    name  = Str(kind='pascal')

@record(0x72)
class KeilUnk114(Struct):
    val1 = UInt(8)
    val2 = UInt(16)
    name = Str(kind='pascal')

@record(0x20)
class KeilUnk32(Struct):
    # This is some bullshit
    data = Data(None)


class Record(Struct):
    type     = UInt(8)
    length   = UInt(16)
    contents = Capped(Switch(options=record_types, fallback=Data(None)), exact=True)
    checksum = UInt(8)

    def on_type(self, spec, context):
        spec.contents.child.selector = self.type

    def on_length(self, spec, context):
        spec.contents.limit = self.length - 1

Object = Arr(Record)


if __name__ == '__main__':
    import sys, json

    if len(sys.argv) < 2:
        print('usage: {} <module.lib/obj>'.format(sys.argv[0]))
        sys.exit(1)

    class DestructEncoder(json.JSONEncoder):
        def default(self, o):
            if isinstance(o, bytes):
                return o.hex()
            if hasattr(o, '__iter__'):
                return {k: getattr(o, k) for k in iter(o) if isinstance(k, str)}
            if isinstance(o, enum.Enum):
                return self.encode(o.value)
            return super().default(o)

    with open(sys.argv[1], 'rb') as f, open(sys.argv[1] + '.json', 'w') as o:
        obj = parse(Object, f)
        print(obj)
        json.dump(obj, o, cls=DestructEncoder)