summaryrefslogtreecommitdiff
path: root/drivers/virtio/virtio_pci_modern.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/virtio/virtio_pci_modern.c')
-rw-r--r--drivers/virtio/virtio_pci_modern.c20
1 files changed, 16 insertions, 4 deletions
diff --git a/drivers/virtio/virtio_pci_modern.c b/drivers/virtio/virtio_pci_modern.c
index 2c1b0ebfce..6fe5e76572 100644
--- a/drivers/virtio/virtio_pci_modern.c
+++ b/drivers/virtio/virtio_pci_modern.c
@@ -397,18 +397,27 @@ static int virtio_pci_notify(struct udevice *udev, struct virtqueue *vq)
*
* @udev: the transport device
* @cfg_type: the VIRTIO_PCI_CAP_* value we seek
+ * @cap_size: expected size of the capability
*
* Return: offset of the configuration structure
*/
-static int virtio_pci_find_capability(struct udevice *udev, u8 cfg_type)
+static int virtio_pci_find_capability(struct udevice *udev, u8 cfg_type,
+ size_t cap_size)
{
int pos;
int offset;
u8 type, bar;
+ assert(cap_size >= sizeof(struct virtio_pci_cap));
+ assert(cap_size <= PCI_CFG_SPACE_SIZE);
+
for (pos = dm_pci_find_capability(udev, PCI_CAP_ID_VNDR);
pos > 0;
pos = dm_pci_find_next_capability(udev, pos, PCI_CAP_ID_VNDR)) {
+ /* Ensure the capability is within bounds */
+ if (PCI_CFG_SPACE_SIZE - cap_size < pos)
+ return 0;
+
offset = pos + offsetof(struct virtio_pci_cap, cfg_type);
dm_pci_read_config8(udev, offset, &type);
offset = pos + offsetof(struct virtio_pci_cap, bar);
@@ -496,7 +505,8 @@ static int virtio_pci_probe(struct udevice *udev)
uc_priv->vendor = subvendor;
/* Check for a common config: if not, use legacy mode (bar 0) */
- common = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_COMMON_CFG);
+ common = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_COMMON_CFG,
+ sizeof(struct virtio_pci_cap));
if (!common) {
printf("(%s): leaving for legacy driver\n", udev->name);
return -ENODEV;
@@ -510,7 +520,8 @@ static int virtio_pci_probe(struct udevice *udev)
}
/* If common is there, notify should be too */
- notify = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_NOTIFY_CFG);
+ notify = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_NOTIFY_CFG,
+ sizeof(struct virtio_pci_notify_cap));
if (!notify) {
printf("(%s): missing capabilities %i/%i\n", udev->name,
common, notify);
@@ -524,7 +535,8 @@ static int virtio_pci_probe(struct udevice *udev)
* Device capability is only mandatory for devices that have
* device-specific configuration.
*/
- device = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_DEVICE_CFG);
+ device = virtio_pci_find_capability(udev, VIRTIO_PCI_CAP_DEVICE_CFG,
+ sizeof(struct virtio_pci_cap));
if (device) {
offset = device + offsetof(struct virtio_pci_cap, length);
dm_pci_read_config32(udev, offset, &priv->device_len);