1  // SPDX-License-Identifier: GPL-2.0+
2  // Copyright 2017 IBM Corp.
3  #include <linux/sched/mm.h>
4  #include <linux/mutex.h>
5  #include <linux/mm.h>
6  #include <linux/mm_types.h>
7  #include <linux/mmu_context.h>
8  #include <linux/mmu_notifier.h>
9  #include <linux/irqdomain.h>
10  #include <asm/copro.h>
11  #include <asm/pnv-ocxl.h>
12  #include <asm/xive.h>
13  #include <misc/ocxl.h>
14  #include "ocxl_internal.h"
15  #include "trace.h"
16  
17  
18  #define SPA_PASID_BITS		15
19  #define SPA_PASID_MAX		((1 << SPA_PASID_BITS) - 1)
20  #define SPA_PE_MASK		SPA_PASID_MAX
21  #define SPA_SPA_SIZE_LOG	22 /* Each SPA is 4 Mb */
22  
23  #define SPA_CFG_SF		(1ull << (63-0))
24  #define SPA_CFG_TA		(1ull << (63-1))
25  #define SPA_CFG_HV		(1ull << (63-3))
26  #define SPA_CFG_UV		(1ull << (63-4))
27  #define SPA_CFG_XLAT_hpt	(0ull << (63-6)) /* Hashed page table (HPT) mode */
28  #define SPA_CFG_XLAT_roh	(2ull << (63-6)) /* Radix on HPT mode */
29  #define SPA_CFG_XLAT_ror	(3ull << (63-6)) /* Radix on Radix mode */
30  #define SPA_CFG_PR		(1ull << (63-49))
31  #define SPA_CFG_TC		(1ull << (63-54))
32  #define SPA_CFG_DR		(1ull << (63-59))
33  
34  #define SPA_XSL_TF		(1ull << (63-3))  /* Translation fault */
35  #define SPA_XSL_S		(1ull << (63-38)) /* Store operation */
36  
37  #define SPA_PE_VALID		0x80000000
38  
39  struct ocxl_link;
40  
41  struct pe_data {
42  	struct mm_struct *mm;
43  	/* callback to trigger when a translation fault occurs */
44  	void (*xsl_err_cb)(void *data, u64 addr, u64 dsisr);
45  	/* opaque pointer to be passed to the above callback */
46  	void *xsl_err_data;
47  	struct rcu_head rcu;
48  	struct ocxl_link *link;
49  	struct mmu_notifier mmu_notifier;
50  };
51  
52  struct spa {
53  	struct ocxl_process_element *spa_mem;
54  	int spa_order;
55  	struct mutex spa_lock;
56  	struct radix_tree_root pe_tree; /* Maps PE handles to pe_data */
57  	char *irq_name;
58  	int virq;
59  	void __iomem *reg_dsisr;
60  	void __iomem *reg_dar;
61  	void __iomem *reg_tfc;
62  	void __iomem *reg_pe_handle;
63  	/*
64  	 * The following field are used by the memory fault
65  	 * interrupt handler. We can only have one interrupt at a
66  	 * time. The NPU won't raise another interrupt until the
67  	 * previous one has been ack'd by writing to the TFC register
68  	 */
69  	struct xsl_fault {
70  		struct work_struct fault_work;
71  		u64 pe;
72  		u64 dsisr;
73  		u64 dar;
74  		struct pe_data pe_data;
75  	} xsl_fault;
76  };
77  
78  /*
79   * A opencapi link can be used be by several PCI functions. We have
80   * one link per device slot.
81   *
82   * A linked list of opencapi links should suffice, as there's a
83   * limited number of opencapi slots on a system and lookup is only
84   * done when the device is probed
85   */
86  struct ocxl_link {
87  	struct list_head list;
88  	struct kref ref;
89  	int domain;
90  	int bus;
91  	int dev;
92  	void __iomem *arva;     /* ATSD register virtual address */
93  	spinlock_t atsd_lock;   /* to serialize shootdowns */
94  	atomic_t irq_available;
95  	struct spa *spa;
96  	void *platform_data;
97  };
98  static LIST_HEAD(links_list);
99  static DEFINE_MUTEX(links_list_lock);
100  
101  enum xsl_response {
102  	CONTINUE,
103  	ADDRESS_ERROR,
104  	RESTART,
105  };
106  
107  
read_irq(struct spa * spa,u64 * dsisr,u64 * dar,u64 * pe)108  static void read_irq(struct spa *spa, u64 *dsisr, u64 *dar, u64 *pe)
109  {
110  	u64 reg;
111  
112  	*dsisr = in_be64(spa->reg_dsisr);
113  	*dar = in_be64(spa->reg_dar);
114  	reg = in_be64(spa->reg_pe_handle);
115  	*pe = reg & SPA_PE_MASK;
116  }
117  
ack_irq(struct spa * spa,enum xsl_response r)118  static void ack_irq(struct spa *spa, enum xsl_response r)
119  {
120  	u64 reg = 0;
121  
122  	/* continue is not supported */
123  	if (r == RESTART)
124  		reg = PPC_BIT(31);
125  	else if (r == ADDRESS_ERROR)
126  		reg = PPC_BIT(30);
127  	else
128  		WARN(1, "Invalid irq response %d\n", r);
129  
130  	if (reg) {
131  		trace_ocxl_fault_ack(spa->spa_mem, spa->xsl_fault.pe,
132  				spa->xsl_fault.dsisr, spa->xsl_fault.dar, reg);
133  		out_be64(spa->reg_tfc, reg);
134  	}
135  }
136  
xsl_fault_handler_bh(struct work_struct * fault_work)137  static void xsl_fault_handler_bh(struct work_struct *fault_work)
138  {
139  	vm_fault_t flt = 0;
140  	unsigned long access, flags, inv_flags = 0;
141  	enum xsl_response r;
142  	struct xsl_fault *fault = container_of(fault_work, struct xsl_fault,
143  					fault_work);
144  	struct spa *spa = container_of(fault, struct spa, xsl_fault);
145  
146  	int rc;
147  
148  	/*
149  	 * We must release a reference on mm_users whenever exiting this
150  	 * function (taken in the memory fault interrupt handler)
151  	 */
152  	rc = copro_handle_mm_fault(fault->pe_data.mm, fault->dar, fault->dsisr,
153  				&flt);
154  	if (rc) {
155  		pr_debug("copro_handle_mm_fault failed: %d\n", rc);
156  		if (fault->pe_data.xsl_err_cb) {
157  			fault->pe_data.xsl_err_cb(
158  				fault->pe_data.xsl_err_data,
159  				fault->dar, fault->dsisr);
160  		}
161  		r = ADDRESS_ERROR;
162  		goto ack;
163  	}
164  
165  	if (!radix_enabled()) {
166  		/*
167  		 * update_mmu_cache() will not have loaded the hash
168  		 * since current->trap is not a 0x400 or 0x300, so
169  		 * just call hash_page_mm() here.
170  		 */
171  		access = _PAGE_PRESENT | _PAGE_READ;
172  		if (fault->dsisr & SPA_XSL_S)
173  			access |= _PAGE_WRITE;
174  
175  		if (get_region_id(fault->dar) != USER_REGION_ID)
176  			access |= _PAGE_PRIVILEGED;
177  
178  		local_irq_save(flags);
179  		hash_page_mm(fault->pe_data.mm, fault->dar, access, 0x300,
180  			inv_flags);
181  		local_irq_restore(flags);
182  	}
183  	r = RESTART;
184  ack:
185  	mmput(fault->pe_data.mm);
186  	ack_irq(spa, r);
187  }
188  
xsl_fault_handler(int irq,void * data)189  static irqreturn_t xsl_fault_handler(int irq, void *data)
190  {
191  	struct ocxl_link *link = data;
192  	struct spa *spa = link->spa;
193  	u64 dsisr, dar, pe_handle;
194  	struct pe_data *pe_data;
195  	struct ocxl_process_element *pe;
196  	int pid;
197  	bool schedule = false;
198  
199  	read_irq(spa, &dsisr, &dar, &pe_handle);
200  	trace_ocxl_fault(spa->spa_mem, pe_handle, dsisr, dar, -1);
201  
202  	WARN_ON(pe_handle > SPA_PE_MASK);
203  	pe = spa->spa_mem + pe_handle;
204  	pid = be32_to_cpu(pe->pid);
205  	/* We could be reading all null values here if the PE is being
206  	 * removed while an interrupt kicks in. It's not supposed to
207  	 * happen if the driver notified the AFU to terminate the
208  	 * PASID, and the AFU waited for pending operations before
209  	 * acknowledging. But even if it happens, we won't find a
210  	 * memory context below and fail silently, so it should be ok.
211  	 */
212  	if (!(dsisr & SPA_XSL_TF)) {
213  		WARN(1, "Invalid xsl interrupt fault register %#llx\n", dsisr);
214  		ack_irq(spa, ADDRESS_ERROR);
215  		return IRQ_HANDLED;
216  	}
217  
218  	rcu_read_lock();
219  	pe_data = radix_tree_lookup(&spa->pe_tree, pe_handle);
220  	if (!pe_data) {
221  		/*
222  		 * Could only happen if the driver didn't notify the
223  		 * AFU about PASID termination before removing the PE,
224  		 * or the AFU didn't wait for all memory access to
225  		 * have completed.
226  		 *
227  		 * Either way, we fail early, but we shouldn't log an
228  		 * error message, as it is a valid (if unexpected)
229  		 * scenario
230  		 */
231  		rcu_read_unlock();
232  		pr_debug("Unknown mm context for xsl interrupt\n");
233  		ack_irq(spa, ADDRESS_ERROR);
234  		return IRQ_HANDLED;
235  	}
236  
237  	if (!pe_data->mm) {
238  		/*
239  		 * translation fault from a kernel context - an OpenCAPI
240  		 * device tried to access a bad kernel address
241  		 */
242  		rcu_read_unlock();
243  		pr_warn("Unresolved OpenCAPI xsl fault in kernel context\n");
244  		ack_irq(spa, ADDRESS_ERROR);
245  		return IRQ_HANDLED;
246  	}
247  	WARN_ON(pe_data->mm->context.id != pid);
248  
249  	if (mmget_not_zero(pe_data->mm)) {
250  			spa->xsl_fault.pe = pe_handle;
251  			spa->xsl_fault.dar = dar;
252  			spa->xsl_fault.dsisr = dsisr;
253  			spa->xsl_fault.pe_data = *pe_data;
254  			schedule = true;
255  			/* mm_users count released by bottom half */
256  	}
257  	rcu_read_unlock();
258  	if (schedule)
259  		schedule_work(&spa->xsl_fault.fault_work);
260  	else
261  		ack_irq(spa, ADDRESS_ERROR);
262  	return IRQ_HANDLED;
263  }
264  
unmap_irq_registers(struct spa * spa)265  static void unmap_irq_registers(struct spa *spa)
266  {
267  	pnv_ocxl_unmap_xsl_regs(spa->reg_dsisr, spa->reg_dar, spa->reg_tfc,
268  				spa->reg_pe_handle);
269  }
270  
map_irq_registers(struct pci_dev * dev,struct spa * spa)271  static int map_irq_registers(struct pci_dev *dev, struct spa *spa)
272  {
273  	return pnv_ocxl_map_xsl_regs(dev, &spa->reg_dsisr, &spa->reg_dar,
274  				&spa->reg_tfc, &spa->reg_pe_handle);
275  }
276  
setup_xsl_irq(struct pci_dev * dev,struct ocxl_link * link)277  static int setup_xsl_irq(struct pci_dev *dev, struct ocxl_link *link)
278  {
279  	struct spa *spa = link->spa;
280  	int rc;
281  	int hwirq;
282  
283  	rc = pnv_ocxl_get_xsl_irq(dev, &hwirq);
284  	if (rc)
285  		return rc;
286  
287  	rc = map_irq_registers(dev, spa);
288  	if (rc)
289  		return rc;
290  
291  	spa->irq_name = kasprintf(GFP_KERNEL, "ocxl-xsl-%x-%x-%x",
292  				link->domain, link->bus, link->dev);
293  	if (!spa->irq_name) {
294  		dev_err(&dev->dev, "Can't allocate name for xsl interrupt\n");
295  		rc = -ENOMEM;
296  		goto err_xsl;
297  	}
298  	/*
299  	 * At some point, we'll need to look into allowing a higher
300  	 * number of interrupts. Could we have an IRQ domain per link?
301  	 */
302  	spa->virq = irq_create_mapping(NULL, hwirq);
303  	if (!spa->virq) {
304  		dev_err(&dev->dev,
305  			"irq_create_mapping failed for translation interrupt\n");
306  		rc = -EINVAL;
307  		goto err_name;
308  	}
309  
310  	dev_dbg(&dev->dev, "hwirq %d mapped to virq %d\n", hwirq, spa->virq);
311  
312  	rc = request_irq(spa->virq, xsl_fault_handler, 0, spa->irq_name,
313  			link);
314  	if (rc) {
315  		dev_err(&dev->dev,
316  			"request_irq failed for translation interrupt: %d\n",
317  			rc);
318  		rc = -EINVAL;
319  		goto err_mapping;
320  	}
321  	return 0;
322  
323  err_mapping:
324  	irq_dispose_mapping(spa->virq);
325  err_name:
326  	kfree(spa->irq_name);
327  err_xsl:
328  	unmap_irq_registers(spa);
329  	return rc;
330  }
331  
release_xsl_irq(struct ocxl_link * link)332  static void release_xsl_irq(struct ocxl_link *link)
333  {
334  	struct spa *spa = link->spa;
335  
336  	if (spa->virq) {
337  		free_irq(spa->virq, link);
338  		irq_dispose_mapping(spa->virq);
339  	}
340  	kfree(spa->irq_name);
341  	unmap_irq_registers(spa);
342  }
343  
alloc_spa(struct pci_dev * dev,struct ocxl_link * link)344  static int alloc_spa(struct pci_dev *dev, struct ocxl_link *link)
345  {
346  	struct spa *spa;
347  
348  	spa = kzalloc(sizeof(struct spa), GFP_KERNEL);
349  	if (!spa)
350  		return -ENOMEM;
351  
352  	mutex_init(&spa->spa_lock);
353  	INIT_RADIX_TREE(&spa->pe_tree, GFP_KERNEL);
354  	INIT_WORK(&spa->xsl_fault.fault_work, xsl_fault_handler_bh);
355  
356  	spa->spa_order = SPA_SPA_SIZE_LOG - PAGE_SHIFT;
357  	spa->spa_mem = (struct ocxl_process_element *)
358  		__get_free_pages(GFP_KERNEL | __GFP_ZERO, spa->spa_order);
359  	if (!spa->spa_mem) {
360  		dev_err(&dev->dev, "Can't allocate Shared Process Area\n");
361  		kfree(spa);
362  		return -ENOMEM;
363  	}
364  	pr_debug("Allocated SPA for %x:%x:%x at %p\n", link->domain, link->bus,
365  		link->dev, spa->spa_mem);
366  
367  	link->spa = spa;
368  	return 0;
369  }
370  
free_spa(struct ocxl_link * link)371  static void free_spa(struct ocxl_link *link)
372  {
373  	struct spa *spa = link->spa;
374  
375  	pr_debug("Freeing SPA for %x:%x:%x\n", link->domain, link->bus,
376  		link->dev);
377  
378  	if (spa && spa->spa_mem) {
379  		free_pages((unsigned long) spa->spa_mem, spa->spa_order);
380  		kfree(spa);
381  		link->spa = NULL;
382  	}
383  }
384  
alloc_link(struct pci_dev * dev,int PE_mask,struct ocxl_link ** out_link)385  static int alloc_link(struct pci_dev *dev, int PE_mask, struct ocxl_link **out_link)
386  {
387  	struct ocxl_link *link;
388  	int rc;
389  
390  	link = kzalloc(sizeof(struct ocxl_link), GFP_KERNEL);
391  	if (!link)
392  		return -ENOMEM;
393  
394  	kref_init(&link->ref);
395  	link->domain = pci_domain_nr(dev->bus);
396  	link->bus = dev->bus->number;
397  	link->dev = PCI_SLOT(dev->devfn);
398  	atomic_set(&link->irq_available, MAX_IRQ_PER_LINK);
399  	spin_lock_init(&link->atsd_lock);
400  
401  	rc = alloc_spa(dev, link);
402  	if (rc)
403  		goto err_free;
404  
405  	rc = setup_xsl_irq(dev, link);
406  	if (rc)
407  		goto err_spa;
408  
409  	/* platform specific hook */
410  	rc = pnv_ocxl_spa_setup(dev, link->spa->spa_mem, PE_mask,
411  				&link->platform_data);
412  	if (rc)
413  		goto err_xsl_irq;
414  
415  	/* if link->arva is not defeined, MMIO registers are not used to
416  	 * generate TLB invalidate. PowerBus snooping is enabled.
417  	 * Otherwise, PowerBus snooping is disabled. TLB Invalidates are
418  	 * initiated using MMIO registers.
419  	 */
420  	pnv_ocxl_map_lpar(dev, mfspr(SPRN_LPID), 0, &link->arva);
421  
422  	*out_link = link;
423  	return 0;
424  
425  err_xsl_irq:
426  	release_xsl_irq(link);
427  err_spa:
428  	free_spa(link);
429  err_free:
430  	kfree(link);
431  	return rc;
432  }
433  
free_link(struct ocxl_link * link)434  static void free_link(struct ocxl_link *link)
435  {
436  	release_xsl_irq(link);
437  	free_spa(link);
438  	kfree(link);
439  }
440  
ocxl_link_setup(struct pci_dev * dev,int PE_mask,void ** link_handle)441  int ocxl_link_setup(struct pci_dev *dev, int PE_mask, void **link_handle)
442  {
443  	int rc = 0;
444  	struct ocxl_link *link;
445  
446  	mutex_lock(&links_list_lock);
447  	list_for_each_entry(link, &links_list, list) {
448  		/* The functions of a device all share the same link */
449  		if (link->domain == pci_domain_nr(dev->bus) &&
450  			link->bus == dev->bus->number &&
451  			link->dev == PCI_SLOT(dev->devfn)) {
452  			kref_get(&link->ref);
453  			*link_handle = link;
454  			goto unlock;
455  		}
456  	}
457  	rc = alloc_link(dev, PE_mask, &link);
458  	if (rc)
459  		goto unlock;
460  
461  	list_add(&link->list, &links_list);
462  	*link_handle = link;
463  unlock:
464  	mutex_unlock(&links_list_lock);
465  	return rc;
466  }
467  EXPORT_SYMBOL_GPL(ocxl_link_setup);
468  
release_xsl(struct kref * ref)469  static void release_xsl(struct kref *ref)
470  {
471  	struct ocxl_link *link = container_of(ref, struct ocxl_link, ref);
472  
473  	if (link->arva) {
474  		pnv_ocxl_unmap_lpar(link->arva);
475  		link->arva = NULL;
476  	}
477  
478  	list_del(&link->list);
479  	/* call platform code before releasing data */
480  	pnv_ocxl_spa_release(link->platform_data);
481  	free_link(link);
482  }
483  
ocxl_link_release(struct pci_dev * dev,void * link_handle)484  void ocxl_link_release(struct pci_dev *dev, void *link_handle)
485  {
486  	struct ocxl_link *link = link_handle;
487  
488  	mutex_lock(&links_list_lock);
489  	kref_put(&link->ref, release_xsl);
490  	mutex_unlock(&links_list_lock);
491  }
492  EXPORT_SYMBOL_GPL(ocxl_link_release);
493  
arch_invalidate_secondary_tlbs(struct mmu_notifier * mn,struct mm_struct * mm,unsigned long start,unsigned long end)494  static void arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
495  					struct mm_struct *mm,
496  					unsigned long start, unsigned long end)
497  {
498  	struct pe_data *pe_data = container_of(mn, struct pe_data, mmu_notifier);
499  	struct ocxl_link *link = pe_data->link;
500  	unsigned long addr, pid, page_size = PAGE_SIZE;
501  
502  	pid = mm->context.id;
503  	trace_ocxl_mmu_notifier_range(start, end, pid);
504  
505  	spin_lock(&link->atsd_lock);
506  	for (addr = start; addr < end; addr += page_size)
507  		pnv_ocxl_tlb_invalidate(link->arva, pid, addr, page_size);
508  	spin_unlock(&link->atsd_lock);
509  }
510  
511  static const struct mmu_notifier_ops ocxl_mmu_notifier_ops = {
512  	.arch_invalidate_secondary_tlbs = arch_invalidate_secondary_tlbs,
513  };
514  
calculate_cfg_state(bool kernel)515  static u64 calculate_cfg_state(bool kernel)
516  {
517  	u64 state;
518  
519  	state = SPA_CFG_DR;
520  	if (mfspr(SPRN_LPCR) & LPCR_TC)
521  		state |= SPA_CFG_TC;
522  	if (radix_enabled())
523  		state |= SPA_CFG_XLAT_ror;
524  	else
525  		state |= SPA_CFG_XLAT_hpt;
526  	state |= SPA_CFG_HV;
527  	if (kernel) {
528  		if (mfmsr() & MSR_SF)
529  			state |= SPA_CFG_SF;
530  	} else {
531  		state |= SPA_CFG_PR;
532  		if (!test_tsk_thread_flag(current, TIF_32BIT))
533  			state |= SPA_CFG_SF;
534  	}
535  	return state;
536  }
537  
ocxl_link_add_pe(void * link_handle,int pasid,u32 pidr,u32 tidr,u64 amr,u16 bdf,struct mm_struct * mm,void (* xsl_err_cb)(void * data,u64 addr,u64 dsisr),void * xsl_err_data)538  int ocxl_link_add_pe(void *link_handle, int pasid, u32 pidr, u32 tidr,
539  		u64 amr, u16 bdf, struct mm_struct *mm,
540  		void (*xsl_err_cb)(void *data, u64 addr, u64 dsisr),
541  		void *xsl_err_data)
542  {
543  	struct ocxl_link *link = link_handle;
544  	struct spa *spa = link->spa;
545  	struct ocxl_process_element *pe;
546  	int pe_handle, rc = 0;
547  	struct pe_data *pe_data;
548  
549  	BUILD_BUG_ON(sizeof(struct ocxl_process_element) != 128);
550  	if (pasid > SPA_PASID_MAX)
551  		return -EINVAL;
552  
553  	mutex_lock(&spa->spa_lock);
554  	pe_handle = pasid & SPA_PE_MASK;
555  	pe = spa->spa_mem + pe_handle;
556  
557  	if (pe->software_state) {
558  		rc = -EBUSY;
559  		goto unlock;
560  	}
561  
562  	pe_data = kmalloc(sizeof(*pe_data), GFP_KERNEL);
563  	if (!pe_data) {
564  		rc = -ENOMEM;
565  		goto unlock;
566  	}
567  
568  	pe_data->mm = mm;
569  	pe_data->xsl_err_cb = xsl_err_cb;
570  	pe_data->xsl_err_data = xsl_err_data;
571  	pe_data->link = link;
572  	pe_data->mmu_notifier.ops = &ocxl_mmu_notifier_ops;
573  
574  	memset(pe, 0, sizeof(struct ocxl_process_element));
575  	pe->config_state = cpu_to_be64(calculate_cfg_state(pidr == 0));
576  	pe->pasid = cpu_to_be32(pasid << (31 - 19));
577  	pe->bdf = cpu_to_be16(bdf);
578  	pe->lpid = cpu_to_be32(mfspr(SPRN_LPID));
579  	pe->pid = cpu_to_be32(pidr);
580  	pe->tid = cpu_to_be32(tidr);
581  	pe->amr = cpu_to_be64(amr);
582  	pe->software_state = cpu_to_be32(SPA_PE_VALID);
583  
584  	/*
585  	 * For user contexts, register a copro so that TLBIs are seen
586  	 * by the nest MMU. If we have a kernel context, TLBIs are
587  	 * already global.
588  	 */
589  	if (mm) {
590  		mm_context_add_copro(mm);
591  		if (link->arva) {
592  			/* Use MMIO registers for the TLB Invalidate
593  			 * operations.
594  			 */
595  			trace_ocxl_init_mmu_notifier(pasid, mm->context.id);
596  			mmu_notifier_register(&pe_data->mmu_notifier, mm);
597  		}
598  	}
599  
600  	/*
601  	 * Barrier is to make sure PE is visible in the SPA before it
602  	 * is used by the device. It also helps with the global TLBI
603  	 * invalidation
604  	 */
605  	mb();
606  	radix_tree_insert(&spa->pe_tree, pe_handle, pe_data);
607  
608  	/*
609  	 * The mm must stay valid for as long as the device uses it. We
610  	 * lower the count when the context is removed from the SPA.
611  	 *
612  	 * We grab mm_count (and not mm_users), as we don't want to
613  	 * end up in a circular dependency if a process mmaps its
614  	 * mmio, therefore incrementing the file ref count when
615  	 * calling mmap(), and forgets to unmap before exiting. In
616  	 * that scenario, when the kernel handles the death of the
617  	 * process, the file is not cleaned because unmap was not
618  	 * called, and the mm wouldn't be freed because we would still
619  	 * have a reference on mm_users. Incrementing mm_count solves
620  	 * the problem.
621  	 */
622  	if (mm)
623  		mmgrab(mm);
624  	trace_ocxl_context_add(current->pid, spa->spa_mem, pasid, pidr, tidr);
625  unlock:
626  	mutex_unlock(&spa->spa_lock);
627  	return rc;
628  }
629  EXPORT_SYMBOL_GPL(ocxl_link_add_pe);
630  
ocxl_link_update_pe(void * link_handle,int pasid,__u16 tid)631  int ocxl_link_update_pe(void *link_handle, int pasid, __u16 tid)
632  {
633  	struct ocxl_link *link = link_handle;
634  	struct spa *spa = link->spa;
635  	struct ocxl_process_element *pe;
636  	int pe_handle, rc;
637  
638  	if (pasid > SPA_PASID_MAX)
639  		return -EINVAL;
640  
641  	pe_handle = pasid & SPA_PE_MASK;
642  	pe = spa->spa_mem + pe_handle;
643  
644  	mutex_lock(&spa->spa_lock);
645  
646  	pe->tid = cpu_to_be32(tid);
647  
648  	/*
649  	 * The barrier makes sure the PE is updated
650  	 * before we clear the NPU context cache below, so that the
651  	 * old PE cannot be reloaded erroneously.
652  	 */
653  	mb();
654  
655  	/*
656  	 * hook to platform code
657  	 * On powerpc, the entry needs to be cleared from the context
658  	 * cache of the NPU.
659  	 */
660  	rc = pnv_ocxl_spa_remove_pe_from_cache(link->platform_data, pe_handle);
661  	WARN_ON(rc);
662  
663  	mutex_unlock(&spa->spa_lock);
664  	return rc;
665  }
666  
ocxl_link_remove_pe(void * link_handle,int pasid)667  int ocxl_link_remove_pe(void *link_handle, int pasid)
668  {
669  	struct ocxl_link *link = link_handle;
670  	struct spa *spa = link->spa;
671  	struct ocxl_process_element *pe;
672  	struct pe_data *pe_data;
673  	int pe_handle, rc;
674  
675  	if (pasid > SPA_PASID_MAX)
676  		return -EINVAL;
677  
678  	/*
679  	 * About synchronization with our memory fault handler:
680  	 *
681  	 * Before removing the PE, the driver is supposed to have
682  	 * notified the AFU, which should have cleaned up and make
683  	 * sure the PASID is no longer in use, including pending
684  	 * interrupts. However, there's no way to be sure...
685  	 *
686  	 * We clear the PE and remove the context from our radix
687  	 * tree. From that point on, any new interrupt for that
688  	 * context will fail silently, which is ok. As mentioned
689  	 * above, that's not expected, but it could happen if the
690  	 * driver or AFU didn't do the right thing.
691  	 *
692  	 * There could still be a bottom half running, but we don't
693  	 * need to wait/flush, as it is managing a reference count on
694  	 * the mm it reads from the radix tree.
695  	 */
696  	pe_handle = pasid & SPA_PE_MASK;
697  	pe = spa->spa_mem + pe_handle;
698  
699  	mutex_lock(&spa->spa_lock);
700  
701  	if (!(be32_to_cpu(pe->software_state) & SPA_PE_VALID)) {
702  		rc = -EINVAL;
703  		goto unlock;
704  	}
705  
706  	trace_ocxl_context_remove(current->pid, spa->spa_mem, pasid,
707  				be32_to_cpu(pe->pid), be32_to_cpu(pe->tid));
708  
709  	memset(pe, 0, sizeof(struct ocxl_process_element));
710  	/*
711  	 * The barrier makes sure the PE is removed from the SPA
712  	 * before we clear the NPU context cache below, so that the
713  	 * old PE cannot be reloaded erroneously.
714  	 */
715  	mb();
716  
717  	/*
718  	 * hook to platform code
719  	 * On powerpc, the entry needs to be cleared from the context
720  	 * cache of the NPU.
721  	 */
722  	rc = pnv_ocxl_spa_remove_pe_from_cache(link->platform_data, pe_handle);
723  	WARN_ON(rc);
724  
725  	pe_data = radix_tree_delete(&spa->pe_tree, pe_handle);
726  	if (!pe_data) {
727  		WARN(1, "Couldn't find pe data when removing PE\n");
728  	} else {
729  		if (pe_data->mm) {
730  			if (link->arva) {
731  				trace_ocxl_release_mmu_notifier(pasid,
732  								pe_data->mm->context.id);
733  				mmu_notifier_unregister(&pe_data->mmu_notifier,
734  							pe_data->mm);
735  				spin_lock(&link->atsd_lock);
736  				pnv_ocxl_tlb_invalidate(link->arva,
737  							pe_data->mm->context.id,
738  							0ull,
739  							PAGE_SIZE);
740  				spin_unlock(&link->atsd_lock);
741  			}
742  			mm_context_remove_copro(pe_data->mm);
743  			mmdrop(pe_data->mm);
744  		}
745  		kfree_rcu(pe_data, rcu);
746  	}
747  unlock:
748  	mutex_unlock(&spa->spa_lock);
749  	return rc;
750  }
751  EXPORT_SYMBOL_GPL(ocxl_link_remove_pe);
752  
ocxl_link_irq_alloc(void * link_handle,int * hw_irq)753  int ocxl_link_irq_alloc(void *link_handle, int *hw_irq)
754  {
755  	struct ocxl_link *link = link_handle;
756  	int irq;
757  
758  	if (atomic_dec_if_positive(&link->irq_available) < 0)
759  		return -ENOSPC;
760  
761  	irq = xive_native_alloc_irq();
762  	if (!irq) {
763  		atomic_inc(&link->irq_available);
764  		return -ENXIO;
765  	}
766  
767  	*hw_irq = irq;
768  	return 0;
769  }
770  EXPORT_SYMBOL_GPL(ocxl_link_irq_alloc);
771  
ocxl_link_free_irq(void * link_handle,int hw_irq)772  void ocxl_link_free_irq(void *link_handle, int hw_irq)
773  {
774  	struct ocxl_link *link = link_handle;
775  
776  	xive_native_free_irq(hw_irq);
777  	atomic_inc(&link->irq_available);
778  }
779  EXPORT_SYMBOL_GPL(ocxl_link_free_irq);
780