Newer
Older
Tardis / lang / rational.c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024 John Watts and the LuminaSensum contributors

#include "rational.h"
#include "boolean.h"
#include "object.h"
#include "vm.h"
#include <stddef.h>

static struct object_class ratio_class;

struct rational {
	mpq_t value;
};

Object rational_create_mpq(VmState state, mpq_t value) {
	Object obj =
		object_create(state, &ratio_class, sizeof(struct rational));
	struct rational *ratio =
		(struct rational *)object_priv(state, obj, &ratio_class);
	mpq_init(ratio->value);
	mpq_set(ratio->value, value);
	return obj;
}

Object rational_create(VmState state, int p, int q) {
	mpq_t value;
	mpq_init(value);
	mpq_set_ui(value, p, q);
	mpq_canonicalize(value);
	Object ret = rational_create_mpq(state, value);
	mpq_clear(value);
	return ret;
}

static void rational_cleanup(VmState state, Object obj) {
	struct rational *ratio =
		(struct rational *)object_priv(state, obj, &ratio_class);
	mpq_clear(ratio->value);
}

int rational_integer(VmState state, Object obj) {
	struct rational *ratio =
		(struct rational *)object_priv(state, obj, &ratio_class);
	mpz_t out;
	mpz_init(out);
	mpq_get_num(out, ratio->value);
	int out_int = (int)mpz_get_si(out);
	mpz_clear(out);
	return out_int;
}

mpq_t *rational_priv_mpq(VmState state, Object obj) {
	struct rational *ratio =
		(struct rational *)object_priv(state, obj, &ratio_class);
	return &ratio->value;
}

#define OP_ADD 1
#define OP_SUBTRACT 2
#define OP_DIVIDE 3
#define OP_MULTIPLY 4

static void rational_math(VmState state, Object obj, int priv) {
	int arg_count = vm_stack_depth(state);
	vm_abort_if(state, arg_count != 2,
		"rational_math called without 2 arguments");
	Object arg1 = vm_stack_get(state, 1);
	mpq_t *ratio1 = rational_priv_mpq(state, obj);
	mpq_t *ratio2 = rational_priv_mpq(state, arg1);
	mpq_t result;
	mpq_init(result);
	if (priv == OP_ADD)
		mpq_add(result, *ratio1, *ratio2);
	else if (priv == OP_SUBTRACT)
		mpq_sub(result, *ratio1, *ratio2);
	else if (priv == OP_DIVIDE)
		mpq_div(result, *ratio1, *ratio2);
	else if (priv == OP_MULTIPLY)
		mpq_mul(result, *ratio1, *ratio2);
	else
		vm_abort_msg(state, "rational_math called with invalid priv");
	vm_stack_set(state, 0, rational_create_mpq(state, result));
	vm_stack_drop(state, 1);
	object_drop(state, &arg1);
	mpq_clear(result);
}

static void rational_equals(VmState state, Object obj, int priv) {
	(void)priv;
	int arg_count = vm_stack_depth(state);
	vm_abort_if(state, arg_count != 2,
		"rational_equals called without 2 arguments");
	Object arg1 = vm_stack_get(state, 1);
	bool equals =
		rational_integer(state, obj) == rational_integer(state, arg1);
	vm_stack_set(state, 0, boolean_create(state, equals));
	vm_stack_drop(state, 1);
	object_drop(state, &arg1);
}

static struct object_call calls[] = {
	{.name = "Add", .handler = rational_math, .priv = OP_ADD},
	{.name = "Subtract", .handler = rational_math, .priv = OP_SUBTRACT},
	{.name = "Divide", .handler = rational_math, .priv = OP_DIVIDE},
	{.name = "Multiply", .handler = rational_math, .priv = OP_MULTIPLY},
	{.name = "Equals", .handler = rational_equals, .priv = 0},
	{.name = NULL, /* end */}};

static struct object_class ratio_class = {
	.cleanup = rational_cleanup,
	.calls = &calls[0],
};