cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

keembay-ocs-ecc.c (27061B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/*
      3 * Intel Keem Bay OCS ECC Crypto Driver.
      4 *
      5 * Copyright (C) 2019-2021 Intel Corporation
      6 */
      7
      8#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
      9
     10#include <linux/clk.h>
     11#include <linux/completion.h>
     12#include <linux/crypto.h>
     13#include <linux/delay.h>
     14#include <linux/fips.h>
     15#include <linux/interrupt.h>
     16#include <linux/io.h>
     17#include <linux/iopoll.h>
     18#include <linux/irq.h>
     19#include <linux/module.h>
     20#include <linux/of.h>
     21#include <linux/platform_device.h>
     22#include <linux/scatterlist.h>
     23#include <linux/slab.h>
     24#include <linux/types.h>
     25
     26#include <crypto/ecc_curve.h>
     27#include <crypto/ecdh.h>
     28#include <crypto/engine.h>
     29#include <crypto/kpp.h>
     30#include <crypto/rng.h>
     31
     32#include <crypto/internal/ecc.h>
     33#include <crypto/internal/kpp.h>
     34
     35#define DRV_NAME			"keembay-ocs-ecc"
     36
     37#define KMB_OCS_ECC_PRIORITY		350
     38
     39#define HW_OFFS_OCS_ECC_COMMAND		0x00000000
     40#define HW_OFFS_OCS_ECC_STATUS		0x00000004
     41#define HW_OFFS_OCS_ECC_DATA_IN		0x00000080
     42#define HW_OFFS_OCS_ECC_CX_DATA_OUT	0x00000100
     43#define HW_OFFS_OCS_ECC_CY_DATA_OUT	0x00000180
     44#define HW_OFFS_OCS_ECC_ISR		0x00000400
     45#define HW_OFFS_OCS_ECC_IER		0x00000404
     46
     47#define HW_OCS_ECC_ISR_INT_STATUS_DONE	BIT(0)
     48#define HW_OCS_ECC_COMMAND_INS_BP	BIT(0)
     49
     50#define HW_OCS_ECC_COMMAND_START_VAL	BIT(0)
     51
     52#define OCS_ECC_OP_SIZE_384		BIT(8)
     53#define OCS_ECC_OP_SIZE_256		0
     54
     55/* ECC Instruction : for ECC_COMMAND */
     56#define OCS_ECC_INST_WRITE_AX		(0x1 << HW_OCS_ECC_COMMAND_INS_BP)
     57#define OCS_ECC_INST_WRITE_AY		(0x2 << HW_OCS_ECC_COMMAND_INS_BP)
     58#define OCS_ECC_INST_WRITE_BX_D		(0x3 << HW_OCS_ECC_COMMAND_INS_BP)
     59#define OCS_ECC_INST_WRITE_BY_L		(0x4 << HW_OCS_ECC_COMMAND_INS_BP)
     60#define OCS_ECC_INST_WRITE_P		(0x5 << HW_OCS_ECC_COMMAND_INS_BP)
     61#define OCS_ECC_INST_WRITE_A		(0x6 << HW_OCS_ECC_COMMAND_INS_BP)
     62#define OCS_ECC_INST_CALC_D_IDX_A	(0x8 << HW_OCS_ECC_COMMAND_INS_BP)
     63#define OCS_ECC_INST_CALC_A_POW_B_MODP	(0xB << HW_OCS_ECC_COMMAND_INS_BP)
     64#define OCS_ECC_INST_CALC_A_MUL_B_MODP	(0xC  << HW_OCS_ECC_COMMAND_INS_BP)
     65#define OCS_ECC_INST_CALC_A_ADD_B_MODP	(0xD << HW_OCS_ECC_COMMAND_INS_BP)
     66
     67#define ECC_ENABLE_INTR			1
     68
     69#define POLL_USEC			100
     70#define TIMEOUT_USEC			10000
     71
     72#define KMB_ECC_VLI_MAX_DIGITS		ECC_CURVE_NIST_P384_DIGITS
     73#define KMB_ECC_VLI_MAX_BYTES		(KMB_ECC_VLI_MAX_DIGITS \
     74					 << ECC_DIGITS_TO_BYTES_SHIFT)
     75
     76#define POW_CUBE			3
     77
     78/**
     79 * struct ocs_ecc_dev - ECC device context
     80 * @list: List of device contexts
     81 * @dev: OCS ECC device
     82 * @base_reg: IO base address of OCS ECC
     83 * @engine: Crypto engine for the device
     84 * @irq_done: IRQ done completion.
     85 * @irq: IRQ number
     86 */
     87struct ocs_ecc_dev {
     88	struct list_head list;
     89	struct device *dev;
     90	void __iomem *base_reg;
     91	struct crypto_engine *engine;
     92	struct completion irq_done;
     93	int irq;
     94};
     95
     96/**
     97 * struct ocs_ecc_ctx - Transformation context.
     98 * @engine_ctx:	 Crypto engine ctx.
     99 * @ecc_dev:	 The ECC driver associated with this context.
    100 * @curve:	 The elliptic curve used by this transformation.
    101 * @private_key: The private key.
    102 */
    103struct ocs_ecc_ctx {
    104	struct crypto_engine_ctx engine_ctx;
    105	struct ocs_ecc_dev *ecc_dev;
    106	const struct ecc_curve *curve;
    107	u64 private_key[KMB_ECC_VLI_MAX_DIGITS];
    108};
    109
    110/* Driver data. */
    111struct ocs_ecc_drv {
    112	struct list_head dev_list;
    113	spinlock_t lock;	/* Protects dev_list. */
    114};
    115
    116/* Global variable holding the list of OCS ECC devices (only one expected). */
    117static struct ocs_ecc_drv ocs_ecc = {
    118	.dev_list = LIST_HEAD_INIT(ocs_ecc.dev_list),
    119	.lock = __SPIN_LOCK_UNLOCKED(ocs_ecc.lock),
    120};
    121
    122/* Get OCS ECC tfm context from kpp_request. */
    123static inline struct ocs_ecc_ctx *kmb_ocs_ecc_tctx(struct kpp_request *req)
    124{
    125	return kpp_tfm_ctx(crypto_kpp_reqtfm(req));
    126}
    127
    128/* Converts number of digits to number of bytes. */
    129static inline unsigned int digits_to_bytes(unsigned int n)
    130{
    131	return n << ECC_DIGITS_TO_BYTES_SHIFT;
    132}
    133
    134/*
    135 * Wait for ECC idle i.e when an operation (other than write operations)
    136 * is done.
    137 */
    138static inline int ocs_ecc_wait_idle(struct ocs_ecc_dev *dev)
    139{
    140	u32 value;
    141
    142	return readl_poll_timeout((dev->base_reg + HW_OFFS_OCS_ECC_STATUS),
    143				  value,
    144				  !(value & HW_OCS_ECC_ISR_INT_STATUS_DONE),
    145				  POLL_USEC, TIMEOUT_USEC);
    146}
    147
    148static void ocs_ecc_cmd_start(struct ocs_ecc_dev *ecc_dev, u32 op_size)
    149{
    150	iowrite32(op_size | HW_OCS_ECC_COMMAND_START_VAL,
    151		  ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
    152}
    153
    154/* Direct write of u32 buffer to ECC engine with associated instruction. */
    155static void ocs_ecc_write_cmd_and_data(struct ocs_ecc_dev *dev,
    156				       u32 op_size,
    157				       u32 inst,
    158				       const void *data_in,
    159				       size_t data_size)
    160{
    161	iowrite32(op_size | inst, dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
    162
    163	/* MMIO Write src uint32 to dst. */
    164	memcpy_toio(dev->base_reg + HW_OFFS_OCS_ECC_DATA_IN, data_in,
    165		    data_size);
    166}
    167
    168/* Start OCS ECC operation and wait for its completion. */
    169static int ocs_ecc_trigger_op(struct ocs_ecc_dev *ecc_dev, u32 op_size,
    170			      u32 inst)
    171{
    172	reinit_completion(&ecc_dev->irq_done);
    173
    174	iowrite32(ECC_ENABLE_INTR, ecc_dev->base_reg + HW_OFFS_OCS_ECC_IER);
    175	iowrite32(op_size | inst, ecc_dev->base_reg + HW_OFFS_OCS_ECC_COMMAND);
    176
    177	return wait_for_completion_interruptible(&ecc_dev->irq_done);
    178}
    179
    180/**
    181 * ocs_ecc_read_cx_out() - Read the CX data output buffer.
    182 * @dev:	The OCS ECC device to read from.
    183 * @cx_out:	The buffer where to store the CX value. Must be at least
    184 *		@byte_count byte long.
    185 * @byte_count:	The amount of data to read.
    186 */
    187static inline void ocs_ecc_read_cx_out(struct ocs_ecc_dev *dev, void *cx_out,
    188				       size_t byte_count)
    189{
    190	memcpy_fromio(cx_out, dev->base_reg + HW_OFFS_OCS_ECC_CX_DATA_OUT,
    191		      byte_count);
    192}
    193
    194/**
    195 * ocs_ecc_read_cy_out() - Read the CX data output buffer.
    196 * @dev:	The OCS ECC device to read from.
    197 * @cy_out:	The buffer where to store the CY value. Must be at least
    198 *		@byte_count byte long.
    199 * @byte_count:	The amount of data to read.
    200 */
    201static inline void ocs_ecc_read_cy_out(struct ocs_ecc_dev *dev, void *cy_out,
    202				       size_t byte_count)
    203{
    204	memcpy_fromio(cy_out, dev->base_reg + HW_OFFS_OCS_ECC_CY_DATA_OUT,
    205		      byte_count);
    206}
    207
    208static struct ocs_ecc_dev *kmb_ocs_ecc_find_dev(struct ocs_ecc_ctx *tctx)
    209{
    210	if (tctx->ecc_dev)
    211		return tctx->ecc_dev;
    212
    213	spin_lock(&ocs_ecc.lock);
    214
    215	/* Only a single OCS device available. */
    216	tctx->ecc_dev = list_first_entry(&ocs_ecc.dev_list, struct ocs_ecc_dev,
    217					 list);
    218
    219	spin_unlock(&ocs_ecc.lock);
    220
    221	return tctx->ecc_dev;
    222}
    223
    224/* Do point multiplication using OCS ECC HW. */
    225static int kmb_ecc_point_mult(struct ocs_ecc_dev *ecc_dev,
    226			      struct ecc_point *result,
    227			      const struct ecc_point *point,
    228			      u64 *scalar,
    229			      const struct ecc_curve *curve)
    230{
    231	u8 sca[KMB_ECC_VLI_MAX_BYTES]; /* Use the maximum data size. */
    232	u32 op_size = (curve->g.ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
    233		      OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
    234	size_t nbytes = digits_to_bytes(curve->g.ndigits);
    235	int rc = 0;
    236
    237	/* Generate random nbytes for Simple and Differential SCA protection. */
    238	rc = crypto_get_default_rng();
    239	if (rc)
    240		return rc;
    241
    242	rc = crypto_rng_get_bytes(crypto_default_rng, sca, nbytes);
    243	crypto_put_default_rng();
    244	if (rc)
    245		return rc;
    246
    247	/* Wait engine to be idle before starting new operation. */
    248	rc = ocs_ecc_wait_idle(ecc_dev);
    249	if (rc)
    250		return rc;
    251
    252	/* Send ecc_start pulse as well as indicating operation size. */
    253	ocs_ecc_cmd_start(ecc_dev, op_size);
    254
    255	/* Write ax param; Base point (Gx). */
    256	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
    257				   point->x, nbytes);
    258
    259	/* Write ay param; Base point (Gy). */
    260	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
    261				   point->y, nbytes);
    262
    263	/*
    264	 * Write the private key into DATA_IN reg.
    265	 *
    266	 * Since DATA_IN register is used to write different values during the
    267	 * computation private Key value is overwritten with
    268	 * side-channel-resistance value.
    269	 */
    270	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BX_D,
    271				   scalar, nbytes);
    272
    273	/* Write operand by/l. */
    274	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_BY_L,
    275				   sca, nbytes);
    276	memzero_explicit(sca, sizeof(sca));
    277
    278	/* Write p = curve prime(GF modulus). */
    279	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
    280				   curve->p, nbytes);
    281
    282	/* Write a = curve coefficient. */
    283	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_A,
    284				   curve->a, nbytes);
    285
    286	/* Make hardware perform the multiplication. */
    287	rc = ocs_ecc_trigger_op(ecc_dev, op_size, OCS_ECC_INST_CALC_D_IDX_A);
    288	if (rc)
    289		return rc;
    290
    291	/* Read result. */
    292	ocs_ecc_read_cx_out(ecc_dev, result->x, nbytes);
    293	ocs_ecc_read_cy_out(ecc_dev, result->y, nbytes);
    294
    295	return 0;
    296}
    297
    298/**
    299 * kmb_ecc_do_scalar_op() - Perform Scalar operation using OCS ECC HW.
    300 * @ecc_dev:	The OCS ECC device to use.
    301 * @scalar_out:	Where to store the output scalar.
    302 * @scalar_a:	Input scalar operand 'a'.
    303 * @scalar_b:	Input scalar operand 'b'
    304 * @curve:	The curve on which the operation is performed.
    305 * @ndigits:	The size of the operands (in digits).
    306 * @inst:	The operation to perform (as an OCS ECC instruction).
    307 *
    308 * Return:	0 on success, negative error code otherwise.
    309 */
    310static int kmb_ecc_do_scalar_op(struct ocs_ecc_dev *ecc_dev, u64 *scalar_out,
    311				const u64 *scalar_a, const u64 *scalar_b,
    312				const struct ecc_curve *curve,
    313				unsigned int ndigits, const u32 inst)
    314{
    315	u32 op_size = (ndigits > ECC_CURVE_NIST_P256_DIGITS) ?
    316		      OCS_ECC_OP_SIZE_384 : OCS_ECC_OP_SIZE_256;
    317	size_t nbytes = digits_to_bytes(ndigits);
    318	int rc;
    319
    320	/* Wait engine to be idle before starting new operation. */
    321	rc = ocs_ecc_wait_idle(ecc_dev);
    322	if (rc)
    323		return rc;
    324
    325	/* Send ecc_start pulse as well as indicating operation size. */
    326	ocs_ecc_cmd_start(ecc_dev, op_size);
    327
    328	/* Write ax param (Base point (Gx).*/
    329	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AX,
    330				   scalar_a, nbytes);
    331
    332	/* Write ay param Base point (Gy).*/
    333	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_AY,
    334				   scalar_b, nbytes);
    335
    336	/* Write p = curve prime(GF modulus).*/
    337	ocs_ecc_write_cmd_and_data(ecc_dev, op_size, OCS_ECC_INST_WRITE_P,
    338				   curve->p, nbytes);
    339
    340	/* Give instruction A.B or A+B to ECC engine. */
    341	rc = ocs_ecc_trigger_op(ecc_dev, op_size, inst);
    342	if (rc)
    343		return rc;
    344
    345	ocs_ecc_read_cx_out(ecc_dev, scalar_out, nbytes);
    346
    347	if (vli_is_zero(scalar_out, ndigits))
    348		return -EINVAL;
    349
    350	return 0;
    351}
    352
    353/* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */
    354static int kmb_ocs_ecc_is_pubkey_valid_partial(struct ocs_ecc_dev *ecc_dev,
    355					       const struct ecc_curve *curve,
    356					       struct ecc_point *pk)
    357{
    358	u64 xxx[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
    359	u64 yy[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
    360	u64 w[KMB_ECC_VLI_MAX_DIGITS] = { 0 };
    361	int rc;
    362
    363	if (WARN_ON(pk->ndigits != curve->g.ndigits))
    364		return -EINVAL;
    365
    366	/* Check 1: Verify key is not the zero point. */
    367	if (ecc_point_is_zero(pk))
    368		return -EINVAL;
    369
    370	/* Check 2: Verify key is in the range [0, p-1]. */
    371	if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1)
    372		return -EINVAL;
    373
    374	if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1)
    375		return -EINVAL;
    376
    377	/* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */
    378
    379	 /* y^2 */
    380	/* Compute y^2 -> store in yy */
    381	rc = kmb_ecc_do_scalar_op(ecc_dev, yy, pk->y, pk->y, curve, pk->ndigits,
    382				  OCS_ECC_INST_CALC_A_MUL_B_MODP);
    383	if (rc)
    384		goto exit;
    385
    386	/* x^3 */
    387	/* Assigning w = 3, used for calculating x^3. */
    388	w[0] = POW_CUBE;
    389	/* Load the next stage.*/
    390	rc = kmb_ecc_do_scalar_op(ecc_dev, xxx, pk->x, w, curve, pk->ndigits,
    391				  OCS_ECC_INST_CALC_A_POW_B_MODP);
    392	if (rc)
    393		goto exit;
    394
    395	/* Do a*x -> store in w. */
    396	rc = kmb_ecc_do_scalar_op(ecc_dev, w, curve->a, pk->x, curve,
    397				  pk->ndigits,
    398				  OCS_ECC_INST_CALC_A_MUL_B_MODP);
    399	if (rc)
    400		goto exit;
    401
    402	/* Do ax + b == w + b; store in w. */
    403	rc = kmb_ecc_do_scalar_op(ecc_dev, w, w, curve->b, curve,
    404				  pk->ndigits,
    405				  OCS_ECC_INST_CALC_A_ADD_B_MODP);
    406	if (rc)
    407		goto exit;
    408
    409	/* x^3 + ax + b == x^3 + w -> store in w. */
    410	rc = kmb_ecc_do_scalar_op(ecc_dev, w, xxx, w, curve, pk->ndigits,
    411				  OCS_ECC_INST_CALC_A_ADD_B_MODP);
    412	if (rc)
    413		goto exit;
    414
    415	/* Compare y^2 == x^3 + a·x + b. */
    416	rc = vli_cmp(yy, w, pk->ndigits);
    417	if (rc)
    418		rc = -EINVAL;
    419
    420exit:
    421	memzero_explicit(xxx, sizeof(xxx));
    422	memzero_explicit(yy, sizeof(yy));
    423	memzero_explicit(w, sizeof(w));
    424
    425	return rc;
    426}
    427
    428/* SP800-56A section 5.6.2.3.3 full verification */
    429static int kmb_ocs_ecc_is_pubkey_valid_full(struct ocs_ecc_dev *ecc_dev,
    430					    const struct ecc_curve *curve,
    431					    struct ecc_point *pk)
    432{
    433	struct ecc_point *nQ;
    434	int rc;
    435
    436	/* Checks 1 through 3 */
    437	rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
    438	if (rc)
    439		return rc;
    440
    441	/* Check 4: Verify that nQ is the zero point. */
    442	nQ = ecc_alloc_point(pk->ndigits);
    443	if (!nQ)
    444		return -ENOMEM;
    445
    446	rc = kmb_ecc_point_mult(ecc_dev, nQ, pk, curve->n, curve);
    447	if (rc)
    448		goto exit;
    449
    450	if (!ecc_point_is_zero(nQ))
    451		rc = -EINVAL;
    452
    453exit:
    454	ecc_free_point(nQ);
    455
    456	return rc;
    457}
    458
    459static int kmb_ecc_is_key_valid(const struct ecc_curve *curve,
    460				const u64 *private_key, size_t private_key_len)
    461{
    462	size_t ndigits = curve->g.ndigits;
    463	u64 one[KMB_ECC_VLI_MAX_DIGITS] = {1};
    464	u64 res[KMB_ECC_VLI_MAX_DIGITS];
    465
    466	if (private_key_len != digits_to_bytes(ndigits))
    467		return -EINVAL;
    468
    469	if (!private_key)
    470		return -EINVAL;
    471
    472	/* Make sure the private key is in the range [2, n-3]. */
    473	if (vli_cmp(one, private_key, ndigits) != -1)
    474		return -EINVAL;
    475
    476	vli_sub(res, curve->n, one, ndigits);
    477	vli_sub(res, res, one, ndigits);
    478	if (vli_cmp(res, private_key, ndigits) != 1)
    479		return -EINVAL;
    480
    481	return 0;
    482}
    483
    484/*
    485 * ECC private keys are generated using the method of extra random bits,
    486 * equivalent to that described in FIPS 186-4, Appendix B.4.1.
    487 *
    488 * d = (c mod(n–1)) + 1    where c is a string of random bits, 64 bits longer
    489 *                         than requested
    490 * 0 <= c mod(n-1) <= n-2  and implies that
    491 * 1 <= d <= n-1
    492 *
    493 * This method generates a private key uniformly distributed in the range
    494 * [1, n-1].
    495 */
    496static int kmb_ecc_gen_privkey(const struct ecc_curve *curve, u64 *privkey)
    497{
    498	size_t nbytes = digits_to_bytes(curve->g.ndigits);
    499	u64 priv[KMB_ECC_VLI_MAX_DIGITS];
    500	size_t nbits;
    501	int rc;
    502
    503	nbits = vli_num_bits(curve->n, curve->g.ndigits);
    504
    505	/* Check that N is included in Table 1 of FIPS 186-4, section 6.1.1 */
    506	if (nbits < 160 || curve->g.ndigits > ARRAY_SIZE(priv))
    507		return -EINVAL;
    508
    509	/*
    510	 * FIPS 186-4 recommends that the private key should be obtained from a
    511	 * RBG with a security strength equal to or greater than the security
    512	 * strength associated with N.
    513	 *
    514	 * The maximum security strength identified by NIST SP800-57pt1r4 for
    515	 * ECC is 256 (N >= 512).
    516	 *
    517	 * This condition is met by the default RNG because it selects a favored
    518	 * DRBG with a security strength of 256.
    519	 */
    520	if (crypto_get_default_rng())
    521		return -EFAULT;
    522
    523	rc = crypto_rng_get_bytes(crypto_default_rng, (u8 *)priv, nbytes);
    524	crypto_put_default_rng();
    525	if (rc)
    526		goto cleanup;
    527
    528	rc = kmb_ecc_is_key_valid(curve, priv, nbytes);
    529	if (rc)
    530		goto cleanup;
    531
    532	ecc_swap_digits(priv, privkey, curve->g.ndigits);
    533
    534cleanup:
    535	memzero_explicit(&priv, sizeof(priv));
    536
    537	return rc;
    538}
    539
    540static int kmb_ocs_ecdh_set_secret(struct crypto_kpp *tfm, const void *buf,
    541				   unsigned int len)
    542{
    543	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
    544	struct ecdh params;
    545	int rc = 0;
    546
    547	rc = crypto_ecdh_decode_key(buf, len, &params);
    548	if (rc)
    549		goto cleanup;
    550
    551	/* Ensure key size is not bigger then expected. */
    552	if (params.key_size > digits_to_bytes(tctx->curve->g.ndigits)) {
    553		rc = -EINVAL;
    554		goto cleanup;
    555	}
    556
    557	/* Auto-generate private key is not provided. */
    558	if (!params.key || !params.key_size) {
    559		rc = kmb_ecc_gen_privkey(tctx->curve, tctx->private_key);
    560		goto cleanup;
    561	}
    562
    563	rc = kmb_ecc_is_key_valid(tctx->curve, (const u64 *)params.key,
    564				  params.key_size);
    565	if (rc)
    566		goto cleanup;
    567
    568	ecc_swap_digits((const u64 *)params.key, tctx->private_key,
    569			tctx->curve->g.ndigits);
    570cleanup:
    571	memzero_explicit(&params, sizeof(params));
    572
    573	if (rc)
    574		tctx->curve = NULL;
    575
    576	return rc;
    577}
    578
    579/* Compute shared secret. */
    580static int kmb_ecc_do_shared_secret(struct ocs_ecc_ctx *tctx,
    581				    struct kpp_request *req)
    582{
    583	struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
    584	const struct ecc_curve *curve = tctx->curve;
    585	u64 shared_secret[KMB_ECC_VLI_MAX_DIGITS];
    586	u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
    587	size_t copied, nbytes, pubk_len;
    588	struct ecc_point *pk, *result;
    589	int rc;
    590
    591	nbytes = digits_to_bytes(curve->g.ndigits);
    592
    593	/* Public key is a point, thus it has two coordinates */
    594	pubk_len = 2 * nbytes;
    595
    596	/* Copy public key from SG list to pubk_buf. */
    597	copied = sg_copy_to_buffer(req->src,
    598				   sg_nents_for_len(req->src, pubk_len),
    599				   pubk_buf, pubk_len);
    600	if (copied != pubk_len)
    601		return -EINVAL;
    602
    603	/* Allocate and initialize public key point. */
    604	pk = ecc_alloc_point(curve->g.ndigits);
    605	if (!pk)
    606		return -ENOMEM;
    607
    608	ecc_swap_digits(pubk_buf, pk->x, curve->g.ndigits);
    609	ecc_swap_digits(&pubk_buf[curve->g.ndigits], pk->y, curve->g.ndigits);
    610
    611	/*
    612	 * Check the public key for following
    613	 * Check 1: Verify key is not the zero point.
    614	 * Check 2: Verify key is in the range [1, p-1].
    615	 * Check 3: Verify that y^2 == (x^3 + a·x + b) mod p
    616	 */
    617	rc = kmb_ocs_ecc_is_pubkey_valid_partial(ecc_dev, curve, pk);
    618	if (rc)
    619		goto exit_free_pk;
    620
    621	/* Allocate point for storing computed shared secret. */
    622	result = ecc_alloc_point(pk->ndigits);
    623	if (!result) {
    624		rc = -ENOMEM;
    625		goto exit_free_pk;
    626	}
    627
    628	/* Calculate the shared secret.*/
    629	rc = kmb_ecc_point_mult(ecc_dev, result, pk, tctx->private_key, curve);
    630	if (rc)
    631		goto exit_free_result;
    632
    633	if (ecc_point_is_zero(result)) {
    634		rc = -EFAULT;
    635		goto exit_free_result;
    636	}
    637
    638	/* Copy shared secret from point to buffer. */
    639	ecc_swap_digits(result->x, shared_secret, result->ndigits);
    640
    641	/* Request might ask for less bytes than what we have. */
    642	nbytes = min_t(size_t, nbytes, req->dst_len);
    643
    644	copied = sg_copy_from_buffer(req->dst,
    645				     sg_nents_for_len(req->dst, nbytes),
    646				     shared_secret, nbytes);
    647
    648	if (copied != nbytes)
    649		rc = -EINVAL;
    650
    651	memzero_explicit(shared_secret, sizeof(shared_secret));
    652
    653exit_free_result:
    654	ecc_free_point(result);
    655
    656exit_free_pk:
    657	ecc_free_point(pk);
    658
    659	return rc;
    660}
    661
    662/* Compute public key. */
    663static int kmb_ecc_do_public_key(struct ocs_ecc_ctx *tctx,
    664				 struct kpp_request *req)
    665{
    666	const struct ecc_curve *curve = tctx->curve;
    667	u64 pubk_buf[KMB_ECC_VLI_MAX_DIGITS * 2];
    668	struct ecc_point *pk;
    669	size_t pubk_len;
    670	size_t copied;
    671	int rc;
    672
    673	/* Public key is a point, so it has double the digits. */
    674	pubk_len = 2 * digits_to_bytes(curve->g.ndigits);
    675
    676	pk = ecc_alloc_point(curve->g.ndigits);
    677	if (!pk)
    678		return -ENOMEM;
    679
    680	/* Public Key(pk) = priv * G. */
    681	rc = kmb_ecc_point_mult(tctx->ecc_dev, pk, &curve->g, tctx->private_key,
    682				curve);
    683	if (rc)
    684		goto exit;
    685
    686	/* SP800-56A rev 3 5.6.2.1.3 key check */
    687	if (kmb_ocs_ecc_is_pubkey_valid_full(tctx->ecc_dev, curve, pk)) {
    688		rc = -EAGAIN;
    689		goto exit;
    690	}
    691
    692	/* Copy public key from point to buffer. */
    693	ecc_swap_digits(pk->x, pubk_buf, pk->ndigits);
    694	ecc_swap_digits(pk->y, &pubk_buf[pk->ndigits], pk->ndigits);
    695
    696	/* Copy public key to req->dst. */
    697	copied = sg_copy_from_buffer(req->dst,
    698				     sg_nents_for_len(req->dst, pubk_len),
    699				     pubk_buf, pubk_len);
    700
    701	if (copied != pubk_len)
    702		rc = -EINVAL;
    703
    704exit:
    705	ecc_free_point(pk);
    706
    707	return rc;
    708}
    709
    710static int kmb_ocs_ecc_do_one_request(struct crypto_engine *engine,
    711				      void *areq)
    712{
    713	struct kpp_request *req = container_of(areq, struct kpp_request, base);
    714	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
    715	struct ocs_ecc_dev *ecc_dev = tctx->ecc_dev;
    716	int rc;
    717
    718	if (req->src)
    719		rc = kmb_ecc_do_shared_secret(tctx, req);
    720	else
    721		rc = kmb_ecc_do_public_key(tctx, req);
    722
    723	crypto_finalize_kpp_request(ecc_dev->engine, req, rc);
    724
    725	return 0;
    726}
    727
    728static int kmb_ocs_ecdh_generate_public_key(struct kpp_request *req)
    729{
    730	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
    731	const struct ecc_curve *curve = tctx->curve;
    732
    733	/* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
    734	if (!tctx->curve)
    735		return -EINVAL;
    736
    737	/* Ensure dst is present. */
    738	if (!req->dst)
    739		return -EINVAL;
    740
    741	/* Check the request dst is big enough to hold the public key. */
    742	if (req->dst_len < (2 * digits_to_bytes(curve->g.ndigits)))
    743		return -EINVAL;
    744
    745	/* 'src' is not supposed to be present when generate pubk is called. */
    746	if (req->src)
    747		return -EINVAL;
    748
    749	return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
    750						     req);
    751}
    752
    753static int kmb_ocs_ecdh_compute_shared_secret(struct kpp_request *req)
    754{
    755	struct ocs_ecc_ctx *tctx = kmb_ocs_ecc_tctx(req);
    756	const struct ecc_curve *curve = tctx->curve;
    757
    758	/* Ensure kmb_ocs_ecdh_set_secret() has been successfully called. */
    759	if (!tctx->curve)
    760		return -EINVAL;
    761
    762	/* Ensure dst is present. */
    763	if (!req->dst)
    764		return -EINVAL;
    765
    766	/* Ensure src is present. */
    767	if (!req->src)
    768		return -EINVAL;
    769
    770	/*
    771	 * req->src is expected to the (other-side) public key, so its length
    772	 * must be 2 * coordinate size (in bytes).
    773	 */
    774	if (req->src_len != 2 * digits_to_bytes(curve->g.ndigits))
    775		return -EINVAL;
    776
    777	return crypto_transfer_kpp_request_to_engine(tctx->ecc_dev->engine,
    778						     req);
    779}
    780
    781static int kmb_ecc_tctx_init(struct ocs_ecc_ctx *tctx, unsigned int curve_id)
    782{
    783	memset(tctx, 0, sizeof(*tctx));
    784
    785	tctx->ecc_dev = kmb_ocs_ecc_find_dev(tctx);
    786
    787	if (IS_ERR(tctx->ecc_dev)) {
    788		pr_err("Failed to find the device : %ld\n",
    789		       PTR_ERR(tctx->ecc_dev));
    790		return PTR_ERR(tctx->ecc_dev);
    791	}
    792
    793	tctx->curve = ecc_get_curve(curve_id);
    794	if (!tctx->curve)
    795		return -EOPNOTSUPP;
    796
    797	tctx->engine_ctx.op.prepare_request = NULL;
    798	tctx->engine_ctx.op.do_one_request = kmb_ocs_ecc_do_one_request;
    799	tctx->engine_ctx.op.unprepare_request = NULL;
    800
    801	return 0;
    802}
    803
    804static int kmb_ocs_ecdh_nist_p256_init_tfm(struct crypto_kpp *tfm)
    805{
    806	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
    807
    808	return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P256);
    809}
    810
    811static int kmb_ocs_ecdh_nist_p384_init_tfm(struct crypto_kpp *tfm)
    812{
    813	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
    814
    815	return kmb_ecc_tctx_init(tctx, ECC_CURVE_NIST_P384);
    816}
    817
    818static void kmb_ocs_ecdh_exit_tfm(struct crypto_kpp *tfm)
    819{
    820	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
    821
    822	memzero_explicit(tctx->private_key, sizeof(*tctx->private_key));
    823}
    824
    825static unsigned int kmb_ocs_ecdh_max_size(struct crypto_kpp *tfm)
    826{
    827	struct ocs_ecc_ctx *tctx = kpp_tfm_ctx(tfm);
    828
    829	/* Public key is made of two coordinates, so double the digits. */
    830	return digits_to_bytes(tctx->curve->g.ndigits) * 2;
    831}
    832
    833static struct kpp_alg ocs_ecdh_p256 = {
    834	.set_secret = kmb_ocs_ecdh_set_secret,
    835	.generate_public_key = kmb_ocs_ecdh_generate_public_key,
    836	.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
    837	.init = kmb_ocs_ecdh_nist_p256_init_tfm,
    838	.exit = kmb_ocs_ecdh_exit_tfm,
    839	.max_size = kmb_ocs_ecdh_max_size,
    840	.base = {
    841		.cra_name = "ecdh-nist-p256",
    842		.cra_driver_name = "ecdh-nist-p256-keembay-ocs",
    843		.cra_priority = KMB_OCS_ECC_PRIORITY,
    844		.cra_module = THIS_MODULE,
    845		.cra_ctxsize = sizeof(struct ocs_ecc_ctx),
    846	},
    847};
    848
    849static struct kpp_alg ocs_ecdh_p384 = {
    850	.set_secret = kmb_ocs_ecdh_set_secret,
    851	.generate_public_key = kmb_ocs_ecdh_generate_public_key,
    852	.compute_shared_secret = kmb_ocs_ecdh_compute_shared_secret,
    853	.init = kmb_ocs_ecdh_nist_p384_init_tfm,
    854	.exit = kmb_ocs_ecdh_exit_tfm,
    855	.max_size = kmb_ocs_ecdh_max_size,
    856	.base = {
    857		.cra_name = "ecdh-nist-p384",
    858		.cra_driver_name = "ecdh-nist-p384-keembay-ocs",
    859		.cra_priority = KMB_OCS_ECC_PRIORITY,
    860		.cra_module = THIS_MODULE,
    861		.cra_ctxsize = sizeof(struct ocs_ecc_ctx),
    862	},
    863};
    864
    865static irqreturn_t ocs_ecc_irq_handler(int irq, void *dev_id)
    866{
    867	struct ocs_ecc_dev *ecc_dev = dev_id;
    868	u32 status;
    869
    870	/*
    871	 * Read the status register and write it back to clear the
    872	 * DONE_INT_STATUS bit.
    873	 */
    874	status = ioread32(ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
    875	iowrite32(status, ecc_dev->base_reg + HW_OFFS_OCS_ECC_ISR);
    876
    877	if (!(status & HW_OCS_ECC_ISR_INT_STATUS_DONE))
    878		return IRQ_NONE;
    879
    880	complete(&ecc_dev->irq_done);
    881
    882	return IRQ_HANDLED;
    883}
    884
    885static int kmb_ocs_ecc_probe(struct platform_device *pdev)
    886{
    887	struct device *dev = &pdev->dev;
    888	struct ocs_ecc_dev *ecc_dev;
    889	int rc;
    890
    891	ecc_dev = devm_kzalloc(dev, sizeof(*ecc_dev), GFP_KERNEL);
    892	if (!ecc_dev)
    893		return -ENOMEM;
    894
    895	ecc_dev->dev = dev;
    896
    897	platform_set_drvdata(pdev, ecc_dev);
    898
    899	INIT_LIST_HEAD(&ecc_dev->list);
    900	init_completion(&ecc_dev->irq_done);
    901
    902	/* Get base register address. */
    903	ecc_dev->base_reg = devm_platform_ioremap_resource(pdev, 0);
    904	if (IS_ERR(ecc_dev->base_reg)) {
    905		dev_err(dev, "Failed to get base address\n");
    906		rc = PTR_ERR(ecc_dev->base_reg);
    907		goto list_del;
    908	}
    909
    910	/* Get and request IRQ */
    911	ecc_dev->irq = platform_get_irq(pdev, 0);
    912	if (ecc_dev->irq < 0) {
    913		rc = ecc_dev->irq;
    914		goto list_del;
    915	}
    916
    917	rc = devm_request_threaded_irq(dev, ecc_dev->irq, ocs_ecc_irq_handler,
    918				       NULL, 0, "keembay-ocs-ecc", ecc_dev);
    919	if (rc < 0) {
    920		dev_err(dev, "Could not request IRQ\n");
    921		goto list_del;
    922	}
    923
    924	/* Add device to the list of OCS ECC devices. */
    925	spin_lock(&ocs_ecc.lock);
    926	list_add_tail(&ecc_dev->list, &ocs_ecc.dev_list);
    927	spin_unlock(&ocs_ecc.lock);
    928
    929	/* Initialize crypto engine. */
    930	ecc_dev->engine = crypto_engine_alloc_init(dev, 1);
    931	if (!ecc_dev->engine) {
    932		dev_err(dev, "Could not allocate crypto engine\n");
    933		rc = -ENOMEM;
    934		goto list_del;
    935	}
    936
    937	rc = crypto_engine_start(ecc_dev->engine);
    938	if (rc) {
    939		dev_err(dev, "Could not start crypto engine\n");
    940		goto cleanup;
    941	}
    942
    943	/* Register the KPP algo. */
    944	rc = crypto_register_kpp(&ocs_ecdh_p256);
    945	if (rc) {
    946		dev_err(dev,
    947			"Could not register OCS algorithms with Crypto API\n");
    948		goto cleanup;
    949	}
    950
    951	rc = crypto_register_kpp(&ocs_ecdh_p384);
    952	if (rc) {
    953		dev_err(dev,
    954			"Could not register OCS algorithms with Crypto API\n");
    955		goto ocs_ecdh_p384_error;
    956	}
    957
    958	return 0;
    959
    960ocs_ecdh_p384_error:
    961	crypto_unregister_kpp(&ocs_ecdh_p256);
    962
    963cleanup:
    964	crypto_engine_exit(ecc_dev->engine);
    965
    966list_del:
    967	spin_lock(&ocs_ecc.lock);
    968	list_del(&ecc_dev->list);
    969	spin_unlock(&ocs_ecc.lock);
    970
    971	return rc;
    972}
    973
    974static int kmb_ocs_ecc_remove(struct platform_device *pdev)
    975{
    976	struct ocs_ecc_dev *ecc_dev;
    977
    978	ecc_dev = platform_get_drvdata(pdev);
    979	if (!ecc_dev)
    980		return -ENODEV;
    981
    982	crypto_unregister_kpp(&ocs_ecdh_p384);
    983	crypto_unregister_kpp(&ocs_ecdh_p256);
    984
    985	spin_lock(&ocs_ecc.lock);
    986	list_del(&ecc_dev->list);
    987	spin_unlock(&ocs_ecc.lock);
    988
    989	crypto_engine_exit(ecc_dev->engine);
    990
    991	return 0;
    992}
    993
    994/* Device tree driver match. */
    995static const struct of_device_id kmb_ocs_ecc_of_match[] = {
    996	{
    997		.compatible = "intel,keembay-ocs-ecc",
    998	},
    999	{}
   1000};
   1001
   1002/* The OCS driver is a platform device. */
   1003static struct platform_driver kmb_ocs_ecc_driver = {
   1004	.probe = kmb_ocs_ecc_probe,
   1005	.remove = kmb_ocs_ecc_remove,
   1006	.driver = {
   1007			.name = DRV_NAME,
   1008			.of_match_table = kmb_ocs_ecc_of_match,
   1009		},
   1010};
   1011module_platform_driver(kmb_ocs_ecc_driver);
   1012
   1013MODULE_LICENSE("GPL");
   1014MODULE_DESCRIPTION("Intel Keem Bay OCS ECC Driver");
   1015MODULE_ALIAS_CRYPTO("ecdh-nist-p256");
   1016MODULE_ALIAS_CRYPTO("ecdh-nist-p384");
   1017MODULE_ALIAS_CRYPTO("ecdh-nist-p256-keembay-ocs");
   1018MODULE_ALIAS_CRYPTO("ecdh-nist-p384-keembay-ocs");