Newer
Older
Tardis / lang / compile.py
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2023 John Watts and the LuminaSensum contributors

import sys

# This is a fairly basic compiler for a simple programming language
# It is split to these phases:
# Parsing (AST and parse sections)
# IR (IR and IR generator sections)
# Assign classes
# Register allocation
# Output (Bytecode section)

## AST

class ASTNumber():
	def __init__(self, number):
		self.number = number

	def __repr__(self):
		return 'ASTNumber(number=%i)' % (self.number)

class ASTBoolean():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return 'ASTBoolean(value=%s)' % (self.value)

class ASTNone():
	def __init__(self):
		pass

	def __repr__(self):
		return 'ASTNone()'

class ASTReference():
	def __init__(self, name):
		self.name = name

	def __repr__(self):
		return 'ASTReference(name="%s")' % (self.name)

class ASTCall():
	def __init__(self, subject, verb, args):
		self.subject = subject
		self.verb = verb
		self.args = args

	def __repr__(self):
		return '\n\t\t\t\tASTCall(subject=%s, verb="%s", args=%s)' % (self.subject, self.verb, self.args)

class ASTJump():
	def __init__(self, call):
		self.call = call

	def __repr__(self):
		return '\n\t\t\tASTJump(call=%s)' % (self.call)

class ASTSet():
	def __init__(self, name, command):
		self.name = name
		self.command = command

	def __repr__(self):
		return '\n\t\t\tASTSet(name="%s", command=%s)' % (self.name, self.command)

class ASTReturn():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return '\n\t\t\tASTReturn(value=%s)' % (self.value)

# Contains: Set/Return
class ASTCommand():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return '\n\t\tASTCommand(value=%s)' % (self.value)

# Contains: Call and Command
class ASTClause():
	def __init__(self, test, success):
		self.test = test
		self.success = success

	def __repr__(self):
		return '\n\t\tASTClause(test=%s, success=%s)' % (self.test, self.success)

# Contains: Call and Command
class ASTConditional():
	def __init__(self, clauses):
		self.clauses = clauses

	def __repr__(self):
		return '\n\t\tASTConditional(clauses=%s)' % (self.clauses)

# Contains: Command or Conditional
class ASTStatement():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return '\n\tASTStatement(value=%s)' % (self.value)

class ASTFunction():
	def __init__(self, public, class_, name, args, statements):
		self.public = public
		self.class_ = class_
		self.name = name
		self.args = args
		self.statements = statements

	def __repr__(self):
		return 'ASTFunction(public=%s, class="%s", name="%s", args="%s", statements=%s)' % \
			(self.public, self.class_, self.name, self.args, self.statements)

## Parser

def tokenize(code):
	line = []
	lines = []
	tok = ""
	for c in code.strip():
		if c == ' ' or c == '\t':
			if tok != "":
				line.append(tok)
			tok = ""
		elif c == '\n':
			if tok != "":
				line.append(tok)
			if line != []:
				lines.append(line)
			line = []
			tok = ""
		else:
			tok += c
	line.append(tok)
	lines.append(line)
	return lines

def parse_value(val):
	if val[0].isdigit():
		return ASTNumber(int(val))
	elif val == "True":
		return ASTBoolean(True)
	elif val == "False":
		return ASTBoolean(False)
	elif val == "None":
		return ASTNone()
	else:
		return ASTReference(val)

def parse_call(line, offset):
	cut = line[offset-1:]
	if len(cut) < 1:
		print("not a call? line: %s" % (' '.join(line)))
		return None
	subject = parse_value(cut[0])
	if subject is None:
		print("not a subject? line: %s" % (' '.join(line)))
		return None
	verb = None
	if len(cut) > 1:
		verb = cut[1]
	args = []
	for arg in cut[2:]:
		val = parse_value(arg)
		if val is None:
			print("not a arg value? line: %s" % (' '.join(line)))
			return None
		args.append(val)
	return ASTCall(subject, verb, args)

