libwd/0028-uadk-cipher-isa_ce-support-SM4-cbc_cts-mode.patch
JangShui Yang e072f742a4 libwd: update the source code
(cherry picked from commit dc42b3a676205c1a1c922628a993887e1ad2988f)
2024-04-07 18:59:45 +08:00

338 lines
10 KiB
Diff

From 8c23969dacd7b1ae1b77c1118a8f895bec6fd165 Mon Sep 17 00:00:00 2001
From: Yang Shen <shenyang39@huawei.com>
Date: Wed, 20 Mar 2024 16:15:00 +0800
Subject: [PATCH 28/44] uadk/cipher: isa_ce - support SM4 cbc_cts mode
This patch implements the CE instruction using SM4 CBC_CTS modes.
Signed-off-by: Yang Shen <shenyang39@huawei.com>
Signed-off-by: Qi Tao <taoqi10@huawei.com>
---
drv/isa_ce_sm4.c | 91 +++++++++++++++++++++++++++-
drv/isa_ce_sm4.h | 24 +++++---
drv/isa_ce_sm4_armv8.S | 133 +++++++++++++++++++++++++++++++++++++++++
3 files changed, 238 insertions(+), 10 deletions(-)
diff --git a/drv/isa_ce_sm4.c b/drv/isa_ce_sm4.c
index ccab8fb..6961471 100644
--- a/drv/isa_ce_sm4.c
+++ b/drv/isa_ce_sm4.c
@@ -128,6 +128,82 @@ static void sm4_cbc_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke
sm4_v8_cbc_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, msg->iv, SM4_DECRYPT);
}
+/*
+ * In some situations, the cts mode can use cbc mode instead to imporve performance.
+ */
+static int sm4_cts_cbc_instead(struct wd_cipher_msg *msg)
+{
+ if (msg->in_bytes == SM4_BLOCK_SIZE)
+ return true;
+
+ if (!(msg->in_bytes % SM4_BLOCK_SIZE) && msg->mode != WD_CIPHER_CBC_CS3)
+ return true;
+
+ return false;
+}
+
+static void sm4_cts_cs1_mode_adapt(__u8 *cts_in, __u8 *cts_out,
+ const __u32 cts_bytes, const int enc)
+{
+ __u32 rsv_bytes = cts_bytes % SM4_BLOCK_SIZE;
+ __u8 blocks[SM4_BLOCK_SIZE] = {0};
+
+ if (enc == SM4_ENCRYPT) {
+ memcpy(blocks, cts_out, SM4_BLOCK_SIZE);
+ memcpy(cts_out, cts_out + SM4_BLOCK_SIZE, rsv_bytes);
+ memcpy(cts_out + rsv_bytes, blocks, SM4_BLOCK_SIZE);
+ } else {
+ memcpy(blocks, cts_in + rsv_bytes, SM4_BLOCK_SIZE);
+ memcpy(cts_in + SM4_BLOCK_SIZE, cts_in, rsv_bytes);
+ memcpy(cts_in, blocks, SM4_BLOCK_SIZE);
+ }
+}
+
+static void sm4_cts_cbc_crypt(struct wd_cipher_msg *msg,
+ const struct SM4_KEY *rkey_enc, const int enc)
+{
+ enum wd_cipher_mode mode = msg->mode;
+ __u32 in_bytes = msg->in_bytes;
+ __u8 *cts_in, *cts_out;
+ __u32 cts_bytes;
+
+ if (sm4_cts_cbc_instead(msg))
+ return sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc);
+
+ cts_bytes = in_bytes % SM4_BLOCK_SIZE + SM4_BLOCK_SIZE;
+ if (cts_bytes == SM4_BLOCK_SIZE)
+ cts_bytes += SM4_BLOCK_SIZE;
+
+ in_bytes -= cts_bytes;
+ if (in_bytes)
+ sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc);
+
+ cts_in = msg->in + in_bytes;
+ cts_out = msg->out + in_bytes;
+
+ if (enc == SM4_ENCRYPT) {
+ sm4_v8_cbc_cts_encrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv);
+
+ if (mode == WD_CIPHER_CBC_CS1)
+ sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc);
+ } else {
+ if (mode == WD_CIPHER_CBC_CS1)
+ sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc);
+
+ sm4_v8_cbc_cts_decrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv);
+ }
+}
+
+static void sm4_cbc_cts_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
+{
+ sm4_cts_cbc_crypt(msg, rkey_enc, SM4_ENCRYPT);
+}
+
+static void sm4_cbc_cts_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
+{
+ sm4_cts_cbc_crypt(msg, rkey_enc, SM4_DECRYPT);
+}
+
static void sm4_ecb_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
{
sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_enc, SM4_ENCRYPT);
@@ -138,12 +214,12 @@ static void sm4_ecb_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke
sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, SM4_DECRYPT);
}
-void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key)
+static void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key)
{
sm4_v8_set_encrypt_key(userKey, key);
}
-void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key)
+static void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key)
{
sm4_v8_set_decrypt_key(userKey, key);
}
@@ -276,6 +352,14 @@ static int isa_ce_cipher_send(struct wd_alg_driver *drv, handle_t ctx, void *wd_
else
sm4_cbc_decrypt(msg, &rkey);
break;
+ case WD_CIPHER_CBC_CS1:
+ case WD_CIPHER_CBC_CS2:
+ case WD_CIPHER_CBC_CS3:
+ if (msg->op_type == WD_CIPHER_ENCRYPTION)
+ sm4_cbc_cts_encrypt(msg, &rkey);
+ else
+ sm4_cbc_cts_decrypt(msg, &rkey);
+ break;
case WD_CIPHER_CTR:
sm4_ctr_encrypt(msg, &rkey);
break;
@@ -330,6 +414,9 @@ static int cipher_recv(struct wd_alg_driver *drv, handle_t ctx, void *msg)
static struct wd_alg_driver cipher_alg_driver[] = {
GEN_CE_ALG_DRIVER("cbc(sm4)", cipher),
+ GEN_CE_ALG_DRIVER("cbc-cs1(sm4)", cipher),
+ GEN_CE_ALG_DRIVER("cbc-cs2(sm4)", cipher),
+ GEN_CE_ALG_DRIVER("cbc-cs3(sm4)", cipher),
GEN_CE_ALG_DRIVER("ctr(sm4)", cipher),
GEN_CE_ALG_DRIVER("cfb(sm4)", cipher),
GEN_CE_ALG_DRIVER("xts(sm4)", cipher),
diff --git a/drv/isa_ce_sm4.h b/drv/isa_ce_sm4.h
index d10b0af..308619e 100644
--- a/drv/isa_ce_sm4.h
+++ b/drv/isa_ce_sm4.h
@@ -25,27 +25,35 @@ struct sm4_ce_drv_ctx {
void sm4_v8_set_encrypt_key(const unsigned char *userKey, struct SM4_KEY *key);
void sm4_v8_set_decrypt_key(const unsigned char *userKey, struct SM4_KEY *key);
+
void sm4_v8_cbc_encrypt(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key,
unsigned char *ivec, const int enc);
+void sm4_v8_cbc_cts_encrypt(const unsigned char *in, unsigned char *out,
+ size_t len, const void *key, const unsigned char ivec[16]);
+void sm4_v8_cbc_cts_decrypt(const unsigned char *in, unsigned char *out,
+ size_t len, const void *key, const unsigned char ivec[16]);
+
void sm4_v8_ecb_encrypt(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key, const int enc);
+
void sm4_v8_ctr32_encrypt_blocks(const unsigned char *in, unsigned char *out,
- size_t len, const void *key, const unsigned char ivec[16]);
+ size_t len, const void *key, const unsigned char ivec[16]);
void sm4_v8_cfb_encrypt_blocks(const unsigned char *in, unsigned char *out,
- size_t length, const struct SM4_KEY *key, unsigned char *ivec);
+ size_t length, const struct SM4_KEY *key, unsigned char *ivec);
void sm4_v8_cfb_decrypt_blocks(const unsigned char *in, unsigned char *out,
- size_t length, const struct SM4_KEY *key, unsigned char *ivec);
+ size_t length, const struct SM4_KEY *key, unsigned char *ivec);
+
void sm4_v8_crypt_block(const unsigned char *in, unsigned char *out,
- const struct SM4_KEY *key);
+ const struct SM4_KEY *key);
int sm4_v8_xts_encrypt(const unsigned char *in, unsigned char *out, size_t length,
- const struct SM4_KEY *key, unsigned char *ivec,
- const struct SM4_KEY *key2);
+ const struct SM4_KEY *key, unsigned char *ivec,
+ const struct SM4_KEY *key2);
int sm4_v8_xts_decrypt(const unsigned char *in, unsigned char *out, size_t length,
- const struct SM4_KEY *key, unsigned char *ivec,
- const struct SM4_KEY *key2);
+ const struct SM4_KEY *key, unsigned char *ivec,
+ const struct SM4_KEY *key2);
#ifdef __cplusplus
}
diff --git a/drv/isa_ce_sm4_armv8.S b/drv/isa_ce_sm4_armv8.S
index 7d84496..6ebf39b 100644
--- a/drv/isa_ce_sm4_armv8.S
+++ b/drv/isa_ce_sm4_armv8.S
@@ -506,6 +506,139 @@ sm4_v8_cbc_encrypt:
ldp d8,d9,[sp],#16
ret
.size sm4_v8_cbc_encrypt,.-sm4_v8_cbc_encrypt
+
+.globl sm4_v8_cbc_cts_encrypt
+.type sm4_v8_cbc_cts_encrypt,%function
+.align 5
+sm4_v8_cbc_cts_encrypt:
+ AARCH64_VALID_CALL_TARGET
+ ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64
+ ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3]
+ sub x5, x2, #16
+
+ ld1 {v8.4s}, [x4]
+
+ ld1 {v10.4s}, [x0]
+ eor v8.16b, v8.16b, v10.16b
+ rev32 v8.16b, v8.16b;
+ sm4e v8.4s, v0.4s;
+ sm4e v8.4s, v1.4s;
+ sm4e v8.4s, v2.4s;
+ sm4e v8.4s, v3.4s;
+ sm4e v8.4s, v4.4s;
+ sm4e v8.4s, v5.4s;
+ sm4e v8.4s, v6.4s;
+ sm4e v8.4s, v7.4s;
+ rev64 v8.4s, v8.4s;
+ ext v8.16b, v8.16b, v8.16b, #8;
+ rev32 v8.16b, v8.16b;
+
+ /* load permute table */
+ adr x6, .cts_permute_table
+ add x7, x6, #32
+ add x6, x6, x5
+ sub x7, x7, x5
+ ld1 {v13.4s}, [x6]
+ ld1 {v14.4s}, [x7]
+
+ /* overlapping loads */
+ add x0, x0, x5
+ ld1 {v11.4s}, [x0]
+
+ /* create Cn from En-1 */
+ tbl v10.16b, {v8.16b}, v13.16b
+ /* padding Pn with zeros */
+ tbl v11.16b, {v11.16b}, v14.16b
+
+ eor v11.16b, v11.16b, v8.16b
+ rev32 v11.16b, v11.16b;
+ sm4e v11.4s, v0.4s;
+ sm4e v11.4s, v1.4s;
+ sm4e v11.4s, v2.4s;
+ sm4e v11.4s, v3.4s;
+ sm4e v11.4s, v4.4s;
+ sm4e v11.4s, v5.4s;
+ sm4e v11.4s, v6.4s;
+ sm4e v11.4s, v7.4s;
+ rev64 v11.4s, v11.4s;
+ ext v11.16b, v11.16b, v11.16b, #8;
+ rev32 v11.16b, v11.16b;
+
+ /* overlapping stores */
+ add x5, x1, x5
+ st1 {v10.16b}, [x5]
+ st1 {v11.16b}, [x1]
+
+ ret
+.size sm4_v8_cbc_cts_encrypt,.-sm4_v8_cbc_cts_encrypt
+
+.globl sm4_v8_cbc_cts_decrypt
+.type sm4_v8_cbc_cts_decrypt,%function
+.align 5
+sm4_v8_cbc_cts_decrypt:
+ AARCH64_VALID_CALL_TARGET
+ ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64
+ ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3]
+
+ sub x5, x2, #16
+
+ ld1 {v8.4s}, [x4]
+
+ /* load permute table */
+ adr x6, .cts_permute_table
+ add x7, x6, #32
+ add x6, x6, x5
+ sub x7, x7, x5
+ ld1 {v13.4s}, [x6]
+ ld1 {v14.4s}, [x7]
+
+ /* overlapping loads */
+ ld1 {v10.16b}, [x0], x5
+ ld1 {v11.16b}, [x0]
+
+ rev32 v10.16b, v10.16b;
+ sm4e v10.4s, v0.4s;
+ sm4e v10.4s, v1.4s;
+ sm4e v10.4s, v2.4s;
+ sm4e v10.4s, v3.4s;
+ sm4e v10.4s, v4.4s;
+ sm4e v10.4s, v5.4s;
+ sm4e v10.4s, v6.4s;
+ sm4e v10.4s, v7.4s;
+ rev64 v10.4s, v10.4s;
+ ext v10.16b, v10.16b, v10.16b, #8;
+ rev32 v10.16b, v10.16b;
+
+ /* select the first Ln bytes of Xn to create Pn */
+ tbl v12.16b, {v10.16b}, v13.16b
+ eor v12.16b, v12.16b, v11.16b
+
+ /* overwrite the first Ln bytes with Cn to create En-1 */
+ tbx v10.16b, {v11.16b}, v14.16b
+
+ rev32 v10.16b, v10.16b;
+ sm4e v10.4s, v0.4s;
+ sm4e v10.4s, v1.4s;
+ sm4e v10.4s, v2.4s;
+ sm4e v10.4s, v3.4s;
+ sm4e v10.4s, v4.4s;
+ sm4e v10.4s, v5.4s;
+ sm4e v10.4s, v6.4s;
+ sm4e v10.4s, v7.4s;
+ rev64 v10.4s, v10.4s;
+ ext v10.16b, v10.16b, v10.16b, #8;
+ rev32 v10.16b, v10.16b;
+
+ eor v10.16b, v10.16b, v8.16b
+
+ /* overlapping stores */
+ add x5, x1, x5
+ st1 {v12.16b}, [x5]
+ st1 {v10.16b}, [x1]
+
+ ret
+.size sm4_v8_cbc_cts_decrypt,.-sm4_v8_cbc_cts_decrypt
+
.globl sm4_v8_ecb_encrypt
.type sm4_v8_ecb_encrypt,%function
.align 5
--
2.25.1