cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

rxe_recv.c (8994B)


      1// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
      2/*
      3 * Copyright (c) 2016 Mellanox Technologies Ltd. All rights reserved.
      4 * Copyright (c) 2015 System Fabric Works, Inc. All rights reserved.
      5 */
      6
      7#include <linux/skbuff.h>
      8
      9#include "rxe.h"
     10#include "rxe_loc.h"
     11
     12/* check that QP matches packet opcode type and is in a valid state */
     13static int check_type_state(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
     14			    struct rxe_qp *qp)
     15{
     16	unsigned int pkt_type;
     17
     18	if (unlikely(!qp->valid))
     19		goto err1;
     20
     21	pkt_type = pkt->opcode & 0xe0;
     22
     23	switch (qp_type(qp)) {
     24	case IB_QPT_RC:
     25		if (unlikely(pkt_type != IB_OPCODE_RC)) {
     26			pr_warn_ratelimited("bad qp type\n");
     27			goto err1;
     28		}
     29		break;
     30	case IB_QPT_UC:
     31		if (unlikely(pkt_type != IB_OPCODE_UC)) {
     32			pr_warn_ratelimited("bad qp type\n");
     33			goto err1;
     34		}
     35		break;
     36	case IB_QPT_UD:
     37	case IB_QPT_GSI:
     38		if (unlikely(pkt_type != IB_OPCODE_UD)) {
     39			pr_warn_ratelimited("bad qp type\n");
     40			goto err1;
     41		}
     42		break;
     43	default:
     44		pr_warn_ratelimited("unsupported qp type\n");
     45		goto err1;
     46	}
     47
     48	if (pkt->mask & RXE_REQ_MASK) {
     49		if (unlikely(qp->resp.state != QP_STATE_READY))
     50			goto err1;
     51	} else if (unlikely(qp->req.state < QP_STATE_READY ||
     52				qp->req.state > QP_STATE_DRAINED)) {
     53		goto err1;
     54	}
     55
     56	return 0;
     57
     58err1:
     59	return -EINVAL;
     60}
     61
     62static void set_bad_pkey_cntr(struct rxe_port *port)
     63{
     64	spin_lock_bh(&port->port_lock);
     65	port->attr.bad_pkey_cntr = min((u32)0xffff,
     66				       port->attr.bad_pkey_cntr + 1);
     67	spin_unlock_bh(&port->port_lock);
     68}
     69
     70static void set_qkey_viol_cntr(struct rxe_port *port)
     71{
     72	spin_lock_bh(&port->port_lock);
     73	port->attr.qkey_viol_cntr = min((u32)0xffff,
     74					port->attr.qkey_viol_cntr + 1);
     75	spin_unlock_bh(&port->port_lock);
     76}
     77
     78static int check_keys(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
     79		      u32 qpn, struct rxe_qp *qp)
     80{
     81	struct rxe_port *port = &rxe->port;
     82	u16 pkey = bth_pkey(pkt);
     83
     84	pkt->pkey_index = 0;
     85
     86	if (!pkey_match(pkey, IB_DEFAULT_PKEY_FULL)) {
     87		pr_warn_ratelimited("bad pkey = 0x%x\n", pkey);
     88		set_bad_pkey_cntr(port);
     89		goto err1;
     90	}
     91
     92	if (qp_type(qp) == IB_QPT_UD || qp_type(qp) == IB_QPT_GSI) {
     93		u32 qkey = (qpn == 1) ? GSI_QKEY : qp->attr.qkey;
     94
     95		if (unlikely(deth_qkey(pkt) != qkey)) {
     96			pr_warn_ratelimited("bad qkey, got 0x%x expected 0x%x for qpn 0x%x\n",
     97					    deth_qkey(pkt), qkey, qpn);
     98			set_qkey_viol_cntr(port);
     99			goto err1;
    100		}
    101	}
    102
    103	return 0;
    104
    105err1:
    106	return -EINVAL;
    107}
    108
    109static int check_addr(struct rxe_dev *rxe, struct rxe_pkt_info *pkt,
    110		      struct rxe_qp *qp)
    111{
    112	struct sk_buff *skb = PKT_TO_SKB(pkt);
    113
    114	if (qp_type(qp) != IB_QPT_RC && qp_type(qp) != IB_QPT_UC)
    115		goto done;
    116
    117	if (unlikely(pkt->port_num != qp->attr.port_num)) {
    118		pr_warn_ratelimited("port %d != qp port %d\n",
    119				    pkt->port_num, qp->attr.port_num);
    120		goto err1;
    121	}
    122
    123	if (skb->protocol == htons(ETH_P_IP)) {
    124		struct in_addr *saddr =
    125			&qp->pri_av.sgid_addr._sockaddr_in.sin_addr;
    126		struct in_addr *daddr =
    127			&qp->pri_av.dgid_addr._sockaddr_in.sin_addr;
    128
    129		if (ip_hdr(skb)->daddr != saddr->s_addr) {
    130			pr_warn_ratelimited("dst addr %pI4 != qp source addr %pI4\n",
    131					    &ip_hdr(skb)->daddr,
    132					    &saddr->s_addr);
    133			goto err1;
    134		}
    135
    136		if (ip_hdr(skb)->saddr != daddr->s_addr) {
    137			pr_warn_ratelimited("source addr %pI4 != qp dst addr %pI4\n",
    138					    &ip_hdr(skb)->saddr,
    139					    &daddr->s_addr);
    140			goto err1;
    141		}
    142
    143	} else if (skb->protocol == htons(ETH_P_IPV6)) {
    144		struct in6_addr *saddr =
    145			&qp->pri_av.sgid_addr._sockaddr_in6.sin6_addr;
    146		struct in6_addr *daddr =
    147			&qp->pri_av.dgid_addr._sockaddr_in6.sin6_addr;
    148
    149		if (memcmp(&ipv6_hdr(skb)->daddr, saddr, sizeof(*saddr))) {
    150			pr_warn_ratelimited("dst addr %pI6 != qp source addr %pI6\n",
    151					    &ipv6_hdr(skb)->daddr, saddr);
    152			goto err1;
    153		}
    154
    155		if (memcmp(&ipv6_hdr(skb)->saddr, daddr, sizeof(*daddr))) {
    156			pr_warn_ratelimited("source addr %pI6 != qp dst addr %pI6\n",
    157					    &ipv6_hdr(skb)->saddr, daddr);
    158			goto err1;
    159		}
    160	}
    161
    162done:
    163	return 0;
    164
    165err1:
    166	return -EINVAL;
    167}
    168
    169static int hdr_check(struct rxe_pkt_info *pkt)
    170{
    171	struct rxe_dev *rxe = pkt->rxe;
    172	struct rxe_port *port = &rxe->port;
    173	struct rxe_qp *qp = NULL;
    174	u32 qpn = bth_qpn(pkt);
    175	int index;
    176	int err;
    177
    178	if (unlikely(bth_tver(pkt) != BTH_TVER)) {
    179		pr_warn_ratelimited("bad tver\n");
    180		goto err1;
    181	}
    182
    183	if (unlikely(qpn == 0)) {
    184		pr_warn_once("QP 0 not supported");
    185		goto err1;
    186	}
    187
    188	if (qpn != IB_MULTICAST_QPN) {
    189		index = (qpn == 1) ? port->qp_gsi_index : qpn;
    190
    191		qp = rxe_pool_get_index(&rxe->qp_pool, index);
    192		if (unlikely(!qp)) {
    193			pr_warn_ratelimited("no qp matches qpn 0x%x\n", qpn);
    194			goto err1;
    195		}
    196
    197		err = check_type_state(rxe, pkt, qp);
    198		if (unlikely(err))
    199			goto err2;
    200
    201		err = check_addr(rxe, pkt, qp);
    202		if (unlikely(err))
    203			goto err2;
    204
    205		err = check_keys(rxe, pkt, qpn, qp);
    206		if (unlikely(err))
    207			goto err2;
    208	} else {
    209		if (unlikely((pkt->mask & RXE_GRH_MASK) == 0)) {
    210			pr_warn_ratelimited("no grh for mcast qpn\n");
    211			goto err1;
    212		}
    213	}
    214
    215	pkt->qp = qp;
    216	return 0;
    217
    218err2:
    219	rxe_put(qp);
    220err1:
    221	return -EINVAL;
    222}
    223
    224static inline void rxe_rcv_pkt(struct rxe_pkt_info *pkt, struct sk_buff *skb)
    225{
    226	if (pkt->mask & RXE_REQ_MASK)
    227		rxe_resp_queue_pkt(pkt->qp, skb);
    228	else
    229		rxe_comp_queue_pkt(pkt->qp, skb);
    230}
    231
    232static void rxe_rcv_mcast_pkt(struct rxe_dev *rxe, struct sk_buff *skb)
    233{
    234	struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
    235	struct rxe_mcg *mcg;
    236	struct rxe_mca *mca;
    237	struct rxe_qp *qp;
    238	union ib_gid dgid;
    239	int err;
    240
    241	if (skb->protocol == htons(ETH_P_IP))
    242		ipv6_addr_set_v4mapped(ip_hdr(skb)->daddr,
    243				       (struct in6_addr *)&dgid);
    244	else if (skb->protocol == htons(ETH_P_IPV6))
    245		memcpy(&dgid, &ipv6_hdr(skb)->daddr, sizeof(dgid));
    246
    247	/* lookup mcast group corresponding to mgid, takes a ref */
    248	mcg = rxe_lookup_mcg(rxe, &dgid);
    249	if (!mcg)
    250		goto drop;	/* mcast group not registered */
    251
    252	spin_lock_bh(&rxe->mcg_lock);
    253
    254	/* this is unreliable datagram service so we let
    255	 * failures to deliver a multicast packet to a
    256	 * single QP happen and just move on and try
    257	 * the rest of them on the list
    258	 */
    259	list_for_each_entry(mca, &mcg->qp_list, qp_list) {
    260		qp = mca->qp;
    261
    262		/* validate qp for incoming packet */
    263		err = check_type_state(rxe, pkt, qp);
    264		if (err)
    265			continue;
    266
    267		err = check_keys(rxe, pkt, bth_qpn(pkt), qp);
    268		if (err)
    269			continue;
    270
    271		/* for all but the last QP create a new clone of the
    272		 * skb and pass to the QP. Pass the original skb to
    273		 * the last QP in the list.
    274		 */
    275		if (mca->qp_list.next != &mcg->qp_list) {
    276			struct sk_buff *cskb;
    277			struct rxe_pkt_info *cpkt;
    278
    279			cskb = skb_clone(skb, GFP_ATOMIC);
    280			if (unlikely(!cskb))
    281				continue;
    282
    283			if (WARN_ON(!ib_device_try_get(&rxe->ib_dev))) {
    284				kfree_skb(cskb);
    285				break;
    286			}
    287
    288			cpkt = SKB_TO_PKT(cskb);
    289			cpkt->qp = qp;
    290			rxe_get(qp);
    291			rxe_rcv_pkt(cpkt, cskb);
    292		} else {
    293			pkt->qp = qp;
    294			rxe_get(qp);
    295			rxe_rcv_pkt(pkt, skb);
    296			skb = NULL;	/* mark consumed */
    297		}
    298	}
    299
    300	spin_unlock_bh(&rxe->mcg_lock);
    301
    302	kref_put(&mcg->ref_cnt, rxe_cleanup_mcg);
    303
    304	if (likely(!skb))
    305		return;
    306
    307	/* This only occurs if one of the checks fails on the last
    308	 * QP in the list above
    309	 */
    310
    311drop:
    312	kfree_skb(skb);
    313	ib_device_put(&rxe->ib_dev);
    314}
    315
    316/**
    317 * rxe_chk_dgid - validate destination IP address
    318 * @rxe: rxe device that received packet
    319 * @skb: the received packet buffer
    320 *
    321 * Accept any loopback packets
    322 * Extract IP address from packet and
    323 * Accept if multicast packet
    324 * Accept if matches an SGID table entry
    325 */
    326static int rxe_chk_dgid(struct rxe_dev *rxe, struct sk_buff *skb)
    327{
    328	struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
    329	const struct ib_gid_attr *gid_attr;
    330	union ib_gid dgid;
    331	union ib_gid *pdgid;
    332
    333	if (pkt->mask & RXE_LOOPBACK_MASK)
    334		return 0;
    335
    336	if (skb->protocol == htons(ETH_P_IP)) {
    337		ipv6_addr_set_v4mapped(ip_hdr(skb)->daddr,
    338				       (struct in6_addr *)&dgid);
    339		pdgid = &dgid;
    340	} else {
    341		pdgid = (union ib_gid *)&ipv6_hdr(skb)->daddr;
    342	}
    343
    344	if (rdma_is_multicast_addr((struct in6_addr *)pdgid))
    345		return 0;
    346
    347	gid_attr = rdma_find_gid_by_port(&rxe->ib_dev, pdgid,
    348					 IB_GID_TYPE_ROCE_UDP_ENCAP,
    349					 1, skb->dev);
    350	if (IS_ERR(gid_attr))
    351		return PTR_ERR(gid_attr);
    352
    353	rdma_put_gid_attr(gid_attr);
    354	return 0;
    355}
    356
    357/* rxe_rcv is called from the interface driver */
    358void rxe_rcv(struct sk_buff *skb)
    359{
    360	int err;
    361	struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
    362	struct rxe_dev *rxe = pkt->rxe;
    363
    364	if (unlikely(skb->len < RXE_BTH_BYTES))
    365		goto drop;
    366
    367	if (rxe_chk_dgid(rxe, skb) < 0) {
    368		pr_warn_ratelimited("failed checking dgid\n");
    369		goto drop;
    370	}
    371
    372	pkt->opcode = bth_opcode(pkt);
    373	pkt->psn = bth_psn(pkt);
    374	pkt->qp = NULL;
    375	pkt->mask |= rxe_opcode[pkt->opcode].mask;
    376
    377	if (unlikely(skb->len < header_size(pkt)))
    378		goto drop;
    379
    380	err = hdr_check(pkt);
    381	if (unlikely(err))
    382		goto drop;
    383
    384	err = rxe_icrc_check(skb, pkt);
    385	if (unlikely(err))
    386		goto drop;
    387
    388	rxe_counter_inc(rxe, RXE_CNT_RCVD_PKTS);
    389
    390	if (unlikely(bth_qpn(pkt) == IB_MULTICAST_QPN))
    391		rxe_rcv_mcast_pkt(rxe, skb);
    392	else
    393		rxe_rcv_pkt(pkt, skb);
    394
    395	return;
    396
    397drop:
    398	if (pkt->qp)
    399		rxe_put(pkt->qp);
    400
    401	kfree_skb(skb);
    402	ib_device_put(&rxe->ib_dev);
    403}