def parse_set(line):
	if len(line) < 4 or line[0] != "Set" or line[2] != "To":
		print("not a set? line: %s" % (' '.join(line)))
		return None
	call = parse_call(line, 4)
	if call is None:
		return None
	return ASTSet(line[1], call)

def parse_return(line):
	if len(line) != 2 or line[0] != "Return":
		print("not a return? line: %s" % (' '.join(line)))
		return None
	value = parse_value(line[1])
	if value is None:
		print("not a return value? line: %s" % (' '.join(line)))
		return None
	return ASTReturn(value)

def parse_jump(line):
	if line[0] != "Jump" or len(line) < 3:
		print("not a jump? line: %s" % (' '.join(line)))
		return None
	jump_call = parse_call(line, 2)
	if jump_call is None:
		print("not a jump call? line: %s" % (' '.join(line)))
		return None
	return ASTJump(jump_call)

def parse_command(line):
	instr = line[0]
	command = None
	if instr == "Set":
		command = parse_set(line)
	elif instr == "Return":
		command = parse_return(line)
	elif instr == "Jump":
		command = parse_jump(line)
	if command is None:
		print("not a command? line: %s" % (' '.join(line)))
		return None
	return ASTCommand(command)

def parse_conditional(lines):
	clauses = []
	op = lines[0][0]
	while True:
		op = lines[0][0]
		if op == "If" or op == "ElseIf":
			test_call = parse_call(lines[0], 2)
			if test_call is None:
				print("not an if/elseif call? line: %s" % (' '.join(lines[0])))
				return (None, [])
			lines = lines[1:]
		elif op == "Else":
			test_call = None
		else:
			break # No more clauses
		op = lines[0][0]
		if op == "Then" or (op == "Else" and test_call is None):
			success_command = parse_command(lines[0][1:])
			if success_command is None:
				print("not a then/else command? line: %s" % (' '.join(lines[0])))
				return (None, [])
			lines = lines[1:]
		else:
			print("no then/else? %s" % (' '.join(lines[0])))
			return (None, [])
		clauses.append(ASTClause(test_call, success_command))
	conditional = ASTConditional(clauses)
	return (conditional, lines)

def parse_statement(lines):
	line = lines[0]
	instr = line[0]
	new_lines = lines[1:]
	if instr == "If":
		(command, new_lines) = parse_conditional(lines)
	else:
		command = parse_command(line)
	if command is None:
		print("not a statement? line: %s" % (' '.join(line)))
		return (None, [])
	return (ASTStatement(command), new_lines)

def parse_function(lines):
	line = lines[0]
	lines = lines[1:]
	public = True
	if len(line) < 5 or line[1] != "Class" or line[3] != "Function":
		print("not a function? line: %s" % (' '.join(line)))
		return (None, [])
	if line[0] == "Public":
		public = True
	elif line[0] == "Private":
		public = False
	else:
		print("not a valid privacy? line: %s" % (' '.join(line)))
		return (None, [])
	class_ = line[2]
	name = line[4]
	statements = []
	if len(line) > 5: # we have args
		args = line[6:]
		if line[5] != "Args" or len(args) == 0:
			print("no function args? line: %s" % (' '.join(line)))
			return (None, [])
		if len(set(args)) < len(args):
			print("duplicate function args? line: %s" % (' '.join(line)))
			return (None, [])
	else:
		args = []
	while len(lines) != 0:
		first_word = lines[0][0]
		if first_word == "EndFunction":
			lines = lines[1:]
			break
		(statement, lines) = parse_statement(lines)
		if statement is None:
			return (None, [])
		else:
			statements.append(statement)
	func = ASTFunction(public, class_, name, args, statements)
	return (func, lines)

class ASTMetadata():
	def __init__(self, id, name, uses):
		self.id = id
		self.name = name
		self.uses = uses

	def __repr__(self):
		return 'ASTMetadata(id=%s, name=%s, uses=%s)' % (self.id, self.name, self.uses)

