Newer
Older
Tardis / lang / compile.py
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2023 John Watts and the LuminaSensum contributors

import sys

# This is a fairly basic compiler for a simple programming language
# It is split to three phases:
# Parsing (AST and parse sections)
# IR (IR and IR generator sections)
# Register allocation
# Output (Bytecode section)

## AST

class ASTNumber():
	def __init__(self, number):
		self.number = number

	def __repr__(self):
		return 'ASTNumber(number=%i)' % (self.number)

class ASTReference():
	def __init__(self, name):
		self.name = name

	def __repr__(self):
		return 'ASTReference(name="%s")' % (self.name)

class ASTCall():
	def __init__(self, subject, verb, args):
		self.subject = subject
		self.verb = verb
		self.args = args

	def __repr__(self):
		return '\n\t\tASTCall(subject=%s, verb="%s", args=%s)' % (self.subject, self.verb, self.args)

class ASTSet():
	def __init__(self, name, command):
		self.name = name
		self.command = command

	def __repr__(self):
		return '\n\tASTSet(name="%s", command=%s)' % (self.name, self.command)

class ASTReturn():
	def __init__(self, value):
		self.value = value

	def __repr__(self):
		return '\n\tASTReturn(value=%s)' % (self.value)

class ASTFunction():
	def __init__(self, name, statements):
		self.name = name
		self.statements = statements

	def __repr__(self):
		return 'ASTFunction(name="%s", statements=%s)' % (self.name, self.statements)

## Parser

def tokenize(code):
	line = []
	lines = []
	tok = ""
	for c in code.strip():
		if c == ' ':
			line.append(tok)
			tok = ""
		elif c == '\n':
			line.append(tok)
			lines.append(line)
			line = []
			tok = ""
		else:
			tok += c
	line.append(tok)
	lines.append(line)
	return lines

def parse_value(val):
	if val[0].isdigit():
		return ASTNumber(int(val))
	else:
		return ASTReference(val)

def parse_call(line, offset):
	cut = line[offset-1:]
	if len(cut) < 1:
		print("not a call? line: %s" % (' '.join(line)))
		return None
	subject = parse_value(cut[0])
	if not subject:
		print("not a subject? line: %s" % (' '.join(line)))
		return None
	verb = None
	if len(cut) > 1:
		verb = cut[1]
	args = []
	for arg in cut[2:]:
		val = parse_value(arg)
		if not val:
			print("not a arg value? line: %s" % (' '.join(line)))
			return None
		args.append(val)
	return ASTCall(subject, verb, args)

def parse_set(line):
	if len(line) < 4 or line[0] != "Set" or line[2] != "To":
		print("not a set? line: %s" % (' '.join(line)))
		return None
	call = parse_call(line, 4)
	if not call:
		return None
	return ASTSet(line[1], call)

def parse_return(line):
	if len(line) < 2 or line[0] != "Return":
		print("not a return? line: %s" % (' '.join(line)))
		return None
	value = parse_value(line[1])
	if not value:
		print("not a return value? line: %s" % (' '.join(line)))
		return None
	return ASTReturn(value)

def parse_statement(lines):
	line = lines[0]
	instr = line[0]
	statement = None
	if instr == "Set":
		statement = parse_set(line)
	elif instr == "Return":
		statement = parse_return(line)
	else:
		print("not a statement? line: %s" % (' '.join(line)))
		return (statement, lines[1:])
	if statement:
		return (statement, lines[1:])
	else:
		return (None, [])

def parse_function(lines):
	line = lines[0]
	lines = lines[1:]
	if len(line) != 2 or line[0] != "Function":
		print("not a function? line: %s" % (' '.join(line)))
		return (None, [])
	name = line[1]
	statements = []
	while len(lines) != 0:
		first_word = lines[0][0]
		if first_word == "EndFunction":
			lines = lines[1:]
			break
		(statement, lines) = parse_statement(lines)
		if statement == None:
			return (None, [])
		else:
			statements.append(statement)
	func = ASTFunction(name, statements)
	return (func, lines)

def parse_toplevel(lines):
	ast = []
	while len(lines) != 0:
		(func, lines) = parse_function(lines)
		if func == None:
			return None
		else:
			ast.append(func)
	return ast

## IR

class IRNumber():
	def __init__(self, number):
		self.number = number

	def __repr__(self):
		return '\n\tIRNumber(number=%i)' % (self.number)

class IRAllocate():
	def __init__(self, size):
		self.size = size

	def __repr__(self):
		return '\n\tIRAllocate(size=%i)' % (self.size)

class IRDrop():
	def __init__(self, size):
		self.size = size

	def __repr__(self):
		return '\n\tIRDrop(number=%i)' % (self.size)

class IRSelf():
	def __repr__(self):
		return '\n\tIRSelf'

