cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

dfs_cache.c (40611B)


      1// SPDX-License-Identifier: GPL-2.0
      2/*
      3 * DFS referral cache routines
      4 *
      5 * Copyright (c) 2018-2019 Paulo Alcantara <palcantara@suse.de>
      6 */
      7
      8#include <linux/jhash.h>
      9#include <linux/ktime.h>
     10#include <linux/slab.h>
     11#include <linux/proc_fs.h>
     12#include <linux/nls.h>
     13#include <linux/workqueue.h>
     14#include <linux/uuid.h>
     15#include "cifsglob.h"
     16#include "smb2pdu.h"
     17#include "smb2proto.h"
     18#include "cifsproto.h"
     19#include "cifs_debug.h"
     20#include "cifs_unicode.h"
     21#include "smb2glob.h"
     22#include "dns_resolve.h"
     23
     24#include "dfs_cache.h"
     25
     26#define CACHE_HTABLE_SIZE 32
     27#define CACHE_MAX_ENTRIES 64
     28#define CACHE_MIN_TTL 120 /* 2 minutes */
     29
     30#define IS_DFS_INTERLINK(v) (((v) & DFSREF_REFERRAL_SERVER) && !((v) & DFSREF_STORAGE_SERVER))
     31
     32struct cache_dfs_tgt {
     33	char *name;
     34	int path_consumed;
     35	struct list_head list;
     36};
     37
     38struct cache_entry {
     39	struct hlist_node hlist;
     40	const char *path;
     41	int hdr_flags; /* RESP_GET_DFS_REFERRAL.ReferralHeaderFlags */
     42	int ttl; /* DFS_REREFERRAL_V3.TimeToLive */
     43	int srvtype; /* DFS_REREFERRAL_V3.ServerType */
     44	int ref_flags; /* DFS_REREFERRAL_V3.ReferralEntryFlags */
     45	struct timespec64 etime;
     46	int path_consumed; /* RESP_GET_DFS_REFERRAL.PathConsumed */
     47	int numtgts;
     48	struct list_head tlist;
     49	struct cache_dfs_tgt *tgthint;
     50};
     51
     52/* List of referral server sessions per dfs mount */
     53struct mount_group {
     54	struct list_head list;
     55	uuid_t id;
     56	struct cifs_ses *sessions[CACHE_MAX_ENTRIES];
     57	int num_sessions;
     58	spinlock_t lock;
     59	struct list_head refresh_list;
     60	struct kref refcount;
     61};
     62
     63static struct kmem_cache *cache_slab __read_mostly;
     64static struct workqueue_struct *dfscache_wq __read_mostly;
     65
     66static int cache_ttl;
     67static DEFINE_SPINLOCK(cache_ttl_lock);
     68
     69static struct nls_table *cache_cp;
     70
     71/*
     72 * Number of entries in the cache
     73 */
     74static atomic_t cache_count;
     75
     76static struct hlist_head cache_htable[CACHE_HTABLE_SIZE];
     77static DECLARE_RWSEM(htable_rw_lock);
     78
     79static LIST_HEAD(mount_group_list);
     80static DEFINE_MUTEX(mount_group_list_lock);
     81
     82static void refresh_cache_worker(struct work_struct *work);
     83
     84static DECLARE_DELAYED_WORK(refresh_task, refresh_cache_worker);
     85
     86static void get_ipc_unc(const char *ref_path, char *ipc, size_t ipclen)
     87{
     88	const char *host;
     89	size_t len;
     90
     91	extract_unc_hostname(ref_path, &host, &len);
     92	scnprintf(ipc, ipclen, "\\\\%.*s\\IPC$", (int)len, host);
     93}
     94
     95static struct cifs_ses *find_ipc_from_server_path(struct cifs_ses **ses, const char *path)
     96{
     97	char unc[SERVER_NAME_LENGTH + sizeof("//x/IPC$")] = {0};
     98
     99	get_ipc_unc(path, unc, sizeof(unc));
    100	for (; *ses; ses++) {
    101		if (!strcasecmp(unc, (*ses)->tcon_ipc->treeName))
    102			return *ses;
    103	}
    104	return ERR_PTR(-ENOENT);
    105}
    106
    107static void __mount_group_release(struct mount_group *mg)
    108{
    109	int i;
    110
    111	for (i = 0; i < mg->num_sessions; i++)
    112		cifs_put_smb_ses(mg->sessions[i]);
    113	kfree(mg);
    114}
    115
    116static void mount_group_release(struct kref *kref)
    117{
    118	struct mount_group *mg = container_of(kref, struct mount_group, refcount);
    119
    120	mutex_lock(&mount_group_list_lock);
    121	list_del(&mg->list);
    122	mutex_unlock(&mount_group_list_lock);
    123	__mount_group_release(mg);
    124}
    125
    126static struct mount_group *find_mount_group_locked(const uuid_t *id)
    127{
    128	struct mount_group *mg;
    129
    130	list_for_each_entry(mg, &mount_group_list, list) {
    131		if (uuid_equal(&mg->id, id))
    132			return mg;
    133	}
    134	return ERR_PTR(-ENOENT);
    135}
    136
    137static struct mount_group *__get_mount_group_locked(const uuid_t *id)
    138{
    139	struct mount_group *mg;
    140
    141	mg = find_mount_group_locked(id);
    142	if (!IS_ERR(mg))
    143		return mg;
    144
    145	mg = kmalloc(sizeof(*mg), GFP_KERNEL);
    146	if (!mg)
    147		return ERR_PTR(-ENOMEM);
    148	kref_init(&mg->refcount);
    149	uuid_copy(&mg->id, id);
    150	mg->num_sessions = 0;
    151	spin_lock_init(&mg->lock);
    152	list_add(&mg->list, &mount_group_list);
    153	return mg;
    154}
    155
    156static struct mount_group *get_mount_group(const uuid_t *id)
    157{
    158	struct mount_group *mg;
    159
    160	mutex_lock(&mount_group_list_lock);
    161	mg = __get_mount_group_locked(id);
    162	if (!IS_ERR(mg))
    163		kref_get(&mg->refcount);
    164	mutex_unlock(&mount_group_list_lock);
    165
    166	return mg;
    167}
    168
    169static void free_mount_group_list(void)
    170{
    171	struct mount_group *mg, *tmp_mg;
    172
    173	list_for_each_entry_safe(mg, tmp_mg, &mount_group_list, list) {
    174		list_del_init(&mg->list);
    175		__mount_group_release(mg);
    176	}
    177}
    178
    179/**
    180 * dfs_cache_canonical_path - get a canonical DFS path
    181 *
    182 * @path: DFS path
    183 * @cp: codepage
    184 * @remap: mapping type
    185 *
    186 * Return canonical path if success, otherwise error.
    187 */
    188char *dfs_cache_canonical_path(const char *path, const struct nls_table *cp, int remap)
    189{
    190	char *tmp;
    191	int plen = 0;
    192	char *npath;
    193
    194	if (!path || strlen(path) < 3 || (*path != '\\' && *path != '/'))
    195		return ERR_PTR(-EINVAL);
    196
    197	if (unlikely(strcmp(cp->charset, cache_cp->charset))) {
    198		tmp = (char *)cifs_strndup_to_utf16(path, strlen(path), &plen, cp, remap);
    199		if (!tmp) {
    200			cifs_dbg(VFS, "%s: failed to convert path to utf16\n", __func__);
    201			return ERR_PTR(-EINVAL);
    202		}
    203
    204		npath = cifs_strndup_from_utf16(tmp, plen, true, cache_cp);
    205		kfree(tmp);
    206
    207		if (!npath) {
    208			cifs_dbg(VFS, "%s: failed to convert path from utf16\n", __func__);
    209			return ERR_PTR(-EINVAL);
    210		}
    211	} else {
    212		npath = kstrdup(path, GFP_KERNEL);
    213		if (!npath)
    214			return ERR_PTR(-ENOMEM);
    215	}
    216	convert_delimiter(npath, '\\');
    217	return npath;
    218}
    219
    220static inline bool cache_entry_expired(const struct cache_entry *ce)
    221{
    222	struct timespec64 ts;
    223
    224	ktime_get_coarse_real_ts64(&ts);
    225	return timespec64_compare(&ts, &ce->etime) >= 0;
    226}
    227
    228static inline void free_tgts(struct cache_entry *ce)
    229{
    230	struct cache_dfs_tgt *t, *n;
    231
    232	list_for_each_entry_safe(t, n, &ce->tlist, list) {
    233		list_del(&t->list);
    234		kfree(t->name);
    235		kfree(t);
    236	}
    237}
    238
    239static inline void flush_cache_ent(struct cache_entry *ce)
    240{
    241	hlist_del_init(&ce->hlist);
    242	kfree(ce->path);
    243	free_tgts(ce);
    244	atomic_dec(&cache_count);
    245	kmem_cache_free(cache_slab, ce);
    246}
    247
    248static void flush_cache_ents(void)
    249{
    250	int i;
    251
    252	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
    253		struct hlist_head *l = &cache_htable[i];
    254		struct hlist_node *n;
    255		struct cache_entry *ce;
    256
    257		hlist_for_each_entry_safe(ce, n, l, hlist) {
    258			if (!hlist_unhashed(&ce->hlist))
    259				flush_cache_ent(ce);
    260		}
    261	}
    262}
    263
    264/*
    265 * dfs cache /proc file
    266 */
    267static int dfscache_proc_show(struct seq_file *m, void *v)
    268{
    269	int i;
    270	struct cache_entry *ce;
    271	struct cache_dfs_tgt *t;
    272
    273	seq_puts(m, "DFS cache\n---------\n");
    274
    275	down_read(&htable_rw_lock);
    276	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
    277		struct hlist_head *l = &cache_htable[i];
    278
    279		hlist_for_each_entry(ce, l, hlist) {
    280			if (hlist_unhashed(&ce->hlist))
    281				continue;
    282
    283			seq_printf(m,
    284				   "cache entry: path=%s,type=%s,ttl=%d,etime=%ld,hdr_flags=0x%x,ref_flags=0x%x,interlink=%s,path_consumed=%d,expired=%s\n",
    285				   ce->path, ce->srvtype == DFS_TYPE_ROOT ? "root" : "link",
    286				   ce->ttl, ce->etime.tv_nsec, ce->hdr_flags, ce->ref_flags,
    287				   IS_DFS_INTERLINK(ce->hdr_flags) ? "yes" : "no",
    288				   ce->path_consumed, cache_entry_expired(ce) ? "yes" : "no");
    289
    290			list_for_each_entry(t, &ce->tlist, list) {
    291				seq_printf(m, "  %s%s\n",
    292					   t->name,
    293					   ce->tgthint == t ? " (target hint)" : "");
    294			}
    295		}
    296	}
    297	up_read(&htable_rw_lock);
    298
    299	return 0;
    300}
    301
    302static ssize_t dfscache_proc_write(struct file *file, const char __user *buffer,
    303				   size_t count, loff_t *ppos)
    304{
    305	char c;
    306	int rc;
    307
    308	rc = get_user(c, buffer);
    309	if (rc)
    310		return rc;
    311
    312	if (c != '0')
    313		return -EINVAL;
    314
    315	cifs_dbg(FYI, "clearing dfs cache\n");
    316
    317	down_write(&htable_rw_lock);
    318	flush_cache_ents();
    319	up_write(&htable_rw_lock);
    320
    321	return count;
    322}
    323
    324static int dfscache_proc_open(struct inode *inode, struct file *file)
    325{
    326	return single_open(file, dfscache_proc_show, NULL);
    327}
    328
    329const struct proc_ops dfscache_proc_ops = {
    330	.proc_open	= dfscache_proc_open,
    331	.proc_read	= seq_read,
    332	.proc_lseek	= seq_lseek,
    333	.proc_release	= single_release,
    334	.proc_write	= dfscache_proc_write,
    335};
    336
    337#ifdef CONFIG_CIFS_DEBUG2
    338static inline void dump_tgts(const struct cache_entry *ce)
    339{
    340	struct cache_dfs_tgt *t;
    341
    342	cifs_dbg(FYI, "target list:\n");
    343	list_for_each_entry(t, &ce->tlist, list) {
    344		cifs_dbg(FYI, "  %s%s\n", t->name,
    345			 ce->tgthint == t ? " (target hint)" : "");
    346	}
    347}
    348
    349static inline void dump_ce(const struct cache_entry *ce)
    350{
    351	cifs_dbg(FYI, "cache entry: path=%s,type=%s,ttl=%d,etime=%ld,hdr_flags=0x%x,ref_flags=0x%x,interlink=%s,path_consumed=%d,expired=%s\n",
    352		 ce->path,
    353		 ce->srvtype == DFS_TYPE_ROOT ? "root" : "link", ce->ttl,
    354		 ce->etime.tv_nsec,
    355		 ce->hdr_flags, ce->ref_flags,
    356		 IS_DFS_INTERLINK(ce->hdr_flags) ? "yes" : "no",
    357		 ce->path_consumed,
    358		 cache_entry_expired(ce) ? "yes" : "no");
    359	dump_tgts(ce);
    360}
    361
    362static inline void dump_refs(const struct dfs_info3_param *refs, int numrefs)
    363{
    364	int i;
    365
    366	cifs_dbg(FYI, "DFS referrals returned by the server:\n");
    367	for (i = 0; i < numrefs; i++) {
    368		const struct dfs_info3_param *ref = &refs[i];
    369
    370		cifs_dbg(FYI,
    371			 "\n"
    372			 "flags:         0x%x\n"
    373			 "path_consumed: %d\n"
    374			 "server_type:   0x%x\n"
    375			 "ref_flag:      0x%x\n"
    376			 "path_name:     %s\n"
    377			 "node_name:     %s\n"
    378			 "ttl:           %d (%dm)\n",
    379			 ref->flags, ref->path_consumed, ref->server_type,
    380			 ref->ref_flag, ref->path_name, ref->node_name,
    381			 ref->ttl, ref->ttl / 60);
    382	}
    383}
    384#else
    385#define dump_tgts(e)
    386#define dump_ce(e)
    387#define dump_refs(r, n)
    388#endif
    389
    390/**
    391 * dfs_cache_init - Initialize DFS referral cache.
    392 *
    393 * Return zero if initialized successfully, otherwise non-zero.
    394 */
    395int dfs_cache_init(void)
    396{
    397	int rc;
    398	int i;
    399
    400	dfscache_wq = alloc_workqueue("cifs-dfscache", WQ_FREEZABLE | WQ_UNBOUND, 1);
    401	if (!dfscache_wq)
    402		return -ENOMEM;
    403
    404	cache_slab = kmem_cache_create("cifs_dfs_cache",
    405				       sizeof(struct cache_entry), 0,
    406				       SLAB_HWCACHE_ALIGN, NULL);
    407	if (!cache_slab) {
    408		rc = -ENOMEM;
    409		goto out_destroy_wq;
    410	}
    411
    412	for (i = 0; i < CACHE_HTABLE_SIZE; i++)
    413		INIT_HLIST_HEAD(&cache_htable[i]);
    414
    415	atomic_set(&cache_count, 0);
    416	cache_cp = load_nls("utf8");
    417	if (!cache_cp)
    418		cache_cp = load_nls_default();
    419
    420	cifs_dbg(FYI, "%s: initialized DFS referral cache\n", __func__);
    421	return 0;
    422
    423out_destroy_wq:
    424	destroy_workqueue(dfscache_wq);
    425	return rc;
    426}
    427
    428static int cache_entry_hash(const void *data, int size, unsigned int *hash)
    429{
    430	int i, clen;
    431	const unsigned char *s = data;
    432	wchar_t c;
    433	unsigned int h = 0;
    434
    435	for (i = 0; i < size; i += clen) {
    436		clen = cache_cp->char2uni(&s[i], size - i, &c);
    437		if (unlikely(clen < 0)) {
    438			cifs_dbg(VFS, "%s: can't convert char\n", __func__);
    439			return clen;
    440		}
    441		c = cifs_toupper(c);
    442		h = jhash(&c, sizeof(c), h);
    443	}
    444	*hash = h % CACHE_HTABLE_SIZE;
    445	return 0;
    446}
    447
    448/* Return target hint of a DFS cache entry */
    449static inline char *get_tgt_name(const struct cache_entry *ce)
    450{
    451	struct cache_dfs_tgt *t = ce->tgthint;
    452
    453	return t ? t->name : ERR_PTR(-ENOENT);
    454}
    455
    456/* Return expire time out of a new entry's TTL */
    457static inline struct timespec64 get_expire_time(int ttl)
    458{
    459	struct timespec64 ts = {
    460		.tv_sec = ttl,
    461		.tv_nsec = 0,
    462	};
    463	struct timespec64 now;
    464
    465	ktime_get_coarse_real_ts64(&now);
    466	return timespec64_add(now, ts);
    467}
    468
    469/* Allocate a new DFS target */
    470static struct cache_dfs_tgt *alloc_target(const char *name, int path_consumed)
    471{
    472	struct cache_dfs_tgt *t;
    473
    474	t = kmalloc(sizeof(*t), GFP_ATOMIC);
    475	if (!t)
    476		return ERR_PTR(-ENOMEM);
    477	t->name = kstrdup(name, GFP_ATOMIC);
    478	if (!t->name) {
    479		kfree(t);
    480		return ERR_PTR(-ENOMEM);
    481	}
    482	t->path_consumed = path_consumed;
    483	INIT_LIST_HEAD(&t->list);
    484	return t;
    485}
    486
    487/*
    488 * Copy DFS referral information to a cache entry and conditionally update
    489 * target hint.
    490 */
    491static int copy_ref_data(const struct dfs_info3_param *refs, int numrefs,
    492			 struct cache_entry *ce, const char *tgthint)
    493{
    494	int i;
    495
    496	ce->ttl = max_t(int, refs[0].ttl, CACHE_MIN_TTL);
    497	ce->etime = get_expire_time(ce->ttl);
    498	ce->srvtype = refs[0].server_type;
    499	ce->hdr_flags = refs[0].flags;
    500	ce->ref_flags = refs[0].ref_flag;
    501	ce->path_consumed = refs[0].path_consumed;
    502
    503	for (i = 0; i < numrefs; i++) {
    504		struct cache_dfs_tgt *t;
    505
    506		t = alloc_target(refs[i].node_name, refs[i].path_consumed);
    507		if (IS_ERR(t)) {
    508			free_tgts(ce);
    509			return PTR_ERR(t);
    510		}
    511		if (tgthint && !strcasecmp(t->name, tgthint)) {
    512			list_add(&t->list, &ce->tlist);
    513			tgthint = NULL;
    514		} else {
    515			list_add_tail(&t->list, &ce->tlist);
    516		}
    517		ce->numtgts++;
    518	}
    519
    520	ce->tgthint = list_first_entry_or_null(&ce->tlist,
    521					       struct cache_dfs_tgt, list);
    522
    523	return 0;
    524}
    525
    526/* Allocate a new cache entry */
    527static struct cache_entry *alloc_cache_entry(struct dfs_info3_param *refs, int numrefs)
    528{
    529	struct cache_entry *ce;
    530	int rc;
    531
    532	ce = kmem_cache_zalloc(cache_slab, GFP_KERNEL);
    533	if (!ce)
    534		return ERR_PTR(-ENOMEM);
    535
    536	ce->path = refs[0].path_name;
    537	refs[0].path_name = NULL;
    538
    539	INIT_HLIST_NODE(&ce->hlist);
    540	INIT_LIST_HEAD(&ce->tlist);
    541
    542	rc = copy_ref_data(refs, numrefs, ce, NULL);
    543	if (rc) {
    544		kfree(ce->path);
    545		kmem_cache_free(cache_slab, ce);
    546		ce = ERR_PTR(rc);
    547	}
    548	return ce;
    549}
    550
    551static void remove_oldest_entry_locked(void)
    552{
    553	int i;
    554	struct cache_entry *ce;
    555	struct cache_entry *to_del = NULL;
    556
    557	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
    558
    559	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
    560		struct hlist_head *l = &cache_htable[i];
    561
    562		hlist_for_each_entry(ce, l, hlist) {
    563			if (hlist_unhashed(&ce->hlist))
    564				continue;
    565			if (!to_del || timespec64_compare(&ce->etime,
    566							  &to_del->etime) < 0)
    567				to_del = ce;
    568		}
    569	}
    570
    571	if (!to_del) {
    572		cifs_dbg(FYI, "%s: no entry to remove\n", __func__);
    573		return;
    574	}
    575
    576	cifs_dbg(FYI, "%s: removing entry\n", __func__);
    577	dump_ce(to_del);
    578	flush_cache_ent(to_del);
    579}
    580
    581/* Add a new DFS cache entry */
    582static int add_cache_entry_locked(struct dfs_info3_param *refs, int numrefs)
    583{
    584	int rc;
    585	struct cache_entry *ce;
    586	unsigned int hash;
    587
    588	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
    589
    590	if (atomic_read(&cache_count) >= CACHE_MAX_ENTRIES) {
    591		cifs_dbg(FYI, "%s: reached max cache size (%d)\n", __func__, CACHE_MAX_ENTRIES);
    592		remove_oldest_entry_locked();
    593	}
    594
    595	rc = cache_entry_hash(refs[0].path_name, strlen(refs[0].path_name), &hash);
    596	if (rc)
    597		return rc;
    598
    599	ce = alloc_cache_entry(refs, numrefs);
    600	if (IS_ERR(ce))
    601		return PTR_ERR(ce);
    602
    603	spin_lock(&cache_ttl_lock);
    604	if (!cache_ttl) {
    605		cache_ttl = ce->ttl;
    606		queue_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
    607	} else {
    608		cache_ttl = min_t(int, cache_ttl, ce->ttl);
    609		mod_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
    610	}
    611	spin_unlock(&cache_ttl_lock);
    612
    613	hlist_add_head(&ce->hlist, &cache_htable[hash]);
    614	dump_ce(ce);
    615
    616	atomic_inc(&cache_count);
    617
    618	return 0;
    619}
    620
    621/* Check if two DFS paths are equal.  @s1 and @s2 are expected to be in @cache_cp's charset */
    622static bool dfs_path_equal(const char *s1, int len1, const char *s2, int len2)
    623{
    624	int i, l1, l2;
    625	wchar_t c1, c2;
    626
    627	if (len1 != len2)
    628		return false;
    629
    630	for (i = 0; i < len1; i += l1) {
    631		l1 = cache_cp->char2uni(&s1[i], len1 - i, &c1);
    632		l2 = cache_cp->char2uni(&s2[i], len2 - i, &c2);
    633		if (unlikely(l1 < 0 && l2 < 0)) {
    634			if (s1[i] != s2[i])
    635				return false;
    636			l1 = 1;
    637			continue;
    638		}
    639		if (l1 != l2)
    640			return false;
    641		if (cifs_toupper(c1) != cifs_toupper(c2))
    642			return false;
    643	}
    644	return true;
    645}
    646
    647static struct cache_entry *__lookup_cache_entry(const char *path, unsigned int hash, int len)
    648{
    649	struct cache_entry *ce;
    650
    651	hlist_for_each_entry(ce, &cache_htable[hash], hlist) {
    652		if (dfs_path_equal(ce->path, strlen(ce->path), path, len)) {
    653			dump_ce(ce);
    654			return ce;
    655		}
    656	}
    657	return ERR_PTR(-ENOENT);
    658}
    659
    660/*
    661 * Find a DFS cache entry in hash table and optionally check prefix path against normalized @path.
    662 *
    663 * Use whole path components in the match.  Must be called with htable_rw_lock held.
    664 *
    665 * Return ERR_PTR(-ENOENT) if the entry is not found.
    666 */
    667static struct cache_entry *lookup_cache_entry(const char *path)
    668{
    669	struct cache_entry *ce;
    670	int cnt = 0;
    671	const char *s = path, *e;
    672	char sep = *s;
    673	unsigned int hash;
    674	int rc;
    675
    676	while ((s = strchr(s, sep)) && ++cnt < 3)
    677		s++;
    678
    679	if (cnt < 3) {
    680		rc = cache_entry_hash(path, strlen(path), &hash);
    681		if (rc)
    682			return ERR_PTR(rc);
    683		return __lookup_cache_entry(path, hash, strlen(path));
    684	}
    685	/*
    686	 * Handle paths that have more than two path components and are a complete prefix of the DFS
    687	 * referral request path (@path).
    688	 *
    689	 * See MS-DFSC 3.2.5.5 "Receiving a Root Referral Request or Link Referral Request".
    690	 */
    691	e = path + strlen(path) - 1;
    692	while (e > s) {
    693		int len;
    694
    695		/* skip separators */
    696		while (e > s && *e == sep)
    697			e--;
    698		if (e == s)
    699			break;
    700
    701		len = e + 1 - path;
    702		rc = cache_entry_hash(path, len, &hash);
    703		if (rc)
    704			return ERR_PTR(rc);
    705		ce = __lookup_cache_entry(path, hash, len);
    706		if (!IS_ERR(ce))
    707			return ce;
    708
    709		/* backward until separator */
    710		while (e > s && *e != sep)
    711			e--;
    712	}
    713	return ERR_PTR(-ENOENT);
    714}
    715
    716/**
    717 * dfs_cache_destroy - destroy DFS referral cache
    718 */
    719void dfs_cache_destroy(void)
    720{
    721	cancel_delayed_work_sync(&refresh_task);
    722	unload_nls(cache_cp);
    723	free_mount_group_list();
    724	flush_cache_ents();
    725	kmem_cache_destroy(cache_slab);
    726	destroy_workqueue(dfscache_wq);
    727
    728	cifs_dbg(FYI, "%s: destroyed DFS referral cache\n", __func__);
    729}
    730
    731/* Update a cache entry with the new referral in @refs */
    732static int update_cache_entry_locked(struct cache_entry *ce, const struct dfs_info3_param *refs,
    733				     int numrefs)
    734{
    735	int rc;
    736	char *s, *th = NULL;
    737
    738	WARN_ON(!rwsem_is_locked(&htable_rw_lock));
    739
    740	if (ce->tgthint) {
    741		s = ce->tgthint->name;
    742		th = kstrdup(s, GFP_ATOMIC);
    743		if (!th)
    744			return -ENOMEM;
    745	}
    746
    747	free_tgts(ce);
    748	ce->numtgts = 0;
    749
    750	rc = copy_ref_data(refs, numrefs, ce, th);
    751
    752	kfree(th);
    753
    754	return rc;
    755}
    756
    757static int get_dfs_referral(const unsigned int xid, struct cifs_ses *ses, const char *path,
    758			    struct dfs_info3_param **refs, int *numrefs)
    759{
    760	int rc;
    761	int i;
    762
    763	cifs_dbg(FYI, "%s: get an DFS referral for %s\n", __func__, path);
    764
    765	*refs = NULL;
    766	*numrefs = 0;
    767
    768	if (!ses || !ses->server || !ses->server->ops->get_dfs_refer)
    769		return -EOPNOTSUPP;
    770	if (unlikely(!cache_cp))
    771		return -EINVAL;
    772
    773	rc =  ses->server->ops->get_dfs_refer(xid, ses, path, refs, numrefs, cache_cp,
    774					      NO_MAP_UNI_RSVD);
    775	if (!rc) {
    776		struct dfs_info3_param *ref = *refs;
    777
    778		for (i = 0; i < *numrefs; i++)
    779			convert_delimiter(ref[i].path_name, '\\');
    780	}
    781	return rc;
    782}
    783
    784/*
    785 * Find, create or update a DFS cache entry.
    786 *
    787 * If the entry wasn't found, it will create a new one. Or if it was found but
    788 * expired, then it will update the entry accordingly.
    789 *
    790 * For interlinks, cifs_mount() and expand_dfs_referral() are supposed to
    791 * handle them properly.
    792 */
    793static int cache_refresh_path(const unsigned int xid, struct cifs_ses *ses, const char *path)
    794{
    795	int rc;
    796	struct cache_entry *ce;
    797	struct dfs_info3_param *refs = NULL;
    798	int numrefs = 0;
    799	bool newent = false;
    800
    801	cifs_dbg(FYI, "%s: search path: %s\n", __func__, path);
    802
    803	down_write(&htable_rw_lock);
    804
    805	ce = lookup_cache_entry(path);
    806	if (!IS_ERR(ce)) {
    807		if (!cache_entry_expired(ce)) {
    808			dump_ce(ce);
    809			up_write(&htable_rw_lock);
    810			return 0;
    811		}
    812	} else {
    813		newent = true;
    814	}
    815
    816	/*
    817	 * Either the entry was not found, or it is expired.
    818	 * Request a new DFS referral in order to create or update a cache entry.
    819	 */
    820	rc = get_dfs_referral(xid, ses, path, &refs, &numrefs);
    821	if (rc)
    822		goto out_unlock;
    823
    824	dump_refs(refs, numrefs);
    825
    826	if (!newent) {
    827		rc = update_cache_entry_locked(ce, refs, numrefs);
    828		goto out_unlock;
    829	}
    830
    831	rc = add_cache_entry_locked(refs, numrefs);
    832
    833out_unlock:
    834	up_write(&htable_rw_lock);
    835	free_dfs_info_array(refs, numrefs);
    836	return rc;
    837}
    838
    839/*
    840 * Set up a DFS referral from a given cache entry.
    841 *
    842 * Must be called with htable_rw_lock held.
    843 */
    844static int setup_referral(const char *path, struct cache_entry *ce,
    845			  struct dfs_info3_param *ref, const char *target)
    846{
    847	int rc;
    848
    849	cifs_dbg(FYI, "%s: set up new ref\n", __func__);
    850
    851	memset(ref, 0, sizeof(*ref));
    852
    853	ref->path_name = kstrdup(path, GFP_ATOMIC);
    854	if (!ref->path_name)
    855		return -ENOMEM;
    856
    857	ref->node_name = kstrdup(target, GFP_ATOMIC);
    858	if (!ref->node_name) {
    859		rc = -ENOMEM;
    860		goto err_free_path;
    861	}
    862
    863	ref->path_consumed = ce->path_consumed;
    864	ref->ttl = ce->ttl;
    865	ref->server_type = ce->srvtype;
    866	ref->ref_flag = ce->ref_flags;
    867	ref->flags = ce->hdr_flags;
    868
    869	return 0;
    870
    871err_free_path:
    872	kfree(ref->path_name);
    873	ref->path_name = NULL;
    874	return rc;
    875}
    876
    877/* Return target list of a DFS cache entry */
    878static int get_targets(struct cache_entry *ce, struct dfs_cache_tgt_list *tl)
    879{
    880	int rc;
    881	struct list_head *head = &tl->tl_list;
    882	struct cache_dfs_tgt *t;
    883	struct dfs_cache_tgt_iterator *it, *nit;
    884
    885	memset(tl, 0, sizeof(*tl));
    886	INIT_LIST_HEAD(head);
    887
    888	list_for_each_entry(t, &ce->tlist, list) {
    889		it = kzalloc(sizeof(*it), GFP_ATOMIC);
    890		if (!it) {
    891			rc = -ENOMEM;
    892			goto err_free_it;
    893		}
    894
    895		it->it_name = kstrdup(t->name, GFP_ATOMIC);
    896		if (!it->it_name) {
    897			kfree(it);
    898			rc = -ENOMEM;
    899			goto err_free_it;
    900		}
    901		it->it_path_consumed = t->path_consumed;
    902
    903		if (ce->tgthint == t)
    904			list_add(&it->it_list, head);
    905		else
    906			list_add_tail(&it->it_list, head);
    907	}
    908
    909	tl->tl_numtgts = ce->numtgts;
    910
    911	return 0;
    912
    913err_free_it:
    914	list_for_each_entry_safe(it, nit, head, it_list) {
    915		list_del(&it->it_list);
    916		kfree(it->it_name);
    917		kfree(it);
    918	}
    919	return rc;
    920}
    921
    922/**
    923 * dfs_cache_find - find a DFS cache entry
    924 *
    925 * If it doesn't find the cache entry, then it will get a DFS referral
    926 * for @path and create a new entry.
    927 *
    928 * In case the cache entry exists but expired, it will get a DFS referral
    929 * for @path and then update the respective cache entry.
    930 *
    931 * These parameters are passed down to the get_dfs_refer() call if it
    932 * needs to be issued:
    933 * @xid: syscall xid
    934 * @ses: smb session to issue the request on
    935 * @cp: codepage
    936 * @remap: path character remapping type
    937 * @path: path to lookup in DFS referral cache.
    938 *
    939 * @ref: when non-NULL, store single DFS referral result in it.
    940 * @tgt_list: when non-NULL, store complete DFS target list in it.
    941 *
    942 * Return zero if the target was found, otherwise non-zero.
    943 */
    944int dfs_cache_find(const unsigned int xid, struct cifs_ses *ses, const struct nls_table *cp,
    945		   int remap, const char *path, struct dfs_info3_param *ref,
    946		   struct dfs_cache_tgt_list *tgt_list)
    947{
    948	int rc;
    949	const char *npath;
    950	struct cache_entry *ce;
    951
    952	npath = dfs_cache_canonical_path(path, cp, remap);
    953	if (IS_ERR(npath))
    954		return PTR_ERR(npath);
    955
    956	rc = cache_refresh_path(xid, ses, npath);
    957	if (rc)
    958		goto out_free_path;
    959
    960	down_read(&htable_rw_lock);
    961
    962	ce = lookup_cache_entry(npath);
    963	if (IS_ERR(ce)) {
    964		up_read(&htable_rw_lock);
    965		rc = PTR_ERR(ce);
    966		goto out_free_path;
    967	}
    968
    969	if (ref)
    970		rc = setup_referral(path, ce, ref, get_tgt_name(ce));
    971	else
    972		rc = 0;
    973	if (!rc && tgt_list)
    974		rc = get_targets(ce, tgt_list);
    975
    976	up_read(&htable_rw_lock);
    977
    978out_free_path:
    979	kfree(npath);
    980	return rc;
    981}
    982
    983/**
    984 * dfs_cache_noreq_find - find a DFS cache entry without sending any requests to
    985 * the currently connected server.
    986 *
    987 * NOTE: This function will neither update a cache entry in case it was
    988 * expired, nor create a new cache entry if @path hasn't been found. It heavily
    989 * relies on an existing cache entry.
    990 *
    991 * @path: canonical DFS path to lookup in the DFS referral cache.
    992 * @ref: when non-NULL, store single DFS referral result in it.
    993 * @tgt_list: when non-NULL, store complete DFS target list in it.
    994 *
    995 * Return 0 if successful.
    996 * Return -ENOENT if the entry was not found.
    997 * Return non-zero for other errors.
    998 */
    999int dfs_cache_noreq_find(const char *path, struct dfs_info3_param *ref,
   1000			 struct dfs_cache_tgt_list *tgt_list)
   1001{
   1002	int rc;
   1003	struct cache_entry *ce;
   1004
   1005	cifs_dbg(FYI, "%s: path: %s\n", __func__, path);
   1006
   1007	down_read(&htable_rw_lock);
   1008
   1009	ce = lookup_cache_entry(path);
   1010	if (IS_ERR(ce)) {
   1011		rc = PTR_ERR(ce);
   1012		goto out_unlock;
   1013	}
   1014
   1015	if (ref)
   1016		rc = setup_referral(path, ce, ref, get_tgt_name(ce));
   1017	else
   1018		rc = 0;
   1019	if (!rc && tgt_list)
   1020		rc = get_targets(ce, tgt_list);
   1021
   1022out_unlock:
   1023	up_read(&htable_rw_lock);
   1024	return rc;
   1025}
   1026
   1027/**
   1028 * dfs_cache_update_tgthint - update target hint of a DFS cache entry
   1029 *
   1030 * If it doesn't find the cache entry, then it will get a DFS referral for @path
   1031 * and create a new entry.
   1032 *
   1033 * In case the cache entry exists but expired, it will get a DFS referral
   1034 * for @path and then update the respective cache entry.
   1035 *
   1036 * @xid: syscall id
   1037 * @ses: smb session
   1038 * @cp: codepage
   1039 * @remap: type of character remapping for paths
   1040 * @path: path to lookup in DFS referral cache
   1041 * @it: DFS target iterator
   1042 *
   1043 * Return zero if the target hint was updated successfully, otherwise non-zero.
   1044 */
   1045int dfs_cache_update_tgthint(const unsigned int xid, struct cifs_ses *ses,
   1046			     const struct nls_table *cp, int remap, const char *path,
   1047			     const struct dfs_cache_tgt_iterator *it)
   1048{
   1049	int rc;
   1050	const char *npath;
   1051	struct cache_entry *ce;
   1052	struct cache_dfs_tgt *t;
   1053
   1054	npath = dfs_cache_canonical_path(path, cp, remap);
   1055	if (IS_ERR(npath))
   1056		return PTR_ERR(npath);
   1057
   1058	cifs_dbg(FYI, "%s: update target hint - path: %s\n", __func__, npath);
   1059
   1060	rc = cache_refresh_path(xid, ses, npath);
   1061	if (rc)
   1062		goto out_free_path;
   1063
   1064	down_write(&htable_rw_lock);
   1065
   1066	ce = lookup_cache_entry(npath);
   1067	if (IS_ERR(ce)) {
   1068		rc = PTR_ERR(ce);
   1069		goto out_unlock;
   1070	}
   1071
   1072	t = ce->tgthint;
   1073
   1074	if (likely(!strcasecmp(it->it_name, t->name)))
   1075		goto out_unlock;
   1076
   1077	list_for_each_entry(t, &ce->tlist, list) {
   1078		if (!strcasecmp(t->name, it->it_name)) {
   1079			ce->tgthint = t;
   1080			cifs_dbg(FYI, "%s: new target hint: %s\n", __func__,
   1081				 it->it_name);
   1082			break;
   1083		}
   1084	}
   1085
   1086out_unlock:
   1087	up_write(&htable_rw_lock);
   1088out_free_path:
   1089	kfree(npath);
   1090	return rc;
   1091}
   1092
   1093/**
   1094 * dfs_cache_noreq_update_tgthint - update target hint of a DFS cache entry
   1095 * without sending any requests to the currently connected server.
   1096 *
   1097 * NOTE: This function will neither update a cache entry in case it was
   1098 * expired, nor create a new cache entry if @path hasn't been found. It heavily
   1099 * relies on an existing cache entry.
   1100 *
   1101 * @path: canonical DFS path to lookup in DFS referral cache.
   1102 * @it: target iterator which contains the target hint to update the cache
   1103 * entry with.
   1104 *
   1105 * Return zero if the target hint was updated successfully, otherwise non-zero.
   1106 */
   1107int dfs_cache_noreq_update_tgthint(const char *path, const struct dfs_cache_tgt_iterator *it)
   1108{
   1109	int rc;
   1110	struct cache_entry *ce;
   1111	struct cache_dfs_tgt *t;
   1112
   1113	if (!it)
   1114		return -EINVAL;
   1115
   1116	cifs_dbg(FYI, "%s: path: %s\n", __func__, path);
   1117
   1118	down_write(&htable_rw_lock);
   1119
   1120	ce = lookup_cache_entry(path);
   1121	if (IS_ERR(ce)) {
   1122		rc = PTR_ERR(ce);
   1123		goto out_unlock;
   1124	}
   1125
   1126	rc = 0;
   1127	t = ce->tgthint;
   1128
   1129	if (unlikely(!strcasecmp(it->it_name, t->name)))
   1130		goto out_unlock;
   1131
   1132	list_for_each_entry(t, &ce->tlist, list) {
   1133		if (!strcasecmp(t->name, it->it_name)) {
   1134			ce->tgthint = t;
   1135			cifs_dbg(FYI, "%s: new target hint: %s\n", __func__,
   1136				 it->it_name);
   1137			break;
   1138		}
   1139	}
   1140
   1141out_unlock:
   1142	up_write(&htable_rw_lock);
   1143	return rc;
   1144}
   1145
   1146/**
   1147 * dfs_cache_get_tgt_referral - returns a DFS referral (@ref) from a given
   1148 * target iterator (@it).
   1149 *
   1150 * @path: canonical DFS path to lookup in DFS referral cache.
   1151 * @it: DFS target iterator.
   1152 * @ref: DFS referral pointer to set up the gathered information.
   1153 *
   1154 * Return zero if the DFS referral was set up correctly, otherwise non-zero.
   1155 */
   1156int dfs_cache_get_tgt_referral(const char *path, const struct dfs_cache_tgt_iterator *it,
   1157			       struct dfs_info3_param *ref)
   1158{
   1159	int rc;
   1160	struct cache_entry *ce;
   1161
   1162	if (!it || !ref)
   1163		return -EINVAL;
   1164
   1165	cifs_dbg(FYI, "%s: path: %s\n", __func__, path);
   1166
   1167	down_read(&htable_rw_lock);
   1168
   1169	ce = lookup_cache_entry(path);
   1170	if (IS_ERR(ce)) {
   1171		rc = PTR_ERR(ce);
   1172		goto out_unlock;
   1173	}
   1174
   1175	cifs_dbg(FYI, "%s: target name: %s\n", __func__, it->it_name);
   1176
   1177	rc = setup_referral(path, ce, ref, it->it_name);
   1178
   1179out_unlock:
   1180	up_read(&htable_rw_lock);
   1181	return rc;
   1182}
   1183
   1184/**
   1185 * dfs_cache_add_refsrv_session - add SMB session of referral server
   1186 *
   1187 * @mount_id: mount group uuid to lookup.
   1188 * @ses: reference counted SMB session of referral server.
   1189 */
   1190void dfs_cache_add_refsrv_session(const uuid_t *mount_id, struct cifs_ses *ses)
   1191{
   1192	struct mount_group *mg;
   1193
   1194	if (WARN_ON_ONCE(!mount_id || uuid_is_null(mount_id) || !ses))
   1195		return;
   1196
   1197	mg = get_mount_group(mount_id);
   1198	if (WARN_ON_ONCE(IS_ERR(mg)))
   1199		return;
   1200
   1201	spin_lock(&mg->lock);
   1202	if (mg->num_sessions < ARRAY_SIZE(mg->sessions))
   1203		mg->sessions[mg->num_sessions++] = ses;
   1204	spin_unlock(&mg->lock);
   1205	kref_put(&mg->refcount, mount_group_release);
   1206}
   1207
   1208/**
   1209 * dfs_cache_put_refsrv_sessions - put all referral server sessions
   1210 *
   1211 * Put all SMB sessions from the given mount group id.
   1212 *
   1213 * @mount_id: mount group uuid to lookup.
   1214 */
   1215void dfs_cache_put_refsrv_sessions(const uuid_t *mount_id)
   1216{
   1217	struct mount_group *mg;
   1218
   1219	if (!mount_id || uuid_is_null(mount_id))
   1220		return;
   1221
   1222	mutex_lock(&mount_group_list_lock);
   1223	mg = find_mount_group_locked(mount_id);
   1224	if (IS_ERR(mg)) {
   1225		mutex_unlock(&mount_group_list_lock);
   1226		return;
   1227	}
   1228	mutex_unlock(&mount_group_list_lock);
   1229	kref_put(&mg->refcount, mount_group_release);
   1230}
   1231
   1232/* Extract share from DFS target and return a pointer to prefix path or NULL */
   1233static const char *parse_target_share(const char *target, char **share)
   1234{
   1235	const char *s, *seps = "/\\";
   1236	size_t len;
   1237
   1238	s = strpbrk(target + 1, seps);
   1239	if (!s)
   1240		return ERR_PTR(-EINVAL);
   1241
   1242	len = strcspn(s + 1, seps);
   1243	if (!len)
   1244		return ERR_PTR(-EINVAL);
   1245	s += len;
   1246
   1247	len = s - target + 1;
   1248	*share = kstrndup(target, len, GFP_KERNEL);
   1249	if (!*share)
   1250		return ERR_PTR(-ENOMEM);
   1251
   1252	s = target + len;
   1253	return s + strspn(s, seps);
   1254}
   1255
   1256/**
   1257 * dfs_cache_get_tgt_share - parse a DFS target
   1258 *
   1259 * @path: DFS full path
   1260 * @it: DFS target iterator.
   1261 * @share: tree name.
   1262 * @prefix: prefix path.
   1263 *
   1264 * Return zero if target was parsed correctly, otherwise non-zero.
   1265 */
   1266int dfs_cache_get_tgt_share(char *path, const struct dfs_cache_tgt_iterator *it, char **share,
   1267			    char **prefix)
   1268{
   1269	char sep;
   1270	char *target_share;
   1271	char *ppath = NULL;
   1272	const char *target_ppath, *dfsref_ppath;
   1273	size_t target_pplen, dfsref_pplen;
   1274	size_t len, c;
   1275
   1276	if (!it || !path || !share || !prefix || strlen(path) < it->it_path_consumed)
   1277		return -EINVAL;
   1278
   1279	sep = it->it_name[0];
   1280	if (sep != '\\' && sep != '/')
   1281		return -EINVAL;
   1282
   1283	target_ppath = parse_target_share(it->it_name, &target_share);
   1284	if (IS_ERR(target_ppath))
   1285		return PTR_ERR(target_ppath);
   1286
   1287	/* point to prefix in DFS referral path */
   1288	dfsref_ppath = path + it->it_path_consumed;
   1289	dfsref_ppath += strspn(dfsref_ppath, "/\\");
   1290
   1291	target_pplen = strlen(target_ppath);
   1292	dfsref_pplen = strlen(dfsref_ppath);
   1293
   1294	/* merge prefix paths from DFS referral path and target node */
   1295	if (target_pplen || dfsref_pplen) {
   1296		len = target_pplen + dfsref_pplen + 2;
   1297		ppath = kzalloc(len, GFP_KERNEL);
   1298		if (!ppath) {
   1299			kfree(target_share);
   1300			return -ENOMEM;
   1301		}
   1302		c = strscpy(ppath, target_ppath, len);
   1303		if (c && dfsref_pplen)
   1304			ppath[c] = sep;
   1305		strlcat(ppath, dfsref_ppath, len);
   1306	}
   1307	*share = target_share;
   1308	*prefix = ppath;
   1309	return 0;
   1310}
   1311
   1312static bool target_share_equal(struct TCP_Server_Info *server, const char *s1, const char *s2)
   1313{
   1314	char unc[sizeof("\\\\") + SERVER_NAME_LENGTH] = {0};
   1315	const char *host;
   1316	size_t hostlen;
   1317	char *ip = NULL;
   1318	struct sockaddr sa;
   1319	bool match;
   1320	int rc;
   1321
   1322	if (strcasecmp(s1, s2))
   1323		return false;
   1324
   1325	/*
   1326	 * Resolve share's hostname and check if server address matches.  Otherwise just ignore it
   1327	 * as we could not have upcall to resolve hostname or failed to convert ip address.
   1328	 */
   1329	match = true;
   1330	extract_unc_hostname(s1, &host, &hostlen);
   1331	scnprintf(unc, sizeof(unc), "\\\\%.*s", (int)hostlen, host);
   1332
   1333	rc = dns_resolve_server_name_to_ip(unc, &ip, NULL);
   1334	if (rc < 0) {
   1335		cifs_dbg(FYI, "%s: could not resolve %.*s. assuming server address matches.\n",
   1336			 __func__, (int)hostlen, host);
   1337		return true;
   1338	}
   1339
   1340	if (!cifs_convert_address(&sa, ip, strlen(ip))) {
   1341		cifs_dbg(VFS, "%s: failed to convert address \'%s\'. skip address matching.\n",
   1342			 __func__, ip);
   1343	} else {
   1344		cifs_server_lock(server);
   1345		match = cifs_match_ipaddr((struct sockaddr *)&server->dstaddr, &sa);
   1346		cifs_server_unlock(server);
   1347	}
   1348
   1349	kfree(ip);
   1350	return match;
   1351}
   1352
   1353/*
   1354 * Mark dfs tcon for reconnecting when the currently connected tcon does not match any of the new
   1355 * target shares in @refs.
   1356 */
   1357static void mark_for_reconnect_if_needed(struct cifs_tcon *tcon, struct dfs_cache_tgt_list *tl,
   1358					 const struct dfs_info3_param *refs, int numrefs)
   1359{
   1360	struct dfs_cache_tgt_iterator *it;
   1361	int i;
   1362
   1363	for (it = dfs_cache_get_tgt_iterator(tl); it; it = dfs_cache_get_next_tgt(tl, it)) {
   1364		for (i = 0; i < numrefs; i++) {
   1365			if (target_share_equal(tcon->ses->server, dfs_cache_get_tgt_name(it),
   1366					       refs[i].node_name))
   1367				return;
   1368		}
   1369	}
   1370
   1371	cifs_dbg(FYI, "%s: no cached or matched targets. mark dfs share for reconnect.\n", __func__);
   1372	cifs_signal_cifsd_for_reconnect(tcon->ses->server, true);
   1373}
   1374
   1375/* Refresh dfs referral of tcon and mark it for reconnect if needed */
   1376static int __refresh_tcon(const char *path, struct cifs_ses **sessions, struct cifs_tcon *tcon,
   1377			  bool force_refresh)
   1378{
   1379	struct cifs_ses *ses;
   1380	struct cache_entry *ce;
   1381	struct dfs_info3_param *refs = NULL;
   1382	int numrefs = 0;
   1383	bool needs_refresh = false;
   1384	struct dfs_cache_tgt_list tl = DFS_CACHE_TGT_LIST_INIT(tl);
   1385	int rc = 0;
   1386	unsigned int xid;
   1387
   1388	ses = find_ipc_from_server_path(sessions, path);
   1389	if (IS_ERR(ses)) {
   1390		cifs_dbg(FYI, "%s: could not find ipc session\n", __func__);
   1391		return PTR_ERR(ses);
   1392	}
   1393
   1394	down_read(&htable_rw_lock);
   1395	ce = lookup_cache_entry(path);
   1396	needs_refresh = force_refresh || IS_ERR(ce) || cache_entry_expired(ce);
   1397	if (!IS_ERR(ce)) {
   1398		rc = get_targets(ce, &tl);
   1399		if (rc)
   1400			cifs_dbg(FYI, "%s: could not get dfs targets: %d\n", __func__, rc);
   1401	}
   1402	up_read(&htable_rw_lock);
   1403
   1404	if (!needs_refresh) {
   1405		rc = 0;
   1406		goto out;
   1407	}
   1408
   1409	xid = get_xid();
   1410	rc = get_dfs_referral(xid, ses, path, &refs, &numrefs);
   1411	free_xid(xid);
   1412
   1413	/* Create or update a cache entry with the new referral */
   1414	if (!rc) {
   1415		dump_refs(refs, numrefs);
   1416
   1417		down_write(&htable_rw_lock);
   1418		ce = lookup_cache_entry(path);
   1419		if (IS_ERR(ce))
   1420			add_cache_entry_locked(refs, numrefs);
   1421		else if (force_refresh || cache_entry_expired(ce))
   1422			update_cache_entry_locked(ce, refs, numrefs);
   1423		up_write(&htable_rw_lock);
   1424
   1425		mark_for_reconnect_if_needed(tcon, &tl, refs, numrefs);
   1426	}
   1427
   1428out:
   1429	dfs_cache_free_tgts(&tl);
   1430	free_dfs_info_array(refs, numrefs);
   1431	return rc;
   1432}
   1433
   1434static int refresh_tcon(struct cifs_ses **sessions, struct cifs_tcon *tcon, bool force_refresh)
   1435{
   1436	struct TCP_Server_Info *server = tcon->ses->server;
   1437
   1438	mutex_lock(&server->refpath_lock);
   1439	if (server->origin_fullpath) {
   1440		if (server->leaf_fullpath && strcasecmp(server->leaf_fullpath,
   1441							server->origin_fullpath))
   1442			__refresh_tcon(server->leaf_fullpath + 1, sessions, tcon, force_refresh);
   1443		__refresh_tcon(server->origin_fullpath + 1, sessions, tcon, force_refresh);
   1444	}
   1445	mutex_unlock(&server->refpath_lock);
   1446
   1447	return 0;
   1448}
   1449
   1450/**
   1451 * dfs_cache_remount_fs - remount a DFS share
   1452 *
   1453 * Reconfigure dfs mount by forcing a new DFS referral and if the currently cached targets do not
   1454 * match any of the new targets, mark it for reconnect.
   1455 *
   1456 * @cifs_sb: cifs superblock.
   1457 *
   1458 * Return zero if remounted, otherwise non-zero.
   1459 */
   1460int dfs_cache_remount_fs(struct cifs_sb_info *cifs_sb)
   1461{
   1462	struct cifs_tcon *tcon;
   1463	struct TCP_Server_Info *server;
   1464	struct mount_group *mg;
   1465	struct cifs_ses *sessions[CACHE_MAX_ENTRIES + 1] = {NULL};
   1466	int rc;
   1467
   1468	if (!cifs_sb || !cifs_sb->master_tlink)
   1469		return -EINVAL;
   1470
   1471	tcon = cifs_sb_master_tcon(cifs_sb);
   1472	server = tcon->ses->server;
   1473
   1474	if (!server->origin_fullpath) {
   1475		cifs_dbg(FYI, "%s: not a dfs mount\n", __func__);
   1476		return 0;
   1477	}
   1478
   1479	if (uuid_is_null(&cifs_sb->dfs_mount_id)) {
   1480		cifs_dbg(FYI, "%s: no dfs mount group id\n", __func__);
   1481		return -EINVAL;
   1482	}
   1483
   1484	mutex_lock(&mount_group_list_lock);
   1485	mg = find_mount_group_locked(&cifs_sb->dfs_mount_id);
   1486	if (IS_ERR(mg)) {
   1487		mutex_unlock(&mount_group_list_lock);
   1488		cifs_dbg(FYI, "%s: no ipc session for refreshing referral\n", __func__);
   1489		return PTR_ERR(mg);
   1490	}
   1491	kref_get(&mg->refcount);
   1492	mutex_unlock(&mount_group_list_lock);
   1493
   1494	spin_lock(&mg->lock);
   1495	memcpy(&sessions, mg->sessions, mg->num_sessions * sizeof(mg->sessions[0]));
   1496	spin_unlock(&mg->lock);
   1497
   1498	/*
   1499	 * After reconnecting to a different server, unique ids won't match anymore, so we disable
   1500	 * serverino. This prevents dentry revalidation to think the dentry are stale (ESTALE).
   1501	 */
   1502	cifs_autodisable_serverino(cifs_sb);
   1503	/*
   1504	 * Force the use of prefix path to support failover on DFS paths that resolve to targets
   1505	 * that have different prefix paths.
   1506	 */
   1507	cifs_sb->mnt_cifs_flags |= CIFS_MOUNT_USE_PREFIX_PATH;
   1508	rc = refresh_tcon(sessions, tcon, true);
   1509
   1510	kref_put(&mg->refcount, mount_group_release);
   1511	return rc;
   1512}
   1513
   1514/*
   1515 * Refresh all active dfs mounts regardless of whether they are in cache or not.
   1516 * (cache can be cleared)
   1517 */
   1518static void refresh_mounts(struct cifs_ses **sessions)
   1519{
   1520	struct TCP_Server_Info *server;
   1521	struct cifs_ses *ses;
   1522	struct cifs_tcon *tcon, *ntcon;
   1523	struct list_head tcons;
   1524
   1525	INIT_LIST_HEAD(&tcons);
   1526
   1527	spin_lock(&cifs_tcp_ses_lock);
   1528	list_for_each_entry(server, &cifs_tcp_ses_list, tcp_ses_list) {
   1529		if (!server->is_dfs_conn)
   1530			continue;
   1531
   1532		list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
   1533			list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
   1534				if (!tcon->ipc && !tcon->need_reconnect) {
   1535					tcon->tc_count++;
   1536					list_add_tail(&tcon->ulist, &tcons);
   1537				}
   1538			}
   1539		}
   1540	}
   1541	spin_unlock(&cifs_tcp_ses_lock);
   1542
   1543	list_for_each_entry_safe(tcon, ntcon, &tcons, ulist) {
   1544		struct TCP_Server_Info *server = tcon->ses->server;
   1545
   1546		list_del_init(&tcon->ulist);
   1547
   1548		mutex_lock(&server->refpath_lock);
   1549		if (server->origin_fullpath) {
   1550			if (server->leaf_fullpath && strcasecmp(server->leaf_fullpath,
   1551								server->origin_fullpath))
   1552				__refresh_tcon(server->leaf_fullpath + 1, sessions, tcon, false);
   1553			__refresh_tcon(server->origin_fullpath + 1, sessions, tcon, false);
   1554		}
   1555		mutex_unlock(&server->refpath_lock);
   1556
   1557		cifs_put_tcon(tcon);
   1558	}
   1559}
   1560
   1561static void refresh_cache(struct cifs_ses **sessions)
   1562{
   1563	int i;
   1564	struct cifs_ses *ses;
   1565	unsigned int xid;
   1566	char *ref_paths[CACHE_MAX_ENTRIES];
   1567	int count = 0;
   1568	struct cache_entry *ce;
   1569
   1570	/*
   1571	 * Refresh all cached entries.  Get all new referrals outside critical section to avoid
   1572	 * starvation while performing SMB2 IOCTL on broken or slow connections.
   1573
   1574	 * The cache entries may cover more paths than the active mounts
   1575	 * (e.g. domain-based DFS referrals or multi tier DFS setups).
   1576	 */
   1577	down_read(&htable_rw_lock);
   1578	for (i = 0; i < CACHE_HTABLE_SIZE; i++) {
   1579		struct hlist_head *l = &cache_htable[i];
   1580
   1581		hlist_for_each_entry(ce, l, hlist) {
   1582			if (count == ARRAY_SIZE(ref_paths))
   1583				goto out_unlock;
   1584			if (hlist_unhashed(&ce->hlist) || !cache_entry_expired(ce) ||
   1585			    IS_ERR(find_ipc_from_server_path(sessions, ce->path)))
   1586				continue;
   1587			ref_paths[count++] = kstrdup(ce->path, GFP_ATOMIC);
   1588		}
   1589	}
   1590
   1591out_unlock:
   1592	up_read(&htable_rw_lock);
   1593
   1594	for (i = 0; i < count; i++) {
   1595		char *path = ref_paths[i];
   1596		struct dfs_info3_param *refs = NULL;
   1597		int numrefs = 0;
   1598		int rc = 0;
   1599
   1600		if (!path)
   1601			continue;
   1602
   1603		ses = find_ipc_from_server_path(sessions, path);
   1604		if (IS_ERR(ses))
   1605			goto next_referral;
   1606
   1607		xid = get_xid();
   1608		rc = get_dfs_referral(xid, ses, path, &refs, &numrefs);
   1609		free_xid(xid);
   1610
   1611		if (!rc) {
   1612			down_write(&htable_rw_lock);
   1613			ce = lookup_cache_entry(path);
   1614			/*
   1615			 * We need to re-check it because other tasks might have it deleted or
   1616			 * updated.
   1617			 */
   1618			if (!IS_ERR(ce) && cache_entry_expired(ce))
   1619				update_cache_entry_locked(ce, refs, numrefs);
   1620			up_write(&htable_rw_lock);
   1621		}
   1622
   1623next_referral:
   1624		kfree(path);
   1625		free_dfs_info_array(refs, numrefs);
   1626	}
   1627}
   1628
   1629/*
   1630 * Worker that will refresh DFS cache and active mounts based on lowest TTL value from a DFS
   1631 * referral.
   1632 */
   1633static void refresh_cache_worker(struct work_struct *work)
   1634{
   1635	struct list_head mglist;
   1636	struct mount_group *mg, *tmp_mg;
   1637	struct cifs_ses *sessions[CACHE_MAX_ENTRIES + 1] = {NULL};
   1638	int max_sessions = ARRAY_SIZE(sessions) - 1;
   1639	int i = 0, count;
   1640
   1641	INIT_LIST_HEAD(&mglist);
   1642
   1643	/* Get refereces of mount groups */
   1644	mutex_lock(&mount_group_list_lock);
   1645	list_for_each_entry(mg, &mount_group_list, list) {
   1646		kref_get(&mg->refcount);
   1647		list_add(&mg->refresh_list, &mglist);
   1648	}
   1649	mutex_unlock(&mount_group_list_lock);
   1650
   1651	/* Fill in local array with an NULL-terminated list of all referral server sessions */
   1652	list_for_each_entry(mg, &mglist, refresh_list) {
   1653		if (i >= max_sessions)
   1654			break;
   1655
   1656		spin_lock(&mg->lock);
   1657		if (i + mg->num_sessions > max_sessions)
   1658			count = max_sessions - i;
   1659		else
   1660			count = mg->num_sessions;
   1661		memcpy(&sessions[i], mg->sessions, count * sizeof(mg->sessions[0]));
   1662		spin_unlock(&mg->lock);
   1663		i += count;
   1664	}
   1665
   1666	if (sessions[0]) {
   1667		/* Refresh all active mounts and cached entries */
   1668		refresh_mounts(sessions);
   1669		refresh_cache(sessions);
   1670	}
   1671
   1672	list_for_each_entry_safe(mg, tmp_mg, &mglist, refresh_list) {
   1673		list_del_init(&mg->refresh_list);
   1674		kref_put(&mg->refcount, mount_group_release);
   1675	}
   1676
   1677	spin_lock(&cache_ttl_lock);
   1678	queue_delayed_work(dfscache_wq, &refresh_task, cache_ttl * HZ);
   1679	spin_unlock(&cache_ttl_lock);
   1680}