def parse_keyvalue(lines, key):
	line = lines[0]
	lines = lines[1:]
	if line[0] != key:
		return (None, [])
	value = ' '.join(line[1:])
	return (value, lines)

def parse_metadata(lines, id):
	(module, lines) = parse_keyvalue(lines, "Module")
	if module is None:
		print("no module name? line: %s" % (' '.join(line)))
	use_list = []
	while len(lines) != 0:
		cur_lines = lines
		(use, lines) = parse_keyvalue(lines, "Use")
		if use is not None:
			use_list.append(use)
		else:
			lines = cur_lines
			break
	metadata = ASTMetadata(id, module, use_list)
	return (metadata, lines)

def parse_toplevel(lines, id):
	ast = []
	(metadata, lines) = parse_metadata(lines, id)
	if metadata is None:
		return None
	ast.append(metadata)
	while len(lines) != 0:
		(func, lines) = parse_function(lines)
		if func is None:
			return None
		else:
			ast.append(func)
	return ast

## IR

class IRNumber():
	def __init__(self, number):
		self.number = number

	def __repr__(self):
		return 'IRNumber(number=%i)' % (self.number)

class IRBoolean():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return 'IRBoolean(value=%s)' % (self.value)

class IRNone():
	def __init__(self):
		pass

	def __repr__(self):
		return 'IRNone()'

class IRAllocate():
	def __init__(self, size):
		self.size = size

	def __repr__(self):
		return 'IRAllocate(size=%i)' % (self.size)

class IRDrop():
	def __init__(self, size):
		self.size = size

	def __repr__(self):
		return 'IRDrop(number=%i)' % (self.size)

class IRSelf():
	def __repr__(self):
		return 'IRSelf'

class IRLoad():
	def __init__(self, variable, pos):
		self.variable = variable
		self.pos = pos

	def __repr__(self):
		return 'IRLoad(variable=%s, pos=%i)' % (self.variable, self.pos)

class IRLoadArg():
	def __init__(self, arg, index):
		self.arg = arg
		self.index = index

	def __repr__(self):
		return 'IRLoadArg(arg=%s, index=%i)' % (self.arg, self.index)

class IRStore():
	def __init__(self, variable, pos, create):
		self.variable = variable
		self.pos = pos
		self.create = create

	def __repr__(self):
		return 'IRStore(variable=%s, pos=%i, create=%s)' \
			% (self.variable, self.pos, self.create)

class IRCall():
	def __init__(self, name, args):
		self.name = name
		self.args = args

	def __repr__(self):
		return 'IRCall(name="%s", args=%i)' % (self.name, self.args)

class IRTailCall():
	def __init__(self, name, args):
		self.name = name
		self.args = args

	def __repr__(self):
		return 'IRTailCall(name="%s", args=%i)' % (self.name, self.args)

class IRReturn():
	def __repr__(self):
		return 'IRReturn()'

class IRFunction():
	def __init__(self, public, class_, name, args, statements):
		self.public = public
		self.class_ = class_
		self.name = name
		self.args = args
		self.statements = statements

	def __repr__(self):
		return 'IRFunction(public=%s, class="%s", name="%s", args="%s", statements=%s)' % \
			(self.public, self.class_, self.name, self.args, self.statements)

class IRDepthCheck():
	def __init__(self, depth):
		self.depth = depth

	def __repr__(self):
		return 'IRDepthCheck(depth=%i)' % (self.depth)

class IRLabel():
	def __init__(self, name):
		self.name = name

	def __repr__(self):
		return 'IRLabel(name=%s)' % (self.name)

class IRJumpFalse():
	def __init__(self, name):
		self.name = name

	def __repr__(self):
		return 'IRJumpFalse(name=%s)' % (self.name)

class IRJumpAlways():
	def __init__(self, name):
		self.name = name

	def __repr__(self):
		return 'IRJumpAlways(name=%s)' % (self.name)

class IRMetadata():
	def __init__(self, id, name, uses):
		self.id = id
		self.name = name
		self.uses = uses

	def __repr__(self):
		return 'IRMetadata(id=%s, name=%s, uses=%s)' % (self.id, self.name, self.uses)

