150397Sobrien// SPDX-License-Identifier: GPL-2.0+
250397Sobrien/*
350397Sobrien * Copyright (c) 2019, Linaro Limited
450397Sobrien */
550397Sobrien
650397Sobrien#include <common.h>
750397Sobrien#include <dm.h>
850397Sobrien#include <log.h>
950397Sobrien#include <rng.h>
1050397Sobrien#include <virtio_types.h>
1150397Sobrien#include <virtio.h>
1250397Sobrien#include <virtio_ring.h>
1350397Sobrien
1450397Sobrien#define BUFFER_SIZE	16UL
1550397Sobrien
1650397Sobrienstruct virtio_rng_priv {
1750397Sobrien	struct virtqueue *rng_vq;
1850397Sobrien};
1950397Sobrien
2050397Sobrienstatic int virtio_rng_read(struct udevice *dev, void *data, size_t len)
2150397Sobrien{
2250397Sobrien	int ret;
2350397Sobrien	unsigned int rsize = 1;
2450397Sobrien	unsigned char buf[BUFFER_SIZE] __aligned(4);
2550397Sobrien	unsigned char *ptr = data;
2650397Sobrien	struct virtio_sg sg;
2750397Sobrien	struct virtio_sg *sgs[1];
2850397Sobrien	struct virtio_rng_priv *priv = dev_get_priv(dev);
2950397Sobrien
3050397Sobrien	while (len) {
3150397Sobrien		sg.addr = buf;
3250397Sobrien		/*
3350397Sobrien		 * Work around implementations which always return 8 bytes
34		 * less than requested, down to 0 bytes, which would
35		 * cause an endless loop otherwise.
36		 */
37		sg.length = min(rsize ? len : len + 8, sizeof(buf));
38		sgs[0] = &sg;
39
40		ret = virtqueue_add(priv->rng_vq, sgs, 0, 1);
41		if (ret)
42			return ret;
43
44		virtqueue_kick(priv->rng_vq);
45
46		while (!virtqueue_get_buf(priv->rng_vq, &rsize))
47			;
48
49		if (rsize > sg.length)
50			return -EIO;
51
52		memcpy(ptr, buf, rsize);
53		len -= rsize;
54		ptr += rsize;
55	}
56
57	return 0;
58}
59
60static int virtio_rng_bind(struct udevice *dev)
61{
62	struct virtio_dev_priv *uc_priv = dev_get_uclass_priv(dev->parent);
63
64	/* Indicate what driver features we support */
65	virtio_driver_features_init(uc_priv, NULL, 0, NULL, 0);
66
67	return 0;
68}
69
70static int virtio_rng_probe(struct udevice *dev)
71{
72	struct virtio_rng_priv *priv = dev_get_priv(dev);
73	int ret;
74
75	ret = virtio_find_vqs(dev, 1, &priv->rng_vq);
76	if (ret < 0) {
77		debug("%s: virtio_find_vqs failed\n", __func__);
78		return ret;
79	}
80
81	return 0;
82}
83
84static const struct dm_rng_ops virtio_rng_ops = {
85	.read	= virtio_rng_read,
86};
87
88U_BOOT_DRIVER(virtio_rng) = {
89	.name	= VIRTIO_RNG_DRV_NAME,
90	.id	= UCLASS_RNG,
91	.bind	= virtio_rng_bind,
92	.probe	= virtio_rng_probe,
93	.remove = virtio_reset,
94	.ops	= &virtio_rng_ops,
95	.priv_auto	= sizeof(struct virtio_rng_priv),
96	.flags	= DM_FLAG_ACTIVE_DMA,
97};
98