Newer
Older
barebox / fs / tftp.c
@Uwe Kleine-König Uwe Kleine-König on 3 Mar 2017 13 KB tftp: implement forward seeking
/*
 * tftp.c
 *
 * Copyright (c) 2011 Sascha Hauer <s.hauer@pengutronix.de>, Pengutronix
 *
 * See file CREDITS for list of people who contributed to this
 * project.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2
 * as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 */
#include <common.h>
#include <command.h>
#include <net.h>
#include <driver.h>
#include <clock.h>
#include <fs.h>
#include <errno.h>
#include <libgen.h>
#include <fcntl.h>
#include <getopt.h>
#include <fs.h>
#include <init.h>
#include <linux/stat.h>
#include <linux/err.h>
#include <kfifo.h>
#include <linux/sizes.h>

#define TFTP_PORT	69	/* Well known TFTP port number */

/* Seconds to wait before remote server is allowed to resend a lost packet */
#define TIMEOUT		5

/* After this time without a response from the server we will resend a packet */
#define TFTP_RESEND_TIMEOUT	SECOND

/* After this time without progress we will bail out */
#define TFTP_TIMEOUT		((TIMEOUT * 3) * SECOND)

/*
 *	TFTP operations.
 */
#define TFTP_RRQ	1
#define TFTP_WRQ	2
#define TFTP_DATA	3
#define TFTP_ACK	4
#define TFTP_ERROR	5
#define TFTP_OACK	6

#define STATE_RRQ	1
#define STATE_WRQ	2
#define STATE_RDATA	3
#define STATE_WDATA	4
#define STATE_OACK	5
#define STATE_WAITACK	6
#define STATE_LAST	7
#define STATE_DONE	8

#define TFTP_BLOCK_SIZE		512	/* default TFTP block size */
#define TFTP_FIFO_SIZE		4096

#define TFTP_ERR_RESEND	1

struct file_priv {
	struct net_connection *tftp_con;
	int push;
	uint16_t block;
	uint16_t last_block;
	int state;
	int err;
	const char *filename;
	int filesize;
	uint64_t resend_timeout;
	uint64_t progress_timeout;
	struct kfifo *fifo;
	void *buf;
	int blocksize;
	int block_requested;
};

struct tftp_priv {
	IPaddr_t server;
};

static int tftp_create(struct device_d *dev, const char *pathname, mode_t mode)
{
	return 0;
}

static int tftp_unlink(struct device_d *dev, const char *pathname)
{
	return -ENOSYS;
}

static int tftp_mkdir(struct device_d *dev, const char *pathname)
{
	return -ENOSYS;
}

static int tftp_rmdir(struct device_d *dev, const char *pathname)
{
	return -ENOSYS;
}

static int tftp_truncate(struct device_d *dev, FILE *f, ulong size)
{
	return 0;
}

static int tftp_send(struct file_priv *priv)
{
	unsigned char *xp;
	int len = 0;
	uint16_t *s;
	unsigned char *pkt = net_udp_get_payload(priv->tftp_con);
	int ret;

	debug("%s: state %d\n", __func__, priv->state);

	switch (priv->state) {
	case STATE_RRQ:
	case STATE_WRQ:
		xp = pkt;
		s = (uint16_t *)pkt;
		if (priv->state == STATE_RRQ)
			*s++ = htons(TFTP_RRQ);
		else
			*s++ = htons(TFTP_WRQ);
		pkt = (unsigned char *)s;
		pkt += sprintf((unsigned char *)pkt,
				"%s%c"
				"octet%c"
				"timeout%c"
				"%d%c"
				"tsize%c"
				"%d%c"
				"blksize%c"
				"1432",
				priv->filename, 0,
				0,
				0,
				TIMEOUT, 0,
				0,
				priv->filesize, 0,
				0);
		pkt++;
		len = pkt - xp;
		break;

	case STATE_RDATA:
		if (priv->block == priv->block_requested)
			return 0;
	case STATE_OACK:
		xp = pkt;
		s = (uint16_t *)pkt;
		*s++ = htons(TFTP_ACK);
		*s++ = htons(priv->block);
		priv->block_requested = priv->block;
		pkt = (unsigned char *)s;
		len = pkt - xp;
		break;
	}

	ret = net_udp_send(priv->tftp_con, len);

	return ret;
}