class IRClass():
	def __init__(self, name, functions):
		self.name = name
		self.functions = functions

	def __repr__(self):
		return 'IRClass(name=%s, functions=%s)' % (self.name, self.functions)

## IR Generator

def generate_ir_value(value):
	if isinstance(value, ASTNumber):
		return [IRNumber(value.number)]
	elif isinstance(value, ASTBoolean):
		return [IRBoolean(value.value)]
	elif isinstance(value, ASTNone):
		return [IRNone()]
	elif isinstance(value, ASTReference):
		if value.name == "Self":
			return [IRSelf()]
		else:
			return [IRLoad(value.name, -1)]
	else:
		print("Unknown value ast node: %s" % (node))
		return None

def generate_ir_call(ast, is_jump):
	final_ir = []
	subject_ir = generate_ir_value(ast.subject)
	if subject_ir is None:
		print("Unknown subject ast node: %s" % (node))
		return None
	if ast.verb is None:
		return subject_ir
	args_ir = []
	for arg in ast.args:
		arg_ir = generate_ir_value(arg)
		if arg_ir is None:
			print("Unknown arg ast node: %s" % (arg))
			return None
		args_ir = args_ir + arg_ir
	args_count = len(ast.args)
	alloc_ir = [IRAllocate(1)]
	if is_jump:
		call_ir = [IRTailCall(ast.verb, args_count)]
	else:
		call_ir = [IRCall(ast.verb, args_count)]
	final_ir = final_ir + alloc_ir
	final_ir = final_ir + args_ir
	final_ir = final_ir + subject_ir
	final_ir = final_ir + call_ir
	return final_ir

def generate_ir_jump(ast):
	# juggle args to arg count then cleanup stack
	return generate_ir_call(ast.call, True)

def generate_ir_set(ast, create):
	command = ast.command
	sub_ir = generate_ir_call(command, False)
	if sub_ir is None:
		print("Unknown set ast node: %s" % (node))
		return None
	store = IRStore(ast.name, -1, create)
	return sub_ir + [store]

def generate_ir_return(ast):
	value_ir = generate_ir_value(ast.value)
	if value_ir is None:
		print("Unknown return ast node: %s" % (node))
		return None
	store_ret = IRStore("Return", -1, False)
	ret = IRReturn()
	return value_ir + [store_ret, ret]

def generate_ir_command(ast, create):
	node = ast.value
	sub_ir = None
	if isinstance(node, ASTSet):
		sub_ir = generate_ir_set(node, create)
	elif isinstance(node, ASTReturn):
		sub_ir = generate_ir_return(node)
	elif isinstance(node, ASTJump):
		sub_ir = generate_ir_jump(node)
	if sub_ir is None:
		print("Unknown command ast node: %s" % (node))
		return None
	return sub_ir

def generate_ir_conditional(ast, statement_id):
	ir = []
	clause_prefix = "Statement" + str(statement_id) + "Clause"
	label_clause = ""
	label_next = ""
	label_end = clause_prefix + "End"
	clause_num = 1
	for clause in ast.clauses:
		label_clause = clause_prefix + str(clause_num)
		label_next = clause_prefix + str(clause_num + 1)
		clause_num += 1
		ir.append(IRLabel(label_clause))
		if clause.test:
			ir += generate_ir_call(clause.test, False)
			ir.append(IRJumpFalse(label_next))
		ir += generate_ir_command(clause.success, False)
		ir.append(IRJumpAlways(label_end))
	# End
	ir.append(IRLabel(label_next))
	ir.append(IRLabel(label_end))
	return ir

def generate_ir_statement(ast, id):
	node = ast.value
	sub_ir = None
	if isinstance(node, ASTConditional):
		sub_ir = generate_ir_conditional(node, id)
	elif isinstance(node, ASTCommand):
		sub_ir = generate_ir_command(node, True)
	elif isinstance(node, ASTJump):
		sub_ir = generate_ir_jump(node)
	if sub_ir is None:
		print("Unknown statement ast node: %s" % (node))
		return None
	return sub_ir

