cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

rate.c (25978B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Copyright 2002-2005, Instant802 Networks, Inc.
      4 * Copyright 2005-2006, Devicescape Software, Inc.
      5 * Copyright (c) 2006 Jiri Benc <jbenc@suse.cz>
      6 * Copyright 2017	Intel Deutschland GmbH
      7 */
      8
      9#include <linux/kernel.h>
     10#include <linux/rtnetlink.h>
     11#include <linux/module.h>
     12#include <linux/slab.h>
     13#include "rate.h"
     14#include "ieee80211_i.h"
     15#include "debugfs.h"
     16
     17struct rate_control_alg {
     18	struct list_head list;
     19	const struct rate_control_ops *ops;
     20};
     21
     22static LIST_HEAD(rate_ctrl_algs);
     23static DEFINE_MUTEX(rate_ctrl_mutex);
     24
     25static char *ieee80211_default_rc_algo = CONFIG_MAC80211_RC_DEFAULT;
     26module_param(ieee80211_default_rc_algo, charp, 0644);
     27MODULE_PARM_DESC(ieee80211_default_rc_algo,
     28		 "Default rate control algorithm for mac80211 to use");
     29
     30void rate_control_rate_init(struct sta_info *sta)
     31{
     32	struct ieee80211_local *local = sta->sdata->local;
     33	struct rate_control_ref *ref = sta->rate_ctrl;
     34	struct ieee80211_sta *ista = &sta->sta;
     35	void *priv_sta = sta->rate_ctrl_priv;
     36	struct ieee80211_supported_band *sband;
     37	struct ieee80211_chanctx_conf *chanctx_conf;
     38
     39	ieee80211_sta_set_rx_nss(sta);
     40
     41	if (!ref)
     42		return;
     43
     44	rcu_read_lock();
     45
     46	chanctx_conf = rcu_dereference(sta->sdata->vif.chanctx_conf);
     47	if (WARN_ON(!chanctx_conf)) {
     48		rcu_read_unlock();
     49		return;
     50	}
     51
     52	sband = local->hw.wiphy->bands[chanctx_conf->def.chan->band];
     53
     54	/* TODO: check for minstrel_s1g ? */
     55	if (sband->band == NL80211_BAND_S1GHZ) {
     56		ieee80211_s1g_sta_rate_init(sta);
     57		rcu_read_unlock();
     58		return;
     59	}
     60
     61	spin_lock_bh(&sta->rate_ctrl_lock);
     62	ref->ops->rate_init(ref->priv, sband, &chanctx_conf->def, ista,
     63			    priv_sta);
     64	spin_unlock_bh(&sta->rate_ctrl_lock);
     65	rcu_read_unlock();
     66	set_sta_flag(sta, WLAN_STA_RATE_CONTROL);
     67}
     68
     69void rate_control_tx_status(struct ieee80211_local *local,
     70			    struct ieee80211_supported_band *sband,
     71			    struct ieee80211_tx_status *st)
     72{
     73	struct rate_control_ref *ref = local->rate_ctrl;
     74	struct sta_info *sta = container_of(st->sta, struct sta_info, sta);
     75	void *priv_sta = sta->rate_ctrl_priv;
     76
     77	if (!ref || !test_sta_flag(sta, WLAN_STA_RATE_CONTROL))
     78		return;
     79
     80	spin_lock_bh(&sta->rate_ctrl_lock);
     81	if (ref->ops->tx_status_ext)
     82		ref->ops->tx_status_ext(ref->priv, sband, priv_sta, st);
     83	else if (st->skb)
     84		ref->ops->tx_status(ref->priv, sband, st->sta, priv_sta, st->skb);
     85	else
     86		WARN_ON_ONCE(1);
     87
     88	spin_unlock_bh(&sta->rate_ctrl_lock);
     89}
     90
     91void rate_control_rate_update(struct ieee80211_local *local,
     92				    struct ieee80211_supported_band *sband,
     93				    struct sta_info *sta, u32 changed)
     94{
     95	struct rate_control_ref *ref = local->rate_ctrl;
     96	struct ieee80211_sta *ista = &sta->sta;
     97	void *priv_sta = sta->rate_ctrl_priv;
     98	struct ieee80211_chanctx_conf *chanctx_conf;
     99
    100	if (ref && ref->ops->rate_update) {
    101		rcu_read_lock();
    102
    103		chanctx_conf = rcu_dereference(sta->sdata->vif.chanctx_conf);
    104		if (WARN_ON(!chanctx_conf)) {
    105			rcu_read_unlock();
    106			return;
    107		}
    108
    109		spin_lock_bh(&sta->rate_ctrl_lock);
    110		ref->ops->rate_update(ref->priv, sband, &chanctx_conf->def,
    111				      ista, priv_sta, changed);
    112		spin_unlock_bh(&sta->rate_ctrl_lock);
    113		rcu_read_unlock();
    114	}
    115	drv_sta_rc_update(local, sta->sdata, &sta->sta, changed);
    116}
    117
    118int ieee80211_rate_control_register(const struct rate_control_ops *ops)
    119{
    120	struct rate_control_alg *alg;
    121
    122	if (!ops->name)
    123		return -EINVAL;
    124
    125	mutex_lock(&rate_ctrl_mutex);
    126	list_for_each_entry(alg, &rate_ctrl_algs, list) {
    127		if (!strcmp(alg->ops->name, ops->name)) {
    128			/* don't register an algorithm twice */
    129			WARN_ON(1);
    130			mutex_unlock(&rate_ctrl_mutex);
    131			return -EALREADY;
    132		}
    133	}
    134
    135	alg = kzalloc(sizeof(*alg), GFP_KERNEL);
    136	if (alg == NULL) {
    137		mutex_unlock(&rate_ctrl_mutex);
    138		return -ENOMEM;
    139	}
    140	alg->ops = ops;
    141
    142	list_add_tail(&alg->list, &rate_ctrl_algs);
    143	mutex_unlock(&rate_ctrl_mutex);
    144
    145	return 0;
    146}
    147EXPORT_SYMBOL(ieee80211_rate_control_register);
    148
    149void ieee80211_rate_control_unregister(const struct rate_control_ops *ops)
    150{
    151	struct rate_control_alg *alg;
    152
    153	mutex_lock(&rate_ctrl_mutex);
    154	list_for_each_entry(alg, &rate_ctrl_algs, list) {
    155		if (alg->ops == ops) {
    156			list_del(&alg->list);
    157			kfree(alg);
    158			break;
    159		}
    160	}
    161	mutex_unlock(&rate_ctrl_mutex);
    162}
    163EXPORT_SYMBOL(ieee80211_rate_control_unregister);
    164
    165static const struct rate_control_ops *
    166ieee80211_try_rate_control_ops_get(const char *name)
    167{
    168	struct rate_control_alg *alg;
    169	const struct rate_control_ops *ops = NULL;
    170
    171	if (!name)
    172		return NULL;
    173
    174	mutex_lock(&rate_ctrl_mutex);
    175	list_for_each_entry(alg, &rate_ctrl_algs, list) {
    176		if (!strcmp(alg->ops->name, name)) {
    177			ops = alg->ops;
    178			break;
    179		}
    180	}
    181	mutex_unlock(&rate_ctrl_mutex);
    182	return ops;
    183}
    184
    185/* Get the rate control algorithm. */
    186static const struct rate_control_ops *
    187ieee80211_rate_control_ops_get(const char *name)
    188{
    189	const struct rate_control_ops *ops;
    190	const char *alg_name;
    191
    192	kernel_param_lock(THIS_MODULE);
    193	if (!name)
    194		alg_name = ieee80211_default_rc_algo;
    195	else
    196		alg_name = name;
    197
    198	ops = ieee80211_try_rate_control_ops_get(alg_name);
    199	if (!ops && name)
    200		/* try default if specific alg requested but not found */
    201		ops = ieee80211_try_rate_control_ops_get(ieee80211_default_rc_algo);
    202
    203	/* Note: check for > 0 is intentional to avoid clang warning */
    204	if (!ops && (strlen(CONFIG_MAC80211_RC_DEFAULT) > 0))
    205		/* try built-in one if specific alg requested but not found */
    206		ops = ieee80211_try_rate_control_ops_get(CONFIG_MAC80211_RC_DEFAULT);
    207
    208	kernel_param_unlock(THIS_MODULE);
    209
    210	return ops;
    211}
    212
    213#ifdef CONFIG_MAC80211_DEBUGFS
    214static ssize_t rcname_read(struct file *file, char __user *userbuf,
    215			   size_t count, loff_t *ppos)
    216{
    217	struct rate_control_ref *ref = file->private_data;
    218	int len = strlen(ref->ops->name);
    219
    220	return simple_read_from_buffer(userbuf, count, ppos,
    221				       ref->ops->name, len);
    222}
    223
    224const struct file_operations rcname_ops = {
    225	.read = rcname_read,
    226	.open = simple_open,
    227	.llseek = default_llseek,
    228};
    229#endif
    230
    231static struct rate_control_ref *
    232rate_control_alloc(const char *name, struct ieee80211_local *local)
    233{
    234	struct rate_control_ref *ref;
    235
    236	ref = kmalloc(sizeof(struct rate_control_ref), GFP_KERNEL);
    237	if (!ref)
    238		return NULL;
    239	ref->ops = ieee80211_rate_control_ops_get(name);
    240	if (!ref->ops)
    241		goto free;
    242
    243	ref->priv = ref->ops->alloc(&local->hw);
    244	if (!ref->priv)
    245		goto free;
    246	return ref;
    247
    248free:
    249	kfree(ref);
    250	return NULL;
    251}
    252
    253static void rate_control_free(struct ieee80211_local *local,
    254			      struct rate_control_ref *ctrl_ref)
    255{
    256	ctrl_ref->ops->free(ctrl_ref->priv);
    257
    258#ifdef CONFIG_MAC80211_DEBUGFS
    259	debugfs_remove_recursive(local->debugfs.rcdir);
    260	local->debugfs.rcdir = NULL;
    261#endif
    262
    263	kfree(ctrl_ref);
    264}
    265
    266void ieee80211_check_rate_mask(struct ieee80211_sub_if_data *sdata)
    267{
    268	struct ieee80211_local *local = sdata->local;
    269	struct ieee80211_supported_band *sband;
    270	u32 user_mask, basic_rates = sdata->vif.bss_conf.basic_rates;
    271	enum nl80211_band band;
    272
    273	if (WARN_ON(!sdata->vif.bss_conf.chandef.chan))
    274		return;
    275
    276	band = sdata->vif.bss_conf.chandef.chan->band;
    277	if (band == NL80211_BAND_S1GHZ) {
    278		/* TODO */
    279		return;
    280	}
    281
    282	if (WARN_ON_ONCE(!basic_rates))
    283		return;
    284
    285	user_mask = sdata->rc_rateidx_mask[band];
    286	sband = local->hw.wiphy->bands[band];
    287
    288	if (user_mask & basic_rates)
    289		return;
    290
    291	sdata_dbg(sdata,
    292		  "no overlap between basic rates (0x%x) and user mask (0x%x on band %d) - clearing the latter",
    293		  basic_rates, user_mask, band);
    294	sdata->rc_rateidx_mask[band] = (1 << sband->n_bitrates) - 1;
    295}
    296
    297static bool rc_no_data_or_no_ack_use_min(struct ieee80211_tx_rate_control *txrc)
    298{
    299	struct sk_buff *skb = txrc->skb;
    300	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
    301
    302	return (info->flags & (IEEE80211_TX_CTL_NO_ACK |
    303			       IEEE80211_TX_CTL_USE_MINRATE)) ||
    304		!ieee80211_is_tx_data(skb);
    305}
    306
    307static void rc_send_low_basicrate(struct ieee80211_tx_rate *rate,
    308				  u32 basic_rates,
    309				  struct ieee80211_supported_band *sband)
    310{
    311	u8 i;
    312
    313	if (sband->band == NL80211_BAND_S1GHZ) {
    314		/* TODO */
    315		rate->flags |= IEEE80211_TX_RC_S1G_MCS;
    316		rate->idx = 0;
    317		return;
    318	}
    319
    320	if (basic_rates == 0)
    321		return; /* assume basic rates unknown and accept rate */
    322	if (rate->idx < 0)
    323		return;
    324	if (basic_rates & (1 << rate->idx))
    325		return; /* selected rate is a basic rate */
    326
    327	for (i = rate->idx + 1; i <= sband->n_bitrates; i++) {
    328		if (basic_rates & (1 << i)) {
    329			rate->idx = i;
    330			return;
    331		}
    332	}
    333
    334	/* could not find a basic rate; use original selection */
    335}
    336
    337static void __rate_control_send_low(struct ieee80211_hw *hw,
    338				    struct ieee80211_supported_band *sband,
    339				    struct ieee80211_sta *sta,
    340				    struct ieee80211_tx_info *info,
    341				    u32 rate_mask)
    342{
    343	int i;
    344	u32 rate_flags =
    345		ieee80211_chandef_rate_flags(&hw->conf.chandef);
    346
    347	if (sband->band == NL80211_BAND_S1GHZ) {
    348		info->control.rates[0].flags |= IEEE80211_TX_RC_S1G_MCS;
    349		info->control.rates[0].idx = 0;
    350		return;
    351	}
    352
    353	if ((sband->band == NL80211_BAND_2GHZ) &&
    354	    (info->flags & IEEE80211_TX_CTL_NO_CCK_RATE))
    355		rate_flags |= IEEE80211_RATE_ERP_G;
    356
    357	info->control.rates[0].idx = 0;
    358	for (i = 0; i < sband->n_bitrates; i++) {
    359		if (!(rate_mask & BIT(i)))
    360			continue;
    361
    362		if ((rate_flags & sband->bitrates[i].flags) != rate_flags)
    363			continue;
    364
    365		if (!rate_supported(sta, sband->band, i))
    366			continue;
    367
    368		info->control.rates[0].idx = i;
    369		break;
    370	}
    371	WARN_ONCE(i == sband->n_bitrates,
    372		  "no supported rates for sta %pM (0x%x, band %d) in rate_mask 0x%x with flags 0x%x\n",
    373		  sta ? sta->addr : NULL,
    374		  sta ? sta->deflink.supp_rates[sband->band] : -1,
    375		  sband->band,
    376		  rate_mask, rate_flags);
    377
    378	info->control.rates[0].count =
    379		(info->flags & IEEE80211_TX_CTL_NO_ACK) ?
    380		1 : hw->max_rate_tries;
    381
    382	info->control.skip_table = 1;
    383}
    384
    385
    386static bool rate_control_send_low(struct ieee80211_sta *pubsta,
    387				  struct ieee80211_tx_rate_control *txrc)
    388{
    389	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
    390	struct ieee80211_supported_band *sband = txrc->sband;
    391	struct sta_info *sta;
    392	int mcast_rate;
    393	bool use_basicrate = false;
    394
    395	if (!pubsta || rc_no_data_or_no_ack_use_min(txrc)) {
    396		__rate_control_send_low(txrc->hw, sband, pubsta, info,
    397					txrc->rate_idx_mask);
    398
    399		if (!pubsta && txrc->bss) {
    400			mcast_rate = txrc->bss_conf->mcast_rate[sband->band];
    401			if (mcast_rate > 0) {
    402				info->control.rates[0].idx = mcast_rate - 1;
    403				return true;
    404			}
    405			use_basicrate = true;
    406		} else if (pubsta) {
    407			sta = container_of(pubsta, struct sta_info, sta);
    408			if (ieee80211_vif_is_mesh(&sta->sdata->vif))
    409				use_basicrate = true;
    410		}
    411
    412		if (use_basicrate)
    413			rc_send_low_basicrate(&info->control.rates[0],
    414					      txrc->bss_conf->basic_rates,
    415					      sband);
    416
    417		return true;
    418	}
    419	return false;
    420}
    421
    422static bool rate_idx_match_legacy_mask(s8 *rate_idx, int n_bitrates, u32 mask)
    423{
    424	int j;
    425
    426	/* See whether the selected rate or anything below it is allowed. */
    427	for (j = *rate_idx; j >= 0; j--) {
    428		if (mask & (1 << j)) {
    429			/* Okay, found a suitable rate. Use it. */
    430			*rate_idx = j;
    431			return true;
    432		}
    433	}
    434
    435	/* Try to find a higher rate that would be allowed */
    436	for (j = *rate_idx + 1; j < n_bitrates; j++) {
    437		if (mask & (1 << j)) {
    438			/* Okay, found a suitable rate. Use it. */
    439			*rate_idx = j;
    440			return true;
    441		}
    442	}
    443	return false;
    444}
    445
    446static bool rate_idx_match_mcs_mask(s8 *rate_idx, u8 *mcs_mask)
    447{
    448	int i, j;
    449	int ridx, rbit;
    450
    451	ridx = *rate_idx / 8;
    452	rbit = *rate_idx % 8;
    453
    454	/* sanity check */
    455	if (ridx < 0 || ridx >= IEEE80211_HT_MCS_MASK_LEN)
    456		return false;
    457
    458	/* See whether the selected rate or anything below it is allowed. */
    459	for (i = ridx; i >= 0; i--) {
    460		for (j = rbit; j >= 0; j--)
    461			if (mcs_mask[i] & BIT(j)) {
    462				*rate_idx = i * 8 + j;
    463				return true;
    464			}
    465		rbit = 7;
    466	}
    467
    468	/* Try to find a higher rate that would be allowed */
    469	ridx = (*rate_idx + 1) / 8;
    470	rbit = (*rate_idx + 1) % 8;
    471
    472	for (i = ridx; i < IEEE80211_HT_MCS_MASK_LEN; i++) {
    473		for (j = rbit; j < 8; j++)
    474			if (mcs_mask[i] & BIT(j)) {
    475				*rate_idx = i * 8 + j;
    476				return true;
    477			}
    478		rbit = 0;
    479	}
    480	return false;
    481}
    482
    483static bool rate_idx_match_vht_mcs_mask(s8 *rate_idx, u16 *vht_mask)
    484{
    485	int i, j;
    486	int ridx, rbit;
    487
    488	ridx = *rate_idx >> 4;
    489	rbit = *rate_idx & 0xf;
    490
    491	if (ridx < 0 || ridx >= NL80211_VHT_NSS_MAX)
    492		return false;
    493
    494	/* See whether the selected rate or anything below it is allowed. */
    495	for (i = ridx; i >= 0; i--) {
    496		for (j = rbit; j >= 0; j--) {
    497			if (vht_mask[i] & BIT(j)) {
    498				*rate_idx = (i << 4) | j;
    499				return true;
    500			}
    501		}
    502		rbit = 15;
    503	}
    504
    505	/* Try to find a higher rate that would be allowed */
    506	ridx = (*rate_idx + 1) >> 4;
    507	rbit = (*rate_idx + 1) & 0xf;
    508
    509	for (i = ridx; i < NL80211_VHT_NSS_MAX; i++) {
    510		for (j = rbit; j < 16; j++) {
    511			if (vht_mask[i] & BIT(j)) {
    512				*rate_idx = (i << 4) | j;
    513				return true;
    514			}
    515		}
    516		rbit = 0;
    517	}
    518	return false;
    519}
    520
    521static void rate_idx_match_mask(s8 *rate_idx, u16 *rate_flags,
    522				struct ieee80211_supported_band *sband,
    523				enum nl80211_chan_width chan_width,
    524				u32 mask,
    525				u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN],
    526				u16 vht_mask[NL80211_VHT_NSS_MAX])
    527{
    528	if (*rate_flags & IEEE80211_TX_RC_VHT_MCS) {
    529		/* handle VHT rates */
    530		if (rate_idx_match_vht_mcs_mask(rate_idx, vht_mask))
    531			return;
    532
    533		*rate_idx = 0;
    534		/* keep protection flags */
    535		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
    536				IEEE80211_TX_RC_USE_CTS_PROTECT |
    537				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
    538
    539		*rate_flags |= IEEE80211_TX_RC_MCS;
    540		if (chan_width == NL80211_CHAN_WIDTH_40)
    541			*rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH;
    542
    543		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
    544			return;
    545
    546		/* also try the legacy rates. */
    547		*rate_flags &= ~(IEEE80211_TX_RC_MCS |
    548				 IEEE80211_TX_RC_40_MHZ_WIDTH);
    549		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
    550					       mask))
    551			return;
    552	} else if (*rate_flags & IEEE80211_TX_RC_MCS) {
    553		/* handle HT rates */
    554		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
    555			return;
    556
    557		/* also try the legacy rates. */
    558		*rate_idx = 0;
    559		/* keep protection flags */
    560		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
    561				IEEE80211_TX_RC_USE_CTS_PROTECT |
    562				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
    563		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
    564					       mask))
    565			return;
    566	} else {
    567		/* handle legacy rates */
    568		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
    569					       mask))
    570			return;
    571
    572		/* if HT BSS, and we handle a data frame, also try HT rates */
    573		switch (chan_width) {
    574		case NL80211_CHAN_WIDTH_20_NOHT:
    575		case NL80211_CHAN_WIDTH_5:
    576		case NL80211_CHAN_WIDTH_10:
    577			return;
    578		default:
    579			break;
    580		}
    581
    582		*rate_idx = 0;
    583		/* keep protection flags */
    584		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
    585				IEEE80211_TX_RC_USE_CTS_PROTECT |
    586				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
    587
    588		*rate_flags |= IEEE80211_TX_RC_MCS;
    589
    590		if (chan_width == NL80211_CHAN_WIDTH_40)
    591			*rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH;
    592
    593		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
    594			return;
    595	}
    596
    597	/*
    598	 * Uh.. No suitable rate exists. This should not really happen with
    599	 * sane TX rate mask configurations. However, should someone manage to
    600	 * configure supported rates and TX rate mask in incompatible way,
    601	 * allow the frame to be transmitted with whatever the rate control
    602	 * selected.
    603	 */
    604}
    605
    606static void rate_fixup_ratelist(struct ieee80211_vif *vif,
    607				struct ieee80211_supported_band *sband,
    608				struct ieee80211_tx_info *info,
    609				struct ieee80211_tx_rate *rates,
    610				int max_rates)
    611{
    612	struct ieee80211_rate *rate;
    613	bool inval = false;
    614	int i;
    615
    616	/*
    617	 * Set up the RTS/CTS rate as the fastest basic rate
    618	 * that is not faster than the data rate unless there
    619	 * is no basic rate slower than the data rate, in which
    620	 * case we pick the slowest basic rate
    621	 *
    622	 * XXX: Should this check all retry rates?
    623	 */
    624	if (!(rates[0].flags &
    625	      (IEEE80211_TX_RC_MCS | IEEE80211_TX_RC_VHT_MCS))) {
    626		u32 basic_rates = vif->bss_conf.basic_rates;
    627		s8 baserate = basic_rates ? ffs(basic_rates) - 1 : 0;
    628
    629		rate = &sband->bitrates[rates[0].idx];
    630
    631		for (i = 0; i < sband->n_bitrates; i++) {
    632			/* must be a basic rate */
    633			if (!(basic_rates & BIT(i)))
    634				continue;
    635			/* must not be faster than the data rate */
    636			if (sband->bitrates[i].bitrate > rate->bitrate)
    637				continue;
    638			/* maximum */
    639			if (sband->bitrates[baserate].bitrate <
    640			     sband->bitrates[i].bitrate)
    641				baserate = i;
    642		}
    643
    644		info->control.rts_cts_rate_idx = baserate;
    645	}
    646
    647	for (i = 0; i < max_rates; i++) {
    648		/*
    649		 * make sure there's no valid rate following
    650		 * an invalid one, just in case drivers don't
    651		 * take the API seriously to stop at -1.
    652		 */
    653		if (inval) {
    654			rates[i].idx = -1;
    655			continue;
    656		}
    657		if (rates[i].idx < 0) {
    658			inval = true;
    659			continue;
    660		}
    661
    662		/*
    663		 * For now assume MCS is already set up correctly, this
    664		 * needs to be fixed.
    665		 */
    666		if (rates[i].flags & IEEE80211_TX_RC_MCS) {
    667			WARN_ON(rates[i].idx > 76);
    668
    669			if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) &&
    670			    info->control.use_cts_prot)
    671				rates[i].flags |=
    672					IEEE80211_TX_RC_USE_CTS_PROTECT;
    673			continue;
    674		}
    675
    676		if (rates[i].flags & IEEE80211_TX_RC_VHT_MCS) {
    677			WARN_ON(ieee80211_rate_get_vht_mcs(&rates[i]) > 9);
    678			continue;
    679		}
    680
    681		/* set up RTS protection if desired */
    682		if (info->control.use_rts) {
    683			rates[i].flags |= IEEE80211_TX_RC_USE_RTS_CTS;
    684			info->control.use_cts_prot = false;
    685		}
    686
    687		/* RC is busted */
    688		if (WARN_ON_ONCE(rates[i].idx >= sband->n_bitrates)) {
    689			rates[i].idx = -1;
    690			continue;
    691		}
    692
    693		rate = &sband->bitrates[rates[i].idx];
    694
    695		/* set up short preamble */
    696		if (info->control.short_preamble &&
    697		    rate->flags & IEEE80211_RATE_SHORT_PREAMBLE)
    698			rates[i].flags |= IEEE80211_TX_RC_USE_SHORT_PREAMBLE;
    699
    700		/* set up G protection */
    701		if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) &&
    702		    info->control.use_cts_prot &&
    703		    rate->flags & IEEE80211_RATE_ERP_G)
    704			rates[i].flags |= IEEE80211_TX_RC_USE_CTS_PROTECT;
    705	}
    706}
    707
    708
    709static void rate_control_fill_sta_table(struct ieee80211_sta *sta,
    710					struct ieee80211_tx_info *info,
    711					struct ieee80211_tx_rate *rates,
    712					int max_rates)
    713{
    714	struct ieee80211_sta_rates *ratetbl = NULL;
    715	int i;
    716
    717	if (sta && !info->control.skip_table)
    718		ratetbl = rcu_dereference(sta->rates);
    719
    720	/* Fill remaining rate slots with data from the sta rate table. */
    721	max_rates = min_t(int, max_rates, IEEE80211_TX_RATE_TABLE_SIZE);
    722	for (i = 0; i < max_rates; i++) {
    723		if (i < ARRAY_SIZE(info->control.rates) &&
    724		    info->control.rates[i].idx >= 0 &&
    725		    info->control.rates[i].count) {
    726			if (rates != info->control.rates)
    727				rates[i] = info->control.rates[i];
    728		} else if (ratetbl) {
    729			rates[i].idx = ratetbl->rate[i].idx;
    730			rates[i].flags = ratetbl->rate[i].flags;
    731			if (info->control.use_rts)
    732				rates[i].count = ratetbl->rate[i].count_rts;
    733			else if (info->control.use_cts_prot)
    734				rates[i].count = ratetbl->rate[i].count_cts;
    735			else
    736				rates[i].count = ratetbl->rate[i].count;
    737		} else {
    738			rates[i].idx = -1;
    739			rates[i].count = 0;
    740		}
    741
    742		if (rates[i].idx < 0 || !rates[i].count)
    743			break;
    744	}
    745}
    746
    747static bool rate_control_cap_mask(struct ieee80211_sub_if_data *sdata,
    748				  struct ieee80211_supported_band *sband,
    749				  struct ieee80211_sta *sta, u32 *mask,
    750				  u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN],
    751				  u16 vht_mask[NL80211_VHT_NSS_MAX])
    752{
    753	u32 i, flags;
    754
    755	*mask = sdata->rc_rateidx_mask[sband->band];
    756	flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chandef);
    757	for (i = 0; i < sband->n_bitrates; i++) {
    758		if ((flags & sband->bitrates[i].flags) != flags)
    759			*mask &= ~BIT(i);
    760	}
    761
    762	if (*mask == (1 << sband->n_bitrates) - 1 &&
    763	    !sdata->rc_has_mcs_mask[sband->band] &&
    764	    !sdata->rc_has_vht_mcs_mask[sband->band])
    765		return false;
    766
    767	if (sdata->rc_has_mcs_mask[sband->band])
    768		memcpy(mcs_mask, sdata->rc_rateidx_mcs_mask[sband->band],
    769		       IEEE80211_HT_MCS_MASK_LEN);
    770	else
    771		memset(mcs_mask, 0xff, IEEE80211_HT_MCS_MASK_LEN);
    772
    773	if (sdata->rc_has_vht_mcs_mask[sband->band])
    774		memcpy(vht_mask, sdata->rc_rateidx_vht_mcs_mask[sband->band],
    775		       sizeof(u16) * NL80211_VHT_NSS_MAX);
    776	else
    777		memset(vht_mask, 0xff, sizeof(u16) * NL80211_VHT_NSS_MAX);
    778
    779	if (sta) {
    780		__le16 sta_vht_cap;
    781		u16 sta_vht_mask[NL80211_VHT_NSS_MAX];
    782
    783		/* Filter out rates that the STA does not support */
    784		*mask &= sta->deflink.supp_rates[sband->band];
    785		for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++)
    786			mcs_mask[i] &= sta->deflink.ht_cap.mcs.rx_mask[i];
    787
    788		sta_vht_cap = sta->deflink.vht_cap.vht_mcs.rx_mcs_map;
    789		ieee80211_get_vht_mask_from_cap(sta_vht_cap, sta_vht_mask);
    790		for (i = 0; i < NL80211_VHT_NSS_MAX; i++)
    791			vht_mask[i] &= sta_vht_mask[i];
    792	}
    793
    794	return true;
    795}
    796
    797static void
    798rate_control_apply_mask_ratetbl(struct sta_info *sta,
    799				struct ieee80211_supported_band *sband,
    800				struct ieee80211_sta_rates *rates)
    801{
    802	int i;
    803	u32 mask;
    804	u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN];
    805	u16 vht_mask[NL80211_VHT_NSS_MAX];
    806	enum nl80211_chan_width chan_width;
    807
    808	if (!rate_control_cap_mask(sta->sdata, sband, &sta->sta, &mask,
    809				   mcs_mask, vht_mask))
    810		return;
    811
    812	chan_width = sta->sdata->vif.bss_conf.chandef.width;
    813	for (i = 0; i < IEEE80211_TX_RATE_TABLE_SIZE; i++) {
    814		if (rates->rate[i].idx < 0)
    815			break;
    816
    817		rate_idx_match_mask(&rates->rate[i].idx, &rates->rate[i].flags,
    818				    sband, chan_width, mask, mcs_mask,
    819				    vht_mask);
    820	}
    821}
    822
    823static void rate_control_apply_mask(struct ieee80211_sub_if_data *sdata,
    824				    struct ieee80211_sta *sta,
    825				    struct ieee80211_supported_band *sband,
    826				    struct ieee80211_tx_rate *rates,
    827				    int max_rates)
    828{
    829	enum nl80211_chan_width chan_width;
    830	u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN];
    831	u32 mask;
    832	u16 rate_flags, vht_mask[NL80211_VHT_NSS_MAX];
    833	int i;
    834
    835	/*
    836	 * Try to enforce the rateidx mask the user wanted. skip this if the
    837	 * default mask (allow all rates) is used to save some processing for
    838	 * the common case.
    839	 */
    840	if (!rate_control_cap_mask(sdata, sband, sta, &mask, mcs_mask,
    841				   vht_mask))
    842		return;
    843
    844	/*
    845	 * Make sure the rate index selected for each TX rate is
    846	 * included in the configured mask and change the rate indexes
    847	 * if needed.
    848	 */
    849	chan_width = sdata->vif.bss_conf.chandef.width;
    850	for (i = 0; i < max_rates; i++) {
    851		/* Skip invalid rates */
    852		if (rates[i].idx < 0)
    853			break;
    854
    855		rate_flags = rates[i].flags;
    856		rate_idx_match_mask(&rates[i].idx, &rate_flags, sband,
    857				    chan_width, mask, mcs_mask, vht_mask);
    858		rates[i].flags = rate_flags;
    859	}
    860}
    861
    862void ieee80211_get_tx_rates(struct ieee80211_vif *vif,
    863			    struct ieee80211_sta *sta,
    864			    struct sk_buff *skb,
    865			    struct ieee80211_tx_rate *dest,
    866			    int max_rates)
    867{
    868	struct ieee80211_sub_if_data *sdata;
    869	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
    870	struct ieee80211_supported_band *sband;
    871
    872	rate_control_fill_sta_table(sta, info, dest, max_rates);
    873
    874	if (!vif)
    875		return;
    876
    877	sdata = vif_to_sdata(vif);
    878	sband = sdata->local->hw.wiphy->bands[info->band];
    879
    880	if (ieee80211_is_tx_data(skb))
    881		rate_control_apply_mask(sdata, sta, sband, dest, max_rates);
    882
    883	if (dest[0].idx < 0)
    884		__rate_control_send_low(&sdata->local->hw, sband, sta, info,
    885					sdata->rc_rateidx_mask[info->band]);
    886
    887	if (sta)
    888		rate_fixup_ratelist(vif, sband, info, dest, max_rates);
    889}
    890EXPORT_SYMBOL(ieee80211_get_tx_rates);
    891
    892void rate_control_get_rate(struct ieee80211_sub_if_data *sdata,
    893			   struct sta_info *sta,
    894			   struct ieee80211_tx_rate_control *txrc)
    895{
    896	struct rate_control_ref *ref = sdata->local->rate_ctrl;
    897	void *priv_sta = NULL;
    898	struct ieee80211_sta *ista = NULL;
    899	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
    900	int i;
    901
    902	for (i = 0; i < IEEE80211_TX_MAX_RATES; i++) {
    903		info->control.rates[i].idx = -1;
    904		info->control.rates[i].flags = 0;
    905		info->control.rates[i].count = 0;
    906	}
    907
    908	if (rate_control_send_low(sta ? &sta->sta : NULL, txrc))
    909		return;
    910
    911	if (ieee80211_hw_check(&sdata->local->hw, HAS_RATE_CONTROL))
    912		return;
    913
    914	if (sta && test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) {
    915		ista = &sta->sta;
    916		priv_sta = sta->rate_ctrl_priv;
    917	}
    918
    919	if (ista) {
    920		spin_lock_bh(&sta->rate_ctrl_lock);
    921		ref->ops->get_rate(ref->priv, ista, priv_sta, txrc);
    922		spin_unlock_bh(&sta->rate_ctrl_lock);
    923	} else {
    924		rate_control_send_low(NULL, txrc);
    925	}
    926
    927	if (ieee80211_hw_check(&sdata->local->hw, SUPPORTS_RC_TABLE))
    928		return;
    929
    930	ieee80211_get_tx_rates(&sdata->vif, ista, txrc->skb,
    931			       info->control.rates,
    932			       ARRAY_SIZE(info->control.rates));
    933}
    934
    935int rate_control_set_rates(struct ieee80211_hw *hw,
    936			   struct ieee80211_sta *pubsta,
    937			   struct ieee80211_sta_rates *rates)
    938{
    939	struct sta_info *sta = container_of(pubsta, struct sta_info, sta);
    940	struct ieee80211_sta_rates *old;
    941	struct ieee80211_supported_band *sband;
    942
    943	sband = ieee80211_get_sband(sta->sdata);
    944	if (!sband)
    945		return -EINVAL;
    946	rate_control_apply_mask_ratetbl(sta, sband, rates);
    947	/*
    948	 * mac80211 guarantees that this function will not be called
    949	 * concurrently, so the following RCU access is safe, even without
    950	 * extra locking. This can not be checked easily, so we just set
    951	 * the condition to true.
    952	 */
    953	old = rcu_dereference_protected(pubsta->rates, true);
    954	rcu_assign_pointer(pubsta->rates, rates);
    955	if (old)
    956		kfree_rcu(old, rcu_head);
    957
    958	if (sta->uploaded)
    959		drv_sta_rate_tbl_update(hw_to_local(hw), sta->sdata, pubsta);
    960
    961	ieee80211_sta_set_expected_throughput(pubsta, sta_get_expected_throughput(sta));
    962
    963	return 0;
    964}
    965EXPORT_SYMBOL(rate_control_set_rates);
    966
    967int ieee80211_init_rate_ctrl_alg(struct ieee80211_local *local,
    968				 const char *name)
    969{
    970	struct rate_control_ref *ref;
    971
    972	ASSERT_RTNL();
    973
    974	if (local->open_count)
    975		return -EBUSY;
    976
    977	if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) {
    978		if (WARN_ON(!local->ops->set_rts_threshold))
    979			return -EINVAL;
    980		return 0;
    981	}
    982
    983	ref = rate_control_alloc(name, local);
    984	if (!ref) {
    985		wiphy_warn(local->hw.wiphy,
    986			   "Failed to select rate control algorithm\n");
    987		return -ENOENT;
    988	}
    989
    990	WARN_ON(local->rate_ctrl);
    991	local->rate_ctrl = ref;
    992
    993	wiphy_debug(local->hw.wiphy, "Selected rate control algorithm '%s'\n",
    994		    ref->ops->name);
    995
    996	return 0;
    997}
    998
    999void rate_control_deinitialize(struct ieee80211_local *local)
   1000{
   1001	struct rate_control_ref *ref;
   1002
   1003	ref = local->rate_ctrl;
   1004
   1005	if (!ref)
   1006		return;
   1007
   1008	local->rate_ctrl = NULL;
   1009	rate_control_free(local, ref);
   1010}