static int tftp_send_write(struct file_priv *priv, void *buf, int len)
{
	uint16_t *s;
	unsigned char *pkt = net_udp_get_payload(priv->tftp_con);
	int ret;

	s = (uint16_t *)pkt;
	*s++ = htons(TFTP_DATA);
	*s++ = htons(priv->block);
	memcpy((void *)s, buf, len);
	if (len < priv->blocksize)
		priv->state = STATE_LAST;
	len += 4;

	ret = net_udp_send(priv->tftp_con, len);
	priv->last_block = priv->block;
	priv->state = STATE_WAITACK;

	return ret;
}

static int tftp_poll(struct file_priv *priv)
{
	if (ctrlc()) {
		priv->state = STATE_DONE;
		priv->err = -EINTR;
		return -EINTR;
	}

	if (is_timeout(priv->resend_timeout, TFTP_RESEND_TIMEOUT)) {
		printf("T ");
		priv->resend_timeout = get_time_ns();
		priv->block_requested = -1;
		return TFTP_ERR_RESEND;
	}

	if (is_timeout(priv->progress_timeout, TFTP_TIMEOUT)) {
		priv->state = STATE_DONE;
		priv->err = -ETIMEDOUT;
		return -ETIMEDOUT;
	}

	net_poll();

	return 0;
}

static void tftp_parse_oack(struct file_priv *priv, unsigned char *pkt, int len)
{
	unsigned char *opt, *val, *s;

	pkt[len - 1] = 0;

	debug("got OACK\n");
#ifdef DEBUG
	memory_display(pkt, 0, len, 1, 0);
#endif

	s = pkt;

	while (s < pkt + len) {
		opt = s;
		val = s + strlen(s) + 1;
		if (val > s + len)
			return;
		if (!strcmp(opt, "tsize"))
			priv->filesize = simple_strtoul(val, NULL, 10);
		if (!strcmp(opt, "blksize"))
			priv->blocksize = simple_strtoul(val, NULL, 10);
		debug("OACK opt: %s val: %s\n", opt, val);
		s = val + strlen(val) + 1;
	}
}

static void tftp_timer_reset(struct file_priv *priv)
{
	priv->progress_timeout = priv->resend_timeout = get_time_ns();
}

static void tftp_recv(struct file_priv *priv,
			uint8_t *pkt, unsigned len, uint16_t uh_sport)
{
	uint16_t opcode;

	/* according to RFC1350 minimal tftp packet length is 4 bytes */
	if (len < 4)
		return;

	opcode = ntohs(*(uint16_t *)pkt);

	/* skip tftp opcode 2-byte field */
	len -= 2;
	pkt += 2;

	debug("%s: opcode 0x%04x\n", __func__, opcode);

	switch (opcode) {
	case TFTP_RRQ:
	case TFTP_WRQ:
	default:
		break;
	case TFTP_ACK:
		if (!priv->push)
			break;

		priv->block = ntohs(*(uint16_t *)pkt);
		if (priv->block != priv->last_block) {
			debug("ack %d != %d\n", priv->block, priv->last_block);
			break;
		}

		priv->block++;

		tftp_timer_reset(priv);

		if (priv->state == STATE_LAST) {
			priv->state = STATE_DONE;
			break;
		}
		priv->tftp_con->udp->uh_dport = uh_sport;
		priv->state = STATE_WDATA;
		break;

	case TFTP_OACK:
		tftp_parse_oack(priv, pkt, len);
		priv->tftp_con->udp->uh_dport = uh_sport;

		if (priv->push) {
			/* send first block */
			priv->state = STATE_WDATA;
			priv->block = 1;
		} else {
			/* send ACK */
			priv->state = STATE_OACK;
			priv->block = 0;
			tftp_send(priv);
		}

		break;
	case TFTP_DATA:
		len -= 2;
		priv->block = ntohs(*(uint16_t *)pkt);

		if (priv->state == STATE_RRQ || priv->state == STATE_OACK) {
			/* first block received */
			priv->state = STATE_RDATA;
			priv->tftp_con->udp->uh_dport = uh_sport;
			priv->last_block = 0;

			if (priv->block != 1) {	/* Assertion */
				printf("error: First block is not block 1 (%d)\n",
					priv->block);
				priv->err = -EINVAL;
				priv->state = STATE_DONE;
				break;
			}
		}

		if (priv->block == priv->last_block)
			/* Same block again; ignore it. */
			break;

		priv->last_block = priv->block;

		tftp_timer_reset(priv);

		kfifo_put(priv->fifo, pkt + 2, len);

		if (len < priv->blocksize) {
			tftp_send(priv);
			priv->err = 0;
			priv->state = STATE_DONE;
		}

		break;

	case TFTP_ERROR:
		debug("\nTFTP error: '%s' (%d)\n",
				pkt + 2, ntohs(*(uint16_t *)pkt));
		switch (ntohs(*(uint16_t *)pkt)) {
		case 1:
			priv->err = -ENOENT;
			break;
		case 2:
			priv->err = -EACCES;
			break;
		default:
			priv->err = -EINVAL;
			break;
		}
		priv->state = STATE_DONE;
		break;
	}
}