def generate_ir_function(ast):
	public = ast.public
	class_ = ast.class_
	name = ast.name
	args = ast.args
	ir = []
	statement_id = 1
	for node in ast.statements:
		sub_ir = None
		if isinstance(node, ASTStatement):
			sub_ir = generate_ir_statement(node, statement_id)
		if sub_ir is None:
			print("Unknown function statement ast node: %s" % (node))
			return None
		ir = ir + sub_ir
		statement_id += 1
	return IRFunction(public, class_, name, ast.args, ir)

def generate_ir_metadata(ast):
	return IRMetadata(ast.id, ast.name, ast.uses)

def generate_ir(ast):
	ir = []
	for node in ast:
		sub_ir = None
		if isinstance(node, ASTMetadata):
			sub_ir = generate_ir_metadata(node)
		elif isinstance(node, ASTFunction):
			sub_ir = generate_ir_function(node)
		if sub_ir is None:
			print("Unknown ast node: %s" % (node))
			return None
		ir.append(sub_ir)
	return ir

## Assign classes

def group_functions(ir):
	groups = {}
	for node in ir:
		if isinstance(node, IRFunction):
			class_ = node.class_
			name = node.name
			if class_ not in groups:
				groups[class_] = []
			group = groups[class_]
			if name in group:
				print("Duplicate class function: %s" % name)
				return None
			groups[class_].append(name)
	return groups

def make_module_class(ir):
	metadata = None
	for node in ir:
		if isinstance(node, IRMetadata):
			metadata = node
	if metadata is None:
		print("Unable to find module metadata?")
		return None
	return IRClass(metadata.name, [])

def map_class_functions(functions, classes):
	mapped_classes = []
	for group in functions:
		name = group
		class_ = None
		for c in classes:
			if c.name == name:
				class_ = c
		if class_ is None:
			print("Unknown class %s" % (name))
			return None
		funcs = functions[name]
		new_class = IRClass(name, funcs)
		mapped_classes.append(new_class)
	return mapped_classes

def update_classes(old_ir, new_classes):
	ir = []
	for node in old_ir:
		sub_ir = node
		if isinstance(node, IRClass):
			for c in new_classes:
				if c.name == node.name:
					sub_ir = c
		if sub_ir is None:
			print("No class for IR: %s" % (node))
			return None
		ir.append(sub_ir)
	return ir

def assign_classes(ir):
	grouped = group_functions(ir)
	if grouped is None:
		return None
	module_class = make_module_class(ir)
	if module_class is None:
		return None
	old_ir = ir + [module_class]
	classes = [module_class]
	mapped_classes = map_class_functions(grouped, classes)
	if mapped_classes is None:
		return None
	new_ir = update_classes(old_ir, mapped_classes)
	return new_ir

## Register allocation

# Register allocation here is very simple: Each variable gets one stack slot
# No further analysis is done to optimize use of the stack

def find_module_args(ast):
	metadata = None
	for node in ast:
		if isinstance(node, ASTMetadata):
			metadata = node
	if metadata is None:
		print("Unable to find AST metadata")
		return None
	# Ordered list of arguments: First ones are module uses
	args = metadata.uses
	return args

def find_variables(module_args, ir):
	vars = {'Return': 0}
	for index, arg in enumerate(module_args):
		vars[arg] = -1 - index
	var_count = 0
	for arg in ir.args:
		var_count += 1
		vars[arg] = var_count
	for node in ir.statements:
		if isinstance(node, IRStore):
			var = node.variable
			if var in module_args:
				print("Setting module arg: %s" % (var))
				return None
			if var not in vars and node.create:
				var_count += 1
				vars[var] = var_count
		elif isinstance(node, IRLoad):
			var = node.variable
			if var not in vars:
				print("Unset variable: %s" % (var))
				return None
	return vars

