summaryrefslogtreecommitdiff
path: root/drivers/iommu
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu')
-rw-r--r--drivers/iommu/arm-smmu-v3.c18
-rw-r--r--drivers/iommu/arm-smmu.c1
2 files changed, 12 insertions, 7 deletions
diff --git a/drivers/iommu/arm-smmu-v3.c b/drivers/iommu/arm-smmu-v3.c
index 2e3e235..3ea4d57 100644
--- a/drivers/iommu/arm-smmu-v3.c
+++ b/drivers/iommu/arm-smmu-v3.c
@@ -1809,13 +1809,13 @@ static int arm_smmu_add_device(struct device *dev)
smmu = arm_smmu_get_for_pci_dev(pdev);
if (!smmu) {
ret = -ENOENT;
- goto out_put_group;
+ goto out_remove_dev;
}
smmu_group = kzalloc(sizeof(*smmu_group), GFP_KERNEL);
if (!smmu_group) {
ret = -ENOMEM;
- goto out_put_group;
+ goto out_remove_dev;
}
smmu_group->ste.valid = true;
@@ -1831,20 +1831,20 @@ static int arm_smmu_add_device(struct device *dev)
for (i = 0; i < smmu_group->num_sids; ++i) {
/* If we already know about this SID, then we're done */
if (smmu_group->sids[i] == sid)
- return 0;
+ goto out_put_group;
}
/* Check the SID is in range of the SMMU and our stream table */
if (!arm_smmu_sid_in_range(smmu, sid)) {
ret = -ERANGE;
- goto out_put_group;
+ goto out_remove_dev;
}
/* Ensure l2 strtab is initialised */
if (smmu->features & ARM_SMMU_FEAT_2_LVL_STRTAB) {
ret = arm_smmu_init_l2_strtab(smmu, sid);
if (ret)
- goto out_put_group;
+ goto out_remove_dev;
}
/* Resize the SID array for the group */
@@ -1854,16 +1854,20 @@ static int arm_smmu_add_device(struct device *dev)
if (!sids) {
smmu_group->num_sids--;
ret = -ENOMEM;
- goto out_put_group;
+ goto out_remove_dev;
}
/* Add the new SID */
sids[smmu_group->num_sids - 1] = sid;
smmu_group->sids = sids;
- return 0;
out_put_group:
iommu_group_put(group);
+ return 0;
+
+out_remove_dev:
+ iommu_group_remove_device(dev);
+ iommu_group_put(group);
return ret;
}
diff --git a/drivers/iommu/arm-smmu.c b/drivers/iommu/arm-smmu.c
index 1ce4b85..6ed169b 100644
--- a/drivers/iommu/arm-smmu.c
+++ b/drivers/iommu/arm-smmu.c
@@ -1355,6 +1355,7 @@ static int arm_smmu_add_device(struct device *dev)
if (IS_ERR(group))
return PTR_ERR(group);
+ iommu_group_put(group);
return 0;
}