static void tftp_handler(void *ctx, char *packet, unsigned len)
{
	struct file_priv *priv = ctx;
	char *pkt = net_eth_to_udp_payload(packet);
	struct udphdr *udp = net_eth_to_udphdr(packet);

	(void)len;
	tftp_recv(priv, pkt, net_eth_to_udplen(packet), udp->uh_sport);
}

static struct file_priv *tftp_do_open(struct device_d *dev,
		int accmode, const char *filename)
{
	struct file_priv *priv;
	struct tftp_priv *tpriv = dev->priv;
	int ret;

	priv = xzalloc(sizeof(*priv));

	filename++;

	switch (accmode & O_ACCMODE) {
	case O_RDONLY:
		priv->push = 0;
		priv->state = STATE_RRQ;
		break;
	case O_WRONLY:
		priv->push = 1;
		priv->state = STATE_WRQ;
		if (!(accmode & O_TRUNC)) {
			/*
			 * TFTP always truncates the existing file, so this
			 * flag is mandatory when opening a file for writing.
			 */
			ret = -ENOSYS;
			goto out;
		}
		break;
	case O_RDWR:
		ret = -ENOSYS;
		goto out;
	}

	priv->block = 1;
	priv->err = -EINVAL;
	priv->filename = filename;
	priv->blocksize = TFTP_BLOCK_SIZE;
	priv->block_requested = -1;

	priv->fifo = kfifo_alloc(TFTP_FIFO_SIZE);
	if (!priv->fifo) {
		ret = -ENOMEM;
		goto out;
	}

	priv->tftp_con = net_udp_new(tpriv->server, TFTP_PORT, tftp_handler,
			priv);
	if (IS_ERR(priv->tftp_con)) {
		ret = PTR_ERR(priv->tftp_con);
		goto out1;
	}

	ret = tftp_send(priv);
	if (ret)
		goto out2;

	tftp_timer_reset(priv);
	while (priv->state != STATE_RDATA &&
			priv->state != STATE_DONE &&
			priv->state != STATE_WDATA) {
		ret = tftp_poll(priv);
		if (ret == TFTP_ERR_RESEND)
			tftp_send(priv);
		if (ret < 0)
			goto out2;
	}

	if (priv->state == STATE_DONE && priv->err) {
		ret = priv->err;
		goto out2;
	}

	priv->buf = xmalloc(priv->blocksize);

	return priv;
out2:
	net_unregister(priv->tftp_con);
out1:
	kfifo_free(priv->fifo);
out:
	free(priv);

	return ERR_PTR(ret);
}

static int tftp_open(struct device_d *dev, FILE *file, const char *filename)
{
	struct file_priv *priv;

	priv = tftp_do_open(dev, file->flags, filename);
	if (IS_ERR(priv))
		return PTR_ERR(priv);

	file->priv = priv;
	file->size = SZ_2G;

	return 0;
}

static int tftp_do_close(struct file_priv *priv)
{
	int ret;

	if (priv->push && priv->state != STATE_DONE) {
		int len;

		len = kfifo_get(priv->fifo, priv->buf, priv->blocksize);
		tftp_send_write(priv, priv->buf, len);
		priv->state = STATE_LAST;

		tftp_timer_reset(priv);

		while (priv->state != STATE_DONE) {
			ret = tftp_poll(priv);
			if (ret == TFTP_ERR_RESEND)
				tftp_send_write(priv, priv->buf, len);
			if (ret < 0)
				break;
		}
	}

	if (!priv->push && priv->state != STATE_DONE) {
		uint16_t *pkt = net_udp_get_payload(priv->tftp_con);
		*pkt++ = htons(TFTP_ERROR);
		*pkt++ = 0;
		*pkt++ = 0;
		net_udp_send(priv->tftp_con, 6);
	}

	net_unregister(priv->tftp_con);
	kfifo_free(priv->fifo);
	free(priv->buf);
	free(priv);

	return 0;
}