def replace_variables(ir, variables):
	new_ir = []
	vars_args = len(ir.args) + 1 # Include Return
	vars_local = len(variables) - vars_args
	vars_all = vars_args + vars_local
	vars_cleanup = vars_all - 1 # Retain Return
	new_ir.append(IRDepthCheck(vars_args))
	if vars_local != 0:
		new_ir.append(IRAllocate(vars_local))
	for node in ir.statements:
		if isinstance(node, IRLoad):
			reg = variables[node.variable]
			if reg < 0:
				new_ir.append(IRLoadArg(node.variable, -reg - 1))
			else:
				new_ir.append(IRLoad(node.variable, reg))
		elif isinstance(node, IRStore):
			reg = variables[node.variable]
			new_ir.append(IRStore(node.variable, reg, node.create))
			continue
		elif isinstance(node, IRCall):
			new_ir.append(node)
			# Include Return and function call return
			new_ir.append(IRDepthCheck(vars_all + 1))
			continue
		elif isinstance(node, IRReturn):
			if vars_cleanup != 0:
				new_ir.append(IRDrop(vars_cleanup))
			new_ir.append(IRReturn())
		else:
			new_ir.append(node)
	return new_ir

def registers_allocate_func(module_args, ir):
	registers = find_variables(module_args, ir)
	if registers is None:
		return None
	new_ir = replace_variables(ir, registers)
	public = ir.public
	class_ = ir.class_
	name = ir.name
	args = ir.args
	return IRFunction(public, class_, name, args, new_ir)

def registers_allocate(module_args, ir):
	new_ir = []
	for node in ir:
		sub_ir = None
		if isinstance(node, IRFunction):
			sub_ir = registers_allocate_func(module_args, node)
		else:
			sub_ir = node
		new_ir.append(sub_ir)
	return new_ir

## Bytecode generation

class Bytecode():
	def __init__(self, bytes, comments, labels, fixups):
		self.bytes = bytes
		self.comments = comments
		self.labels = labels
		self.fixups = fixups

	def __repr__(self):
		return 'Bytecode(bytes=%s, comments=%s, labels=%s, fixups=%s)' \
			% (self.bytes, self.comments, self.labels, self.fixups)

def generate_bytecode_function(ir):
	bytes = b""
	comments = {}
	labels = {}
	fixups = {}
	for node in ir.statements:
		pos = len(bytes)
		if comments.get(pos) is None:
			comments[pos] = []
		comments[pos].append(str(node))
		if isinstance(node, IRNumber):
			bytes += b"\x01" # OP_RATIO1
			bytes += node.number.to_bytes(4, 'little')
		elif isinstance(node, IRBoolean):
			bool_value = int(node.value == 1)
			bytes += b"\x0B" # OP_BOOLEAN
			bytes += bool_value.to_bytes(1, 'little')
		elif isinstance(node, IRNone):
			bytes += b"\x05" # OP_NONE
		elif isinstance(node, IRAllocate):
			for i in range(0, node.size):
				bytes += b"\x05" # OP_NONE
		elif isinstance(node, IRDrop):
			bytes += b"\x08" # OP_DROP
			bytes += node.size.to_bytes(1)
		elif isinstance(node, IRLoad):
			index = int(node.pos)
			bytes += b"\x06" # OP_GET
			bytes += index.to_bytes(1)
		elif isinstance(node, IRLoadArg):
			index = int(node.index)
			bytes += b"\x0F" # OP_GET_ARG
			bytes += index.to_bytes(1)
		elif isinstance(node, IRStore):
			index = int(node.pos)
			bytes += b"\x07" # OP_SET
			bytes += index.to_bytes(1)
		elif isinstance(node, IRCall):
			bytes += b"\x04" # OP_CALL
			bytes += (node.args + 1).to_bytes(1)
			bytes += node.name.encode('utf-8')
			bytes += b"\x00" # NULL terminator
		elif isinstance(node, IRTailCall):
			bytes += b"\x0E" # OP_TAIL_CALL
			bytes += (node.args + 1).to_bytes(1)
			bytes += node.name.encode('utf-8')
			bytes += b"\x00" # NULL terminator
		elif isinstance(node, IRReturn):
			bytes += b"\x03" # OP_RET
		elif isinstance(node, IRDepthCheck):
			depth = int(node.depth)
			bytes += b"\x09" # OP_DEPTH_CHECK
			bytes += depth.to_bytes(1)
		elif isinstance(node, IRSelf):
			bytes += b"\x0A" # OP_SELF
		elif isinstance(node, IRLabel):
			labels[node.name] = pos
		elif isinstance(node, IRJumpFalse):
			bytes += b"\x0C" # OP_JUMP_FALSE
			fixups[len(bytes)] = node.name
			bytes += b"\x00" # Fixed up later
		elif isinstance(node, IRJumpAlways):
			bytes += b"\x0D" # OP_JUMP_ALWAYS
			fixups[len(bytes)] = node.name
			bytes += b"\x00" # Fixed up later
		else:
			print("Unknown bytecode node: %s" % (node))
			return None
	comments[len(bytes)] = ["OP_END trailer"]
	bytes += b"\x00"
	return Bytecode(bytearray(bytes), comments, labels, fixups)