class IRLoad():
	def __init__(self, variable):
		self.variable = variable

	def __repr__(self):
		return '\n\tIRLoad(variable=%s)' % (self.variable)

class IRStore():
	def __init__(self, variable):
		self.variable = variable

	def __repr__(self):
		return '\n\tIRStore(variable=%s)' % (self.variable)

class IRCall():
	def __init__(self, name, args):
		self.name = name
		self.args = args

	def __repr__(self):
		return '\n\tIRCall(name="%s", args=%i)' % (self.name, self.args)

class IRReturn():
	def __repr__(self):
		return '\n\tIRReturn()'

class IRFunction():
	def __init__(self, name, statements):
		self.name = name
		self.statements = statements

	def __repr__(self):
		return 'IRFunction(name="%s", statements=%s)' % (self.name, self.statements)

class IRDepthCheck():
	def __init__(self, depth):
		self.depth = depth

	def __repr__(self):
		return '\n\tIRDepthCheck(depth=%i)' % (self.depth)

## IR Generator

def generate_ir_value(value):
	if isinstance(value, ASTNumber):
		return [IRNumber(value.number)]
	elif isinstance(value, ASTReference):
		if value.name == "Self":
			return [IRSelf()]
		else:
			return [IRLoad(value.name)]
	else:
		print("Unknown value ast node: %s" % (node))
		return None

def generate_ir_call(ast):
	final_ir = []
	subject_ir = generate_ir_value(ast.subject)
	if not subject_ir:
		print("Unknown subject ast node: %s" % (node))
		return None
	if not ast.verb:
		return subject_ir
	args_ir = []
	for arg in ast.args:
		arg_ir = generate_ir_value(arg)
		if not arg_ir:
			print("Unknown arg ast node: %s" % (arg))
			return None
		args_ir = args_ir + arg_ir
	args_count = len(ast.args)
	alloc_ir = [IRAllocate(1)]
	call_ir = [IRCall(ast.verb, args_count)]
	final_ir = final_ir + alloc_ir
	final_ir = final_ir + args_ir
	final_ir = final_ir + subject_ir
	final_ir = final_ir + call_ir
	return final_ir

def generate_ir_set(ast):
	command = ast.command
	sub_ir = generate_ir_call(command)
	if not sub_ir:
		print("Unknown set ast node: %s" % (node))
		return None
	store = IRStore(ast.name)
	return sub_ir + [store]

def generate_ir_return(ast):
	value_ir = generate_ir_value(ast.value)
	if not value_ir:
		print("Unknown return ast node: %s" % (node))
		return None
	store_ret = IRStore("Return")
	ret = IRReturn()
	return value_ir + [store_ret, ret]

def generate_ir_function(ast):
	name = ast.name
	ir = []
	for node in ast.statements:
		sub_ir = None
		if isinstance(node, ASTSet):
			sub_ir = generate_ir_set(node)
		elif isinstance(node, ASTReturn):
			sub_ir = generate_ir_return(node)
		if not sub_ir:
			print("Unknown statement ast node: %s" % (node))
			return None
		ir = ir + sub_ir
	return IRFunction(name, ir)

def generate_ir(ast):
	ir = []
	for node in ast:
		sub_ir = None
		if isinstance(node, ASTFunction):
			sub_ir = generate_ir_function(node)
		if not sub_ir:
			print("Unknown ast node: %s" % (node))
			return None
		ir.append(sub_ir)
	return ir

## Register allocation

# Register allocation here is very simple: Each variable gets one stack slot
# No further analysis is done to optimize use of the stack

def find_variables(ir):
	vars = {'Return': 0}
	var_count = 0
	for node in ir:
		if isinstance(node, IRStore):
			var = node.variable
			if var not in vars:
				var_count += 1
				vars[var] = var_count
		elif isinstance(node, IRLoad):
			var = node.variable
			if var not in vars:
				print("Unset variable: %s" % (var))
				return None
	return vars

def replace_variables(ir, variables):
	new_ir = []
	new_ir.append(IRDepthCheck(1))
	var_count = len(variables) - 1 # Ignore Return
	if var_count != 0:
		new_ir.append(IRAllocate(var_count))
	for node in ir:
		if isinstance(node, IRLoad):
			reg = variables[node.variable]
			new_ir.append(IRLoad(reg))
		elif isinstance(node, IRStore):
			reg = variables[node.variable]
			new_ir.append(IRStore(reg))
			continue
		elif isinstance(node, IRCall):
			new_ir.append(node)
			# Include Return and function call return
			new_ir.append(IRDepthCheck(var_count + 2))
			continue
		elif isinstance(node, IRReturn):
			if var_count != 0:
				new_ir.append(IRDrop(var_count))
			new_ir.append(IRReturn())
		else:
			new_ir.append(node)
	return new_ir

def registers_allocate_func(ir):
	registers = find_variables(ir.statements)
	if not registers:
		return None
	new_ir = replace_variables(ir.statements, registers)
	name = ir.name
	return IRFunction(name, new_ir)

