cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

mmu.c (20074B)


      1// SPDX-License-Identifier: GPL-2.0
      2
      3/*
      4 * Copyright 2016-2020 HabanaLabs, Ltd.
      5 * All Rights Reserved.
      6 */
      7
      8#include <linux/slab.h>
      9
     10#include "../habanalabs.h"
     11
     12/**
     13 * hl_mmu_get_funcs() - get MMU functions structure
     14 * @hdev: habanalabs device structure.
     15 * @pgt_residency: page table residency.
     16 * @is_dram_addr: true if we need HMMU functions
     17 *
     18 * @return appropriate MMU functions structure
     19 */
     20static struct hl_mmu_funcs *hl_mmu_get_funcs(struct hl_device *hdev, int pgt_residency,
     21									bool is_dram_addr)
     22{
     23	return &hdev->mmu_func[pgt_residency];
     24}
     25
     26bool hl_is_dram_va(struct hl_device *hdev, u64 virt_addr)
     27{
     28	struct asic_fixed_properties *prop = &hdev->asic_prop;
     29
     30	return hl_mem_area_inside_range(virt_addr, prop->dmmu.page_size,
     31					prop->dmmu.start_addr,
     32					prop->dmmu.end_addr);
     33}
     34
     35/**
     36 * hl_mmu_init() - initialize the MMU module.
     37 * @hdev: habanalabs device structure.
     38 *
     39 * Return: 0 for success, non-zero for failure.
     40 */
     41int hl_mmu_init(struct hl_device *hdev)
     42{
     43	int rc = -EOPNOTSUPP;
     44
     45	if (!hdev->mmu_enable)
     46		return 0;
     47
     48	if (hdev->mmu_func[MMU_DR_PGT].init != NULL) {
     49		rc = hdev->mmu_func[MMU_DR_PGT].init(hdev);
     50		if (rc)
     51			return rc;
     52	}
     53
     54	if (hdev->mmu_func[MMU_HR_PGT].init != NULL)
     55		rc = hdev->mmu_func[MMU_HR_PGT].init(hdev);
     56
     57	return rc;
     58}
     59
     60/**
     61 * hl_mmu_fini() - release the MMU module.
     62 * @hdev: habanalabs device structure.
     63 *
     64 * This function does the following:
     65 * - Disable MMU in H/W.
     66 * - Free the pgt_infos pool.
     67 *
     68 * All contexts should be freed before calling this function.
     69 */
     70void hl_mmu_fini(struct hl_device *hdev)
     71{
     72	if (!hdev->mmu_enable)
     73		return;
     74
     75	if (hdev->mmu_func[MMU_DR_PGT].fini != NULL)
     76		hdev->mmu_func[MMU_DR_PGT].fini(hdev);
     77
     78	if (hdev->mmu_func[MMU_HR_PGT].fini != NULL)
     79		hdev->mmu_func[MMU_HR_PGT].fini(hdev);
     80}
     81
     82/**
     83 * hl_mmu_ctx_init() - initialize a context for using the MMU module.
     84 * @ctx: pointer to the context structure to initialize.
     85 *
     86 * Initialize a mutex to protect the concurrent mapping flow, a hash to hold all
     87 * page tables hops related to this context.
     88 * Return: 0 on success, non-zero otherwise.
     89 */
     90int hl_mmu_ctx_init(struct hl_ctx *ctx)
     91{
     92	struct hl_device *hdev = ctx->hdev;
     93	int rc = -EOPNOTSUPP;
     94
     95	if (!hdev->mmu_enable)
     96		return 0;
     97
     98	mutex_init(&ctx->mmu_lock);
     99
    100	if (hdev->mmu_func[MMU_DR_PGT].ctx_init != NULL) {
    101		rc = hdev->mmu_func[MMU_DR_PGT].ctx_init(ctx);
    102		if (rc)
    103			return rc;
    104	}
    105
    106	if (hdev->mmu_func[MMU_HR_PGT].ctx_init != NULL)
    107		rc = hdev->mmu_func[MMU_HR_PGT].ctx_init(ctx);
    108
    109	return rc;
    110}
    111
    112/*
    113 * hl_mmu_ctx_fini - disable a ctx from using the mmu module
    114 *
    115 * @ctx: pointer to the context structure
    116 *
    117 * This function does the following:
    118 * - Free any pgts which were not freed yet
    119 * - Free the mutex
    120 * - Free DRAM default page mapping hops
    121 */
    122void hl_mmu_ctx_fini(struct hl_ctx *ctx)
    123{
    124	struct hl_device *hdev = ctx->hdev;
    125
    126	if (!hdev->mmu_enable)
    127		return;
    128
    129	if (hdev->mmu_func[MMU_DR_PGT].ctx_fini != NULL)
    130		hdev->mmu_func[MMU_DR_PGT].ctx_fini(ctx);
    131
    132	if (hdev->mmu_func[MMU_HR_PGT].ctx_fini != NULL)
    133		hdev->mmu_func[MMU_HR_PGT].ctx_fini(ctx);
    134
    135	mutex_destroy(&ctx->mmu_lock);
    136}
    137
    138/*
    139 * hl_mmu_get_real_page_size - get real page size to use in map/unmap operation
    140 *
    141 * @hdev: pointer to device data.
    142 * @mmu_prop: MMU properties.
    143 * @page_size: page size
    144 * @real_page_size: set here the actual page size to use for the operation
    145 * @is_dram_addr: true if DRAM address, otherwise false.
    146 *
    147 * @return 0 on success, otherwise non 0 error code
    148 *
    149 * note that this is general implementation that can fit most MMU arch. but as this is used as an
    150 * MMU function:
    151 * 1. it shall not be called directly- only from mmu_func structure instance
    152 * 2. each MMU may modify the implementation internally
    153 */
    154int hl_mmu_get_real_page_size(struct hl_device *hdev, struct hl_mmu_properties *mmu_prop,
    155				u32 page_size, u32 *real_page_size, bool is_dram_addr)
    156{
    157	/*
    158	 * The H/W handles mapping of specific page sizes. Hence if the page
    159	 * size is bigger, we break it to sub-pages and map them separately.
    160	 */
    161	if ((page_size % mmu_prop->page_size) == 0) {
    162		*real_page_size = mmu_prop->page_size;
    163		return 0;
    164	}
    165
    166	dev_err(hdev->dev, "page size of %u is not %uKB aligned, can't map\n",
    167						page_size, mmu_prop->page_size >> 10);
    168
    169	return -EFAULT;
    170}
    171
    172static struct hl_mmu_properties *hl_mmu_get_prop(struct hl_device *hdev, u32 page_size,
    173							bool is_dram_addr)
    174{
    175	struct asic_fixed_properties *prop = &hdev->asic_prop;
    176
    177	if (is_dram_addr)
    178		return &prop->dmmu;
    179	else if ((page_size % prop->pmmu_huge.page_size) == 0)
    180		return &prop->pmmu_huge;
    181
    182	return &prop->pmmu;
    183}
    184
    185/*
    186 * hl_mmu_unmap_page - unmaps a virtual addr
    187 *
    188 * @ctx: pointer to the context structure
    189 * @virt_addr: virt addr to map from
    190 * @page_size: size of the page to unmap
    191 * @flush_pte: whether to do a PCI flush
    192 *
    193 * This function does the following:
    194 * - Check that the virt addr is mapped
    195 * - Unmap the virt addr and frees pgts if possible
    196 * - Returns 0 on success, -EINVAL if the given addr is not mapped
    197 *
    198 * Because this function changes the page tables in the device and because it
    199 * changes the MMU hash, it must be protected by a lock.
    200 * However, because it maps only a single page, the lock should be implemented
    201 * in a higher level in order to protect the entire mapping of the memory area
    202 *
    203 * For optimization reasons PCI flush may be requested once after unmapping of
    204 * large area.
    205 */
    206int hl_mmu_unmap_page(struct hl_ctx *ctx, u64 virt_addr, u32 page_size, bool flush_pte)
    207{
    208	struct hl_device *hdev = ctx->hdev;
    209	struct hl_mmu_properties *mmu_prop;
    210	struct hl_mmu_funcs *mmu_funcs;
    211	int i, pgt_residency, rc = 0;
    212	u32 real_page_size, npages;
    213	u64 real_virt_addr;
    214	bool is_dram_addr;
    215
    216	if (!hdev->mmu_enable)
    217		return 0;
    218
    219	is_dram_addr = hl_is_dram_va(hdev, virt_addr);
    220	mmu_prop = hl_mmu_get_prop(hdev, page_size, is_dram_addr);
    221
    222	pgt_residency = mmu_prop->host_resident ? MMU_HR_PGT : MMU_DR_PGT;
    223	mmu_funcs = hl_mmu_get_funcs(hdev, pgt_residency, is_dram_addr);
    224
    225	rc = hdev->asic_funcs->mmu_get_real_page_size(hdev, mmu_prop, page_size, &real_page_size,
    226							is_dram_addr);
    227	if (rc)
    228		return rc;
    229
    230	npages = page_size / real_page_size;
    231	real_virt_addr = virt_addr;
    232
    233	for (i = 0 ; i < npages ; i++) {
    234		rc = mmu_funcs->unmap(ctx, real_virt_addr, is_dram_addr);
    235		if (rc)
    236			break;
    237
    238		real_virt_addr += real_page_size;
    239	}
    240
    241	if (flush_pte)
    242		mmu_funcs->flush(ctx);
    243
    244	return rc;
    245}
    246
    247/*
    248 * hl_mmu_map_page - maps a virtual addr to physical addr
    249 *
    250 * @ctx: pointer to the context structure
    251 * @virt_addr: virt addr to map from
    252 * @phys_addr: phys addr to map to
    253 * @page_size: physical page size
    254 * @flush_pte: whether to do a PCI flush
    255 *
    256 * This function does the following:
    257 * - Check that the virt addr is not mapped
    258 * - Allocate pgts as necessary in order to map the virt addr to the phys
    259 * - Returns 0 on success, -EINVAL if addr is already mapped, or -ENOMEM.
    260 *
    261 * Because this function changes the page tables in the device and because it
    262 * changes the MMU hash, it must be protected by a lock.
    263 * However, because it maps only a single page, the lock should be implemented
    264 * in a higher level in order to protect the entire mapping of the memory area
    265 *
    266 * For optimization reasons PCI flush may be requested once after mapping of
    267 * large area.
    268 */
    269int hl_mmu_map_page(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr, u32 page_size,
    270			bool flush_pte)
    271{
    272	int i, rc, pgt_residency, mapped_cnt = 0;
    273	struct hl_device *hdev = ctx->hdev;
    274	struct hl_mmu_properties *mmu_prop;
    275	u64 real_virt_addr, real_phys_addr;
    276	struct hl_mmu_funcs *mmu_funcs;
    277	u32 real_page_size, npages;
    278	bool is_dram_addr;
    279
    280
    281	if (!hdev->mmu_enable)
    282		return 0;
    283
    284	is_dram_addr = hl_is_dram_va(hdev, virt_addr);
    285	mmu_prop = hl_mmu_get_prop(hdev, page_size, is_dram_addr);
    286
    287	pgt_residency = mmu_prop->host_resident ? MMU_HR_PGT : MMU_DR_PGT;
    288	mmu_funcs = hl_mmu_get_funcs(hdev, pgt_residency, is_dram_addr);
    289
    290	rc = hdev->asic_funcs->mmu_get_real_page_size(hdev, mmu_prop, page_size, &real_page_size,
    291							is_dram_addr);
    292	if (rc)
    293		return rc;
    294
    295	/*
    296	 * Verify that the phys and virt addresses are aligned with the
    297	 * MMU page size (in dram this means checking the address and MMU
    298	 * after scrambling)
    299	 */
    300	if ((is_dram_addr &&
    301			((hdev->asic_funcs->scramble_addr(hdev, phys_addr) &
    302				(mmu_prop->page_size - 1)) ||
    303			(hdev->asic_funcs->scramble_addr(hdev, virt_addr) &
    304				(mmu_prop->page_size - 1)))) ||
    305		(!is_dram_addr && ((phys_addr & (real_page_size - 1)) ||
    306				(virt_addr & (real_page_size - 1)))))
    307		dev_crit(hdev->dev,
    308			"Mapping address 0x%llx with virtual address 0x%llx and page size of 0x%x is erroneous! Addresses must be divisible by page size",
    309			phys_addr, virt_addr, real_page_size);
    310
    311	npages = page_size / real_page_size;
    312	real_virt_addr = virt_addr;
    313	real_phys_addr = phys_addr;
    314
    315	for (i = 0 ; i < npages ; i++) {
    316		rc = mmu_funcs->map(ctx, real_virt_addr, real_phys_addr, real_page_size,
    317										is_dram_addr);
    318		if (rc)
    319			goto err;
    320
    321		real_virt_addr += real_page_size;
    322		real_phys_addr += real_page_size;
    323		mapped_cnt++;
    324	}
    325
    326	if (flush_pte)
    327		mmu_funcs->flush(ctx);
    328
    329	return 0;
    330
    331err:
    332	real_virt_addr = virt_addr;
    333	for (i = 0 ; i < mapped_cnt ; i++) {
    334		if (mmu_funcs->unmap(ctx, real_virt_addr, is_dram_addr))
    335			dev_warn_ratelimited(hdev->dev,
    336				"failed to unmap va: 0x%llx\n", real_virt_addr);
    337
    338		real_virt_addr += real_page_size;
    339	}
    340
    341	mmu_funcs->flush(ctx);
    342
    343	return rc;
    344}
    345
    346/*
    347 * hl_mmu_map_contiguous - implements a wrapper for hl_mmu_map_page
    348 *                         for mapping contiguous physical memory
    349 *
    350 * @ctx: pointer to the context structure
    351 * @virt_addr: virt addr to map from
    352 * @phys_addr: phys addr to map to
    353 * @size: size to map
    354 *
    355 */
    356int hl_mmu_map_contiguous(struct hl_ctx *ctx, u64 virt_addr,
    357					u64 phys_addr, u32 size)
    358{
    359	struct hl_device *hdev = ctx->hdev;
    360	struct asic_fixed_properties *prop = &hdev->asic_prop;
    361	u64 curr_va, curr_pa;
    362	u32 page_size;
    363	bool flush_pte;
    364	int rc = 0, off;
    365
    366	if (hl_mem_area_inside_range(virt_addr, size,
    367			prop->dmmu.start_addr, prop->dmmu.end_addr))
    368		page_size = prop->dmmu.page_size;
    369	else if (hl_mem_area_inside_range(virt_addr, size,
    370			prop->pmmu.start_addr, prop->pmmu.end_addr))
    371		page_size = prop->pmmu.page_size;
    372	else if (hl_mem_area_inside_range(virt_addr, size,
    373			prop->pmmu_huge.start_addr, prop->pmmu_huge.end_addr))
    374		page_size = prop->pmmu_huge.page_size;
    375	else
    376		return -EINVAL;
    377
    378	for (off = 0 ; off < size ; off += page_size) {
    379		curr_va = virt_addr + off;
    380		curr_pa = phys_addr + off;
    381		flush_pte = (off + page_size) >= size;
    382		rc = hl_mmu_map_page(ctx, curr_va, curr_pa, page_size,
    383								flush_pte);
    384		if (rc) {
    385			dev_err(hdev->dev,
    386				"Map failed for va 0x%llx to pa 0x%llx\n",
    387				curr_va, curr_pa);
    388			goto unmap;
    389		}
    390	}
    391
    392	return rc;
    393
    394unmap:
    395	for (; off >= 0 ; off -= page_size) {
    396		curr_va = virt_addr + off;
    397		flush_pte = (off - (s32) page_size) < 0;
    398		if (hl_mmu_unmap_page(ctx, curr_va, page_size, flush_pte))
    399			dev_warn_ratelimited(hdev->dev,
    400				"failed to unmap va 0x%llx\n", curr_va);
    401	}
    402
    403	return rc;
    404}
    405
    406/*
    407 * hl_mmu_unmap_contiguous - implements a wrapper for hl_mmu_unmap_page
    408 *                           for unmapping contiguous physical memory
    409 *
    410 * @ctx: pointer to the context structure
    411 * @virt_addr: virt addr to unmap
    412 * @size: size to unmap
    413 *
    414 */
    415int hl_mmu_unmap_contiguous(struct hl_ctx *ctx, u64 virt_addr, u32 size)
    416{
    417	struct hl_device *hdev = ctx->hdev;
    418	struct asic_fixed_properties *prop = &hdev->asic_prop;
    419	u64 curr_va;
    420	u32 page_size;
    421	bool flush_pte;
    422	int rc = 0, off;
    423
    424	if (hl_mem_area_inside_range(virt_addr, size,
    425			prop->dmmu.start_addr, prop->dmmu.end_addr))
    426		page_size = prop->dmmu.page_size;
    427	else if (hl_mem_area_inside_range(virt_addr, size,
    428			prop->pmmu.start_addr, prop->pmmu.end_addr))
    429		page_size = prop->pmmu.page_size;
    430	else if (hl_mem_area_inside_range(virt_addr, size,
    431			prop->pmmu_huge.start_addr, prop->pmmu_huge.end_addr))
    432		page_size = prop->pmmu_huge.page_size;
    433	else
    434		return -EINVAL;
    435
    436	for (off = 0 ; off < size ; off += page_size) {
    437		curr_va = virt_addr + off;
    438		flush_pte = (off + page_size) >= size;
    439		rc = hl_mmu_unmap_page(ctx, curr_va, page_size, flush_pte);
    440		if (rc)
    441			dev_warn_ratelimited(hdev->dev,
    442				"Unmap failed for va 0x%llx\n", curr_va);
    443	}
    444
    445	return rc;
    446}
    447
    448/*
    449 * hl_mmu_swap_out - marks all mapping of the given ctx as swapped out
    450 *
    451 * @ctx: pointer to the context structure
    452 *
    453 */
    454void hl_mmu_swap_out(struct hl_ctx *ctx)
    455{
    456	struct hl_device *hdev = ctx->hdev;
    457
    458	if (!hdev->mmu_enable)
    459		return;
    460
    461	if (hdev->mmu_func[MMU_DR_PGT].swap_out != NULL)
    462		hdev->mmu_func[MMU_DR_PGT].swap_out(ctx);
    463
    464	if (hdev->mmu_func[MMU_HR_PGT].swap_out != NULL)
    465		hdev->mmu_func[MMU_HR_PGT].swap_out(ctx);
    466}
    467
    468/*
    469 * hl_mmu_swap_in - marks all mapping of the given ctx as swapped in
    470 *
    471 * @ctx: pointer to the context structure
    472 *
    473 */
    474void hl_mmu_swap_in(struct hl_ctx *ctx)
    475{
    476	struct hl_device *hdev = ctx->hdev;
    477
    478	if (!hdev->mmu_enable)
    479		return;
    480
    481	if (hdev->mmu_func[MMU_DR_PGT].swap_in != NULL)
    482		hdev->mmu_func[MMU_DR_PGT].swap_in(ctx);
    483
    484	if (hdev->mmu_func[MMU_HR_PGT].swap_in != NULL)
    485		hdev->mmu_func[MMU_HR_PGT].swap_in(ctx);
    486}
    487
    488static void hl_mmu_pa_page_with_offset(struct hl_ctx *ctx, u64 virt_addr,
    489						struct hl_mmu_hop_info *hops,
    490						u64 *phys_addr)
    491{
    492	struct asic_fixed_properties *prop = &ctx->hdev->asic_prop;
    493	u64 offset_mask, addr_mask, hop_shift, tmp_phys_addr;
    494	struct hl_mmu_properties *mmu_prop;
    495
    496	/* last hop holds the phys address and flags */
    497	if (hops->unscrambled_paddr)
    498		tmp_phys_addr = hops->unscrambled_paddr;
    499	else
    500		tmp_phys_addr = hops->hop_info[hops->used_hops - 1].hop_pte_val;
    501
    502	if (hops->range_type == HL_VA_RANGE_TYPE_HOST_HUGE)
    503		mmu_prop = &prop->pmmu_huge;
    504	else if (hops->range_type == HL_VA_RANGE_TYPE_HOST)
    505		mmu_prop = &prop->pmmu;
    506	else /* HL_VA_RANGE_TYPE_DRAM */
    507		mmu_prop = &prop->dmmu;
    508
    509	if ((hops->range_type == HL_VA_RANGE_TYPE_DRAM) &&
    510			!is_power_of_2(prop->dram_page_size)) {
    511		u64 dram_page_size, dram_base, abs_phys_addr, abs_virt_addr,
    512			page_id, page_start;
    513		u32 page_off;
    514
    515		/*
    516		 * Bit arithmetics cannot be used for non power of two page
    517		 * sizes. In addition, since bit arithmetics is not used,
    518		 * we cannot ignore dram base. All that shall be considered.
    519		 */
    520
    521		dram_page_size = prop->dram_page_size;
    522		dram_base = prop->dram_base_address;
    523		abs_phys_addr = tmp_phys_addr - dram_base;
    524		abs_virt_addr = virt_addr - dram_base;
    525		page_id = DIV_ROUND_DOWN_ULL(abs_phys_addr, dram_page_size);
    526		page_start = page_id * dram_page_size;
    527		div_u64_rem(abs_virt_addr, dram_page_size, &page_off);
    528
    529		*phys_addr = page_start + page_off + dram_base;
    530	} else {
    531		/*
    532		 * find the correct hop shift field in hl_mmu_properties
    533		 * structure in order to determine the right masks
    534		 * for the page offset.
    535		 */
    536		hop_shift = mmu_prop->hop_shifts[hops->used_hops - 1];
    537		offset_mask = (1ull << hop_shift) - 1;
    538		addr_mask = ~(offset_mask);
    539		*phys_addr = (tmp_phys_addr & addr_mask) |
    540				(virt_addr & offset_mask);
    541	}
    542}
    543
    544int hl_mmu_va_to_pa(struct hl_ctx *ctx, u64 virt_addr, u64 *phys_addr)
    545{
    546	struct hl_mmu_hop_info hops;
    547	int rc;
    548
    549	memset(&hops, 0, sizeof(hops));
    550
    551	rc = hl_mmu_get_tlb_info(ctx, virt_addr, &hops);
    552	if (rc)
    553		return rc;
    554
    555	hl_mmu_pa_page_with_offset(ctx, virt_addr, &hops,  phys_addr);
    556
    557	return 0;
    558}
    559
    560int hl_mmu_get_tlb_info(struct hl_ctx *ctx, u64 virt_addr,
    561			struct hl_mmu_hop_info *hops)
    562{
    563	struct hl_device *hdev = ctx->hdev;
    564	struct asic_fixed_properties *prop;
    565	struct hl_mmu_properties *mmu_prop;
    566	struct hl_mmu_funcs *mmu_funcs;
    567	int pgt_residency, rc;
    568	bool is_dram_addr;
    569
    570	if (!hdev->mmu_enable)
    571		return -EOPNOTSUPP;
    572
    573	prop = &hdev->asic_prop;
    574	hops->scrambled_vaddr = virt_addr;      /* assume no scrambling */
    575
    576	is_dram_addr = hl_mem_area_inside_range(virt_addr, prop->dmmu.page_size,
    577								prop->dmmu.start_addr,
    578								prop->dmmu.end_addr);
    579
    580	/* host-residency is the same in PMMU and PMMU huge, no need to distinguish here */
    581	mmu_prop = is_dram_addr ? &prop->dmmu : &prop->pmmu;
    582	pgt_residency = mmu_prop->host_resident ? MMU_HR_PGT : MMU_DR_PGT;
    583	mmu_funcs = hl_mmu_get_funcs(hdev, pgt_residency, is_dram_addr);
    584
    585	mutex_lock(&ctx->mmu_lock);
    586	rc = mmu_funcs->get_tlb_info(ctx, virt_addr, hops);
    587	mutex_unlock(&ctx->mmu_lock);
    588
    589	if (rc)
    590		return rc;
    591
    592	/* add page offset to physical address */
    593	if (hops->unscrambled_paddr)
    594		hl_mmu_pa_page_with_offset(ctx, virt_addr, hops, &hops->unscrambled_paddr);
    595
    596	return 0;
    597}
    598
    599int hl_mmu_if_set_funcs(struct hl_device *hdev)
    600{
    601	if (!hdev->mmu_enable)
    602		return 0;
    603
    604	switch (hdev->asic_type) {
    605	case ASIC_GOYA:
    606	case ASIC_GAUDI:
    607	case ASIC_GAUDI_SEC:
    608		hl_mmu_v1_set_funcs(hdev, &hdev->mmu_func[MMU_DR_PGT]);
    609		break;
    610	default:
    611		dev_err(hdev->dev, "Unrecognized ASIC type %d\n",
    612			hdev->asic_type);
    613		return -EOPNOTSUPP;
    614	}
    615
    616	return 0;
    617}
    618
    619/**
    620 * hl_mmu_scramble_addr() - The generic mmu address scrambling routine.
    621 * @hdev: pointer to device data.
    622 * @addr: The address to scramble.
    623 *
    624 * Return: The scrambled address.
    625 */
    626u64 hl_mmu_scramble_addr(struct hl_device *hdev, u64 addr)
    627{
    628	return addr;
    629}
    630
    631/**
    632 * hl_mmu_descramble_addr() - The generic mmu address descrambling
    633 * routine.
    634 * @hdev: pointer to device data.
    635 * @addr: The address to descramble.
    636 *
    637 * Return: The un-scrambled address.
    638 */
    639u64 hl_mmu_descramble_addr(struct hl_device *hdev, u64 addr)
    640{
    641	return addr;
    642}
    643
    644int hl_mmu_invalidate_cache(struct hl_device *hdev, bool is_hard, u32 flags)
    645{
    646	int rc;
    647
    648	rc = hdev->asic_funcs->mmu_invalidate_cache(hdev, is_hard, flags);
    649	if (rc)
    650		dev_err_ratelimited(hdev->dev, "MMU cache invalidation failed\n");
    651
    652	return rc;
    653}
    654
    655int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard,
    656					u32 flags, u32 asid, u64 va, u64 size)
    657{
    658	int rc;
    659
    660	rc = hdev->asic_funcs->mmu_invalidate_cache_range(hdev, is_hard, flags,
    661								asid, va, size);
    662	if (rc)
    663		dev_err_ratelimited(hdev->dev, "MMU cache range invalidation failed\n");
    664
    665	return rc;
    666}
    667
    668static void hl_mmu_prefetch_work_function(struct work_struct *work)
    669{
    670	struct hl_prefetch_work *pfw = container_of(work, struct hl_prefetch_work, pf_work);
    671	struct hl_ctx *ctx = pfw->ctx;
    672
    673	if (!hl_device_operational(ctx->hdev, NULL))
    674		goto put_ctx;
    675
    676	mutex_lock(&ctx->mmu_lock);
    677
    678	ctx->hdev->asic_funcs->mmu_prefetch_cache_range(ctx, pfw->flags, pfw->asid,
    679								pfw->va, pfw->size);
    680
    681	mutex_unlock(&ctx->mmu_lock);
    682
    683put_ctx:
    684	/*
    685	 * context was taken in the common mmu prefetch function- see comment there about
    686	 * context handling.
    687	 */
    688	hl_ctx_put(ctx);
    689	kfree(pfw);
    690}
    691
    692int hl_mmu_prefetch_cache_range(struct hl_ctx *ctx, u32 flags, u32 asid, u64 va, u64 size)
    693{
    694	struct hl_prefetch_work *handle_pf_work;
    695
    696	handle_pf_work = kmalloc(sizeof(*handle_pf_work), GFP_KERNEL);
    697	if (!handle_pf_work)
    698		return -ENOMEM;
    699
    700	INIT_WORK(&handle_pf_work->pf_work, hl_mmu_prefetch_work_function);
    701	handle_pf_work->ctx = ctx;
    702	handle_pf_work->va = va;
    703	handle_pf_work->size = size;
    704	handle_pf_work->flags = flags;
    705	handle_pf_work->asid = asid;
    706
    707	/*
    708	 * as actual prefetch is done in a WQ we must get the context (and put it
    709	 * at the end of the work function)
    710	 */
    711	hl_ctx_get(ctx);
    712	queue_work(ctx->hdev->pf_wq, &handle_pf_work->pf_work);
    713
    714	return 0;
    715}
    716
    717u64 hl_mmu_get_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte)
    718{
    719	return (curr_pte & PAGE_PRESENT_MASK) ? (curr_pte & HOP_PHYS_ADDR_MASK) : ULLONG_MAX;
    720}
    721
    722/**
    723 * hl_mmu_get_hop_pte_phys_addr() - extract PTE address from HOP
    724 * @ctx: pointer to the context structure to initialize.
    725 * @mmu_prop: MMU properties.
    726 * @hop_idx: HOP index.
    727 * @hop_addr: HOP address.
    728 * @virt_addr: virtual address fro the translation.
    729 *
    730 * @return the matching PTE value on success, otherwise U64_MAX.
    731 */
    732u64 hl_mmu_get_hop_pte_phys_addr(struct hl_ctx *ctx, struct hl_mmu_properties *mmu_prop,
    733					u8 hop_idx, u64 hop_addr, u64 virt_addr)
    734{
    735	u64 mask, shift;
    736
    737	if (hop_idx >= mmu_prop->num_hops) {
    738		dev_err_ratelimited(ctx->hdev->dev, "Invalid hop index %d\n", hop_idx);
    739		return U64_MAX;
    740	}
    741
    742	shift = mmu_prop->hop_shifts[hop_idx];
    743	mask = mmu_prop->hop_masks[hop_idx];
    744
    745	return hop_addr + ctx->hdev->asic_prop.mmu_pte_size * ((virt_addr & mask) >> shift);
    746}
    747