def bytecode_fixup(bytecode):
	bytes = bytecode.bytes
	comments = bytecode.comments
	for (fixup, name) in bytecode.fixups.items():
		target = bytecode.labels[name]
		offset = target - fixup
		offset -= 1 # Count from after this byte
		bytes[fixup] = offset
	return Bytecode(bytes, comments, {}, {})

def generate_bytecode_c(bytecode):
	output = ""
	for i in range(0, len(bytecode.bytes)):
		byte = bytecode.bytes[i]
		comments = bytecode.comments.get(i)
		if comments:
			for c in comments:
				output += "\n\t// %s" % (c)
			output += "\n\t"
		output += "0x%02x, " % (byte)
	output += "\n"
	return output

## C file output

def generate_header(source):
	include_output = ""
	includes = [
		'"bytecode.h"',
		'"module.h"',
		'"object.h"',
		'"vm.h"',
		'<stddef.h>',
	]
	for i in includes:
		include_output += "\n#include %s" % i
	return """/* Autogenerated by compile.py.

%s
*/

#define MODULE_INTERNAL_API
%s""" % (source, include_output)

def generate_bytecode(node):
	bytecode = generate_bytecode_function(node)
	if bytecode is None:
		print("Unknown bytecode node: %s" % (node))
		return None
	fixed_bytecode = bytecode_fixup(bytecode)
	c_bytecode = generate_bytecode_c(fixed_bytecode)
	return "static const unsigned char bytecode_%s_%s [] = {%s};\n\n" \
		% (node.class_, node.name, c_bytecode)

def generate_module_args(module_args):
	return """static ObjectList create_args(VmState state, ObjectList use_modules) {
	(void)state; (void)use_modules;
	int args_count = %i;
	ObjectList list = object_list_create(state, args_count);
	vm_abort_if(state, object_list_length(state, use_modules) < args_count, "create_args: Not enough use_modules");
	for(int i = 0; i < args_count; ++i) {
		Object obj = object_list_get(state, use_modules, i);
		object_list_set(state, use_modules, i, object_none());
		object_list_set(state, list, i, obj);
	}
	return list;
}

static void free_args(VmState state, ObjectList *args) {
	for(int i = 0; i < object_list_length(state, *args); ++i)
		object_list_set(state, *args, i, object_none());
	object_list_free(state, args);
}""" % (len(module_args))

def generate_class_boilerplate():
	return """static struct object_class CLASSNAME_class;

static Object CLASSNAME_create(VmState state, ObjectList use_modules) {
	Object obj = object_create(state, &CLASSNAME_class, sizeof(ObjectList));
	ObjectList *args = (ObjectList *)object_priv(state, obj, &CLASSNAME_class);
	*args = create_args(state, use_modules);
	return obj;
}

static void CLASSNAME_cleanup(VmState state, Object obj) {
	ObjectList *args = (ObjectList *)object_priv(state, obj, &CLASSNAME_class);
	free_args(state, args);
}"""