static int tftp_close(struct device_d *dev, FILE *f)
{
	struct file_priv *priv = f->priv;

	return tftp_do_close(priv);
}

static int tftp_write(struct device_d *_dev, FILE *f, const void *inbuf,
		size_t insize)
{
	struct file_priv *priv = f->priv;
	size_t size, now;
	int ret;

	debug("%s: %zu\n", __func__, insize);

	size = insize;

	while (size) {
		now = kfifo_put(priv->fifo, inbuf, size);

		while (kfifo_len(priv->fifo) >= priv->blocksize) {
			kfifo_get(priv->fifo, priv->buf, priv->blocksize);

			tftp_send_write(priv, priv->buf, priv->blocksize);
			tftp_timer_reset(priv);

			while (priv->state == STATE_WAITACK) {
				ret = tftp_poll(priv);
				if (ret == TFTP_ERR_RESEND)
					tftp_send_write(priv, priv->buf,
							priv->blocksize);
				if (ret < 0)
					return ret;
			}
		}
		size -= now;
		inbuf += now;
	}

	return insize;
}

static int tftp_read(struct device_d *dev, FILE *f, void *buf, size_t insize)
{
	struct file_priv *priv = f->priv;
	size_t outsize = 0, now;
	int ret;

	debug("%s %zu\n", __func__, insize);

	while (insize) {
		now = kfifo_get(priv->fifo, buf, insize);
		outsize += now;
		buf += now;
		insize -= now;
		if (priv->state == STATE_DONE)
			return outsize;

		if (TFTP_FIFO_SIZE - kfifo_len(priv->fifo) >= priv->blocksize)
			tftp_send(priv);

		ret = tftp_poll(priv);
		if (ret == TFTP_ERR_RESEND)
			tftp_send(priv);
		if (ret < 0)
			return ret;
	}

	return outsize;
}

static loff_t tftp_lseek(struct device_d *dev, FILE *f, loff_t pos)
{
	/* We cannot seek backwards without reloading or caching the file */
	if (pos >= f->pos) {
		loff_t ret;
		char *buf = xmalloc(1024);

		while (pos > f->pos) {
			size_t len = min_t(size_t, 1024, pos - f->pos);

			ret = tftp_read(dev, f, buf, len);

			if (!ret)
				/* EOF, so the desired pos is invalid. */
				ret = -EINVAL;
			if (ret < 0)
				goto out_free;

			f->pos += ret;
		}

		ret = pos;

out_free:
		free(buf);
		return ret;
	}

	return -ENOSYS;
}

static DIR* tftp_opendir(struct device_d *dev, const char *pathname)
{
	/* not implemented in tftp protocol */
	return NULL;
}

static int tftp_stat(struct device_d *dev, const char *filename, struct stat *s)
{
	struct file_priv *priv;

	priv = tftp_do_open(dev, O_RDONLY, filename);
	if (IS_ERR(priv))
		return PTR_ERR(priv);

	s->st_mode = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO;
	if (priv->filesize)
		s->st_size = priv->filesize;
	else
		s->st_size = FILESIZE_MAX;

	tftp_do_close(priv);

	return 0;
}

static int tftp_probe(struct device_d *dev)
{
	struct fs_device_d *fsdev = dev_to_fs_device(dev);
	struct tftp_priv *priv = xzalloc(sizeof(struct tftp_priv));

	dev->priv = priv;

	priv->server = resolv(fsdev->backingstore);

	return 0;
}

static void tftp_remove(struct device_d *dev)
{
	struct tftp_priv *priv = dev->priv;

	free(priv);
}

static struct fs_driver_d tftp_driver = {
	.open      = tftp_open,
	.close     = tftp_close,
	.read      = tftp_read,
	.lseek     = tftp_lseek,
	.opendir   = tftp_opendir,
	.stat      = tftp_stat,
	.create    = tftp_create,
	.unlink    = tftp_unlink,
	.mkdir     = tftp_mkdir,
	.rmdir     = tftp_rmdir,
	.write     = tftp_write,
	.truncate  = tftp_truncate,
	.flags     = 0,
	.drv = {
		.probe  = tftp_probe,
		.remove = tftp_remove,
		.name = "tftp",
	}
};

static int tftp_init(void)
{
	return register_fs_driver(&tftp_driver);
}
coredevice_initcall(tftp_init);