def registers_allocate(ir):
	new_ir = []
	for node in ir:
		sub_ir = None
		if isinstance(node, IRFunction):
			sub_ir = registers_allocate_func(node)
		if not sub_ir:
			return None
		new_ir.append(sub_ir)
	return new_ir

## Bytecode generation

def generate_bytecode_function(ir):
	bytes = b""
	for node in ir.statements:
		if isinstance(node, IRNumber):
			bytes += b"\x01" # OP_NUM
			bytes += node.number.to_bytes(4, 'little')
		elif isinstance(node, IRAllocate):
			for i in range(0, node.size):
				bytes += b"\x05" # OP_NULL
		elif isinstance(node, IRDrop):
			bytes += b"\x08" # OP_DROP
			bytes += node.size.to_bytes(1)
		elif isinstance(node, IRLoad):
			index = int(node.variable)
			bytes += b"\x06" # OP_GET
			bytes += index.to_bytes(1)
		elif isinstance(node, IRStore):
			index = int(node.variable)
			bytes += b"\x07" # OP_SET
			bytes += index.to_bytes(1)
		elif isinstance(node, IRCall):
			bytes += b"\x04" # OP_CALL
			bytes += (node.args + 1).to_bytes(1)
			bytes += node.name.encode('utf-8')
			bytes += b"\x00" # NULL terminator
		elif isinstance(node, IRReturn):
			bytes += b"\x03" # OP_RET
		elif isinstance(node, IRDepthCheck):
			depth = int(node.depth)
			bytes += b"\x09" # OP_DEPTH_CHECK
			bytes += depth.to_bytes(1)
		elif isinstance(node, IRSelf):
			bytes += b"\x0A" # OP_SELF
		else:
			print("Unknown bytecode node: %s" % (node))
			return None
	bytes += b"\x00" # OP_END
	return bytes

## C file output

def generate_header(ir, source):
	header = "/* Autogenerated by compile.py.\n\n"
	header += source
	for i in ir:
		header += "\n" + str(i) + "\n"
	header += "*/\n\n"
	includes = [
		'"bytecode.h"',
		'"error.h"',
		'"object.h"',
		'"vm.h"',
		'<stddef.h>',
	]
	for i in includes:
		header += "#include %s\n" % i
	header += """
static struct object_class module_class;

Object module_create(void) {
	Object obj = object_create(&module_class, 1);
	abort_if(!obj, "unable to allocate module");
	return obj;
}

static void module_cleanup(Object obj) { (void)obj; }

static void module_call(VmState state, Object obj, void *priv) {
	bytecode_run(state, obj, (const unsigned char *)priv);
}\n\n"""
	return header

def generate_footer(verbs):
	footer = """static struct object_call calls[] = {\n"""
	for verb in verbs:
		footer += "\t" + verb + "\n"
	footer += "\t{.name = NULL, /* end */}\n};"
	footer += """\n
static struct object_class module_class = {
	.cleanup = module_cleanup,
	.calls = &calls[0],
};"""
	return footer

def generate_bytecode(node):
	bytes = generate_bytecode_function(node)
	if bytes is None:
		print("Unknown bytecode node: %s" % (node))
		return None
	output = "static const unsigned char "
	output += "bytecode_%s [] = {\n\t" % (node.name)
	col = 0
	for b in bytes:
		output += "0x%02x, " % (b)
		col += 1
		if col == 10:
			output += "\n\t"
			col = 0
	if col != 0:
		output += "\n"
	output += "};\n\n"
	return output

def generate_call(node):
	call = '{.name = "%s", .handler = module_call, ' % (node.name)
	call += '\n\t\t.priv = (void *)bytecode_%s},' % (node.name)
	return call

def generate_c_file(ir, source):
	output = generate_header(ir, source)
	verbs = []
	for node in ir:
		next = None
		if isinstance(node, IRFunction):
			next = generate_bytecode(node)
			verbs.append(generate_call(node))
		if next is None:
			print("Can't generate node: %s" % (node))
			return None
		output += next
	footer = generate_footer(verbs)
	if footer is None:
		print("Can't generate footer")
		return None
	output += footer
	return output

## Main wrapper

def compile(code):
	lines = tokenize(code)
	if not lines:
		return None
	ast = parse_toplevel(lines)
	if not ast:
		return None
	ir = generate_ir(ast)
	if not ir:
		return None
	ir_reg = registers_allocate(ir)
	if not ir_reg:
		return None
	c_code = generate_c_file(ir_reg, code)
	if not c_code:
		return None
	return c_code

if __name__ == "__main__":
	input = open(sys.argv[1], "r")
	output = open(sys.argv[2], "w")
	code = input.read()
	c_code = compile(code)
	output.write(c_code)
	input.close()
	output.close()