def generate_class_call(class_name, func_name, index):
	call = '{.name = "%s", .handler = CLASSNAME_call_bytecode, .priv = %s },\n\t' % (func_name, index)
	return call

def generate_class_structs(class_ir):
	calls = ""
	bytecodes = ""
	index = 0
	for func in class_ir.functions:
		calls += generate_class_call(class_ir.name, func, index)
		bytecodes += "bytecode_%s_%s,\n\t" % (class_ir.name, func)
		index += 1
	return """static const unsigned char *CLASSNAME_bytecodes[] = {
	%sNULL,
};

static void CLASSNAME_call_bytecode(VmState state, Object obj, int priv) {
	struct object_list **args = (struct object_list **)object_priv(state, obj, &CLASSNAME_class);
	const unsigned char *bytecode = CLASSNAME_bytecodes[priv];
	bytecode_run(state, obj, bytecode, *args);
}

static struct object_call CLASSNAME_calls[] = {
	%s{.name = NULL, /* end */ },
};

static struct object_class CLASSNAME_class = {
	.cleanup = CLASSNAME_cleanup,
	.calls = &CLASSNAME_calls[0],
};""" % (bytecodes, calls)

def generate_c_class(class_ir):
	functions = []
	output = generate_class_boilerplate()
	output += "\n\n"
	output += generate_class_structs(class_ir)
	output = output.replace('CLASSNAME', class_ir.name)
	return output

def generate_metadata(node):
	uses = ""
	for use in node.uses:
		uses += "\t\"%s\",\n" % (use)
	return """static const char* module_uses[] = {
%s	NULL, /* end */
};

const struct module_info module_info_%s = {
	.name = "%s",
	.uses = module_uses,
	.create = %s_create,
};
""" % (uses, node.id, node.name, node.name)

def generate_c_file(module_args, ir, source):
	metadata = ""
	bytecodes = ""
	classes = ""
	for node in ir:
		success = False
		if isinstance(node, IRMetadata):
			metadata = generate_metadata(node)
			success = metadata is not None
		elif isinstance(node, IRFunction):
			bytecode = generate_bytecode(node)
			bytecodes += bytecode
			success = bytecode is not None
		elif isinstance(node, IRClass):
			class_output = generate_c_class(node)
			classes += "\n\n/* CLASS %s */\n\n" % (node.name)
			classes += class_output
			success = class_output is not None
		if not success:
			print("Can't generate node: %s" % (node))
			return None
	output = generate_header(source)
	output += "\n\n/* BYTECODES */\n\n"
	output += bytecodes
	output += "/* MODULE ARGS */\n\n"
	output += generate_module_args(module_args)
	output += classes
	output += "\n\n/* METADATA */\n\n"
	output += metadata
	return output

## Main wrapper

def compile(code, id):
	lines = tokenize(code)
	if lines is None:
		print("Failed to tokenize file")
		return None
	ast = parse_toplevel(lines, id)
	if ast is None:
		print("Failed to parse file")
		return None
	module_args = find_module_args(ast)
	if module_args is None:
		print("Failed to find module arguments")
		return None
	ir = generate_ir(ast)
	if ir is None:
		print("Failed to generate IR")
		return None
	ir_class = assign_classes(ir)
	if ir_class is None:
		print("Failed to generate classes")
		return None
	ir_reg = registers_allocate(module_args, ir_class)
	if ir_reg is None:
		print("Failed to allocate registers")
		return None
	c_code = generate_c_file(module_args, ir_reg, code)
	if c_code is None:
		print("Failed to generate C file")
		return None
	return c_code

if __name__ == "__main__":
	input = open(sys.argv[1], "r")
	output = open(sys.argv[2], "w")
	id = sys.argv[3]
	code = input.read()
	c_code = compile(code, id)
	if c_code:
		output.write(c_code)
	input.close()
	output.close()
	sys.exit(c_code is None)