Newer
Older
Tardis / lang / module.c
// SPDX-License-Identifier: MIT
// Copyright (c) 2023 John Watts and the LuminaSensum contributors

#define MODULE_INTERNAL_API
#include "module.h"
#include "error.h"
#include "list.h"
#include "object.h"
#include "string.h"
#include "vm.h"
#include <stdbool.h>

// Useful macro for getting array size
#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x[0]))

// Use X macros to declare external module variables
#define X(id) extern const ModuleInfo module_info_##id;
#include <module_ids.x>
#undef X

// Use X macros to build a list of modules
#define X(id) &module_info_##id,
static const ModuleInfo *modules_list[] = {
#include <module_ids.x>
};
#undef X

// Size of a stack in elemenets
#define STACK_SIZE 64

// Section 1: Module dependency resolution

// A stack of ModuleInfo pointers
struct module_stack {
	const ModuleInfo *elems[STACK_SIZE];
	int next;
};

// Finds a ModuleInfo by name
static const ModuleInfo *info_by_name(const char *name) {
	for (unsigned int i = 0; i < ARRAY_SIZE(modules_list); ++i) {
		const ModuleInfo *info = modules_list[i];
		if (strcmp(info->name, name) == 0)
			return info;
	}
	return NULL;
}

// Checks if a ModuleInfo is in a module_stack
static bool in_stack(const ModuleInfo *info, struct module_stack *stack) {
	for (int i = 0; i < stack->next; ++i) {
		const ModuleInfo *elem = stack->elems[i];
		if (elem == info)
			return true;
	}
	return false;
}

// Traverses a module use dependency tree, checking for circular dependencies
// The result is the visited stack containing modules in load order
// This function may be called using an existing visited stack
static void traverse_module_uses(
	const ModuleInfo *start, struct module_stack *visited) {
	// This function works using a depth-first search
	static struct module_stack traversing = {0};
	if (sizeof(modules_list) > sizeof(traversing.elems))
		abort_print("Modules too big to traverse");
	if (in_stack(start, visited))
		return;
	// Push the start module to the top of the traversal stack
	traversing.next = 0;
	traversing.elems[traversing.next++] = start;
	do {
		// Use the top of traversal stack as our current module
		const ModuleInfo *cur = traversing.elems[traversing.next - 1];
		const char **use = NULL;
		// Iterate through module uses
		for (use = cur->uses; *use != NULL; ++use) {
			// Find the info associated with the use
			const ModuleInfo *info = info_by_name(*use);
			if (info == NULL)
				abort_print("Couldn't find module");
			// If we are already traversing this use then
			// we have a cyclic dependency. Bail
			if (in_stack(info, &traversing))
				abort_print("Cyclic dependency found");
			// If this use is unvisited, push it to the stack
			// and traverse it
			if (!in_stack(info, visited)) {
				traversing.elems[traversing.next++] = info;
				break;
			}
		}
		// No unvisited uses? Mark this module as visited and
		// retry the module we were traversing before
		if (*use == NULL) {
			traversing.elems[traversing.next--] = NULL;
			visited->elems[visited->next++] = cur;
		}
	} while (traversing.next != 0);

	// The traversal stack is now empty, all done!
}

// Section 2: Runtime module creation

// A stack of Objects
struct object_stack {
	Object elems[STACK_SIZE];
	int next;
};

// A mapping mapping between ModuleInfo and Object
// Each stack element should have a corresponding element
// in the other stack at the same index
struct module_mapping {
	struct module_stack infos;
	struct object_stack objects;
};

// Finds a module object by its info using a module_mapping
Object get_module_by_info(
	struct module_mapping *mapping, const ModuleInfo *info) {
	struct module_stack *infos = &mapping->infos;
	for (int i = 0; i < infos->next; ++i) {
		const ModuleInfo *elem = infos->elems[i];
		if (elem == info)
			return mapping->objects.elems[i];
	}
	abort_print("Couldn't find module by info");
}

// Creates a module and stores it in a mapping
// Assumes all dependencies already exist
Object create_module(
	VmState state, struct module_mapping *mapping, const ModuleInfo *info) {
	int use_count = 0;
	for (const char **use = info->uses; *use != NULL; ++use)
		++use_count;
	Object use_modules = object_list_create(state, use_count);
	int use_modules_next = 0;
	for (const char **use = info->uses; *use != NULL; ++use) {
		const ModuleInfo *use_info = info_by_name(*use);
		if (info == NULL)
			abort_print("Couldn't find module");
		Object use_module = get_module_by_info(mapping, use_info);
		// Create a new reference for the module
		object_hold(state, use_module);
		object_list_set(
			state, use_modules, use_modules_next++, use_module);
	}
	Object module = info->create(state, use_modules);
	return module;
}

// Creates objects for newly added modules in mapping
void create_new_modules(VmState state, struct module_mapping *mapping) {
	struct module_stack *infos = &mapping->infos;
	struct object_stack *objs = &mapping->objects;
	for (int i = objs->next; i < infos->next; ++i) {
		const ModuleInfo *info = infos->elems[i];
		Object module = create_module(state, mapping, info);
		objs->elems[objs->next++] = module;
	}
}

// Global mapping used for finding and freeing modules
static struct module_mapping mapping = {0};

Object module_find(VmState state, const char *name) {
	const ModuleInfo *info = info_by_name(name);
	vm_abort_if(state, !info, "Unable to find module!");
	if (sizeof(modules_list) > sizeof(mapping.infos.elems))
		abort_print("Modules too big to track");
	if (sizeof(modules_list) > sizeof(mapping.objects.elems))
		abort_print("Modules too big to load");
	// Traverse modules then create them in the correct order
	traverse_module_uses(info, &mapping.infos);
	create_new_modules(state, &mapping);
	Object module = get_module_by_info(&mapping, info);
	// Create a new reference for the caller
	object_hold(state, module);
	return module;
}

void modules_free(VmState state) {
	struct object_stack *objs = &mapping.objects;
	for (int i = 0; i < objs->next; ++i)
		object_drop(state, &objs->elems[i]);
	memset(&mapping, 0, sizeof(mapping));
}