Newer
Older
Tardis / lang / module.c
// SPDX-License-Identifier: MIT
// Copyright (c) 2023 John Watts and the LuminaSensum contributors

#define MODULE_INTERNAL_API
#include "module.h"
#include "error.h"
#include "object.h"
#include "string.h"
#include "vm.h"

// Useful macro for getting array size
#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x[0]))

// Use X macros to declare external module variables
#define X(id) extern ModuleInfo module_info_##id;
#include <module_ids.x>
#undef X

// Use X macros to build a list of modules
#define X(id) &module_info_##id,
static ModuleInfo *modules_list[] = {
#include <module_ids.x>
};
#undef X

// Size of a stack in elemenets
#define STACK_SIZE 64

// A stack of ModuleInfo pointers
struct module_stack {
	const ModuleInfo *elems[STACK_SIZE];
	int next;
};

// A stack of Objects
struct object_stack {
	Object elems[STACK_SIZE];
	int next;
};

// A mapping mapping between ModuleInfo and Object
// Each stack element should have a corresponding element
// in the other stack at the same index
struct module_mapping {
	struct module_stack infos;
	struct object_stack objects;
};

// Finds a ModuleInfo by name
static const ModuleInfo *info_by_name(const char *name) {
	for (unsigned int i = 0; i < ARRAY_SIZE(modules_list); ++i) {
		const ModuleInfo *info = modules_list[i];
		if (strcmp(info->name, name) == 0)
			return info;
	}
	return NULL;
}

// Checks if a ModuleInfo is in a module_stack
static int in_stack(const ModuleInfo *info, struct module_stack *stack) {
	for (int i = 0; i < stack->next; ++i) {
		const ModuleInfo *elem = stack->elems[i];
		if (elem == info)
			return 1;
	}
	return 0;
}

// Traverses a module use dependency tree, checking for circular dependencies
// The result is the visited stack containing modules in load order
// This function works using a depth-first search
static void traverse_module_uses(
	const ModuleInfo *start, struct module_stack *visited) {
	static struct module_stack traversing = {0};
	if (sizeof(modules_list) > sizeof(traversing.elems))
		abort_print("Modules too big to traverse");
	if (in_stack(start, visited))
		return;
	// Push the start module to the top of the traversal stack
	traversing.next = 0;
	traversing.elems[traversing.next++] = start;
	do {
		// Use the top of traversal stack as our current module
		const ModuleInfo *cur = traversing.elems[traversing.next - 1];
		const char **use = NULL;
		// Iterate through module uses
		for (use = cur->uses; *use != NULL; ++use) {
			// Find the info associated with the use
			const ModuleInfo *info = info_by_name(*use);
			if (info == NULL)
				abort_print("Couldn't find module");
			// If we are already traversing this use then
			// we have a cyclic dependency. Bail
			if (in_stack(info, &traversing))
				abort_print("Cyclic dependency found");
			// If this use is unvisited, push it to the stack
			// and traverse it
			if (!in_stack(info, visited)) {
				traversing.elems[traversing.next++] = info;
				break;
			}
		}
		// No unvisited uses? Mark this module as visited and
		// retry the module we were traversing before
		if (*use == NULL) {
			traversing.elems[traversing.next--] = NULL;
			visited->elems[visited->next++] = cur;
		}
	} while (traversing.next != 0);

	// The traversal stack is now empty, all done!
}

// Finds a module object by its info using a module_mapping
Object get_module_by_info(
	struct module_mapping *mapping, const ModuleInfo *info) {
	struct module_stack *infos = &mapping->infos;
	for (int i = 0; i < infos->next; ++i) {
		const ModuleInfo *elem = infos->elems[i];
		if (elem == info)
			return mapping->objects.elems[i];
	}
	abort_print("Couldn't find module by info");
}

// Creates a module and stores it in a mapping
Object create_module(
	VmState state, struct module_mapping *mapping, const ModuleInfo *info) {
	static Object args[16];
	unsigned int args_next = 0;
	for (const char **use = info->uses; *use != NULL; ++use) {
		const ModuleInfo *use_info = info_by_name(*use);
		if (info == NULL)
			abort_print("Couldn't find module");
		Object use_module = get_module_by_info(mapping, use_info);
		args[args_next++] = use_module;
		// Check there's enough room left for this and an object_none
		if (ARRAY_SIZE(args) < (args_next + 1))
			abort_print("Too many module args");
	}
	args[args_next++] = object_none();
	Object module = info->create(state);
	(void)args; // Unused for now
	return module;
}

// Creates modules for a mapping
void create_modules(VmState state, struct module_mapping *mapping) {
	struct module_stack *infos = &mapping->infos;
	struct object_stack *objs = &mapping->objects;
	for (int i = objs->next; i < infos->next; ++i) {
		ModuleInfo *info = infos->elems[i];
		Object module = create_module(state, mapping, info);
		objs->elems[objs->next++] = module;
	}
}

// Global mapping used for finding and freeing modules
static struct module_mapping mapping = {0};

Object module_find(VmState state, const char *name) {
	const ModuleInfo *info = info_by_name(name);
	vm_abort_if(state, !info, "Unable to find module!");
	if (sizeof(modules_list) > sizeof(mapping.infos.elems))
		abort_print("Modules too big to track");
	if (sizeof(modules_list) > sizeof(mapping.objects.elems))
		abort_print("Modules too big to load");
	traverse_module_uses(info, &mapping.infos);
	create_modules(state, &mapping);
	Object module = get_module_by_info(&mapping, info);
	// Create a new reference for the caller
	object_hold(state, module);
	return module;
}

void modules_free(VmState state) {
	struct object_stack *objs = &mapping.objects;
	for (int i = 0; i < objs->next; ++i)
		object_drop(state, &objs->elems[i]);
	memset(&mapping, 0, sizeof(mapping));
}