crypto: x86/chacha20 - Support partial lengths in 8-block AVX2 variant
authorMartin Willi <martin@strongswan.org>
Sun, 11 Nov 2018 09:36:27 +0000 (10:36 +0100)
committerHerbert Xu <herbert@gondor.apana.org.au>
Fri, 16 Nov 2018 06:11:04 +0000 (14:11 +0800)
Add a length argument to the eight block function for AVX2, so the
block function may XOR only a partial length of eight blocks.

To avoid unnecessary operations, we integrate XORing of the first four
blocks in the final lane interleaving; this also avoids some work in
the partial lengths path.

Signed-off-by: Martin Willi <martin@strongswan.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/x86/crypto/chacha20-avx2-x86_64.S
arch/x86/crypto/chacha20_glue.c

index f3cd26f4833254cb0966f31a1afae8c5c9fc95a4..7b62d55bee3d40ed2d264b9dea391c4a267b92f7 100644 (file)
@@ -30,8 +30,9 @@ CTRINC:       .octa 0x00000003000000020000000100000000
 
 ENTRY(chacha20_8block_xor_avx2)
        # %rdi: Input state matrix, s
-       # %rsi: 8 data blocks output, o
-       # %rdx: 8 data blocks input, i
+       # %rsi: up to 8 data blocks output, o
+       # %rdx: up to 8 data blocks input, i
+       # %rcx: input/output length in bytes
 
        # This function encrypts eight consecutive ChaCha20 blocks by loading
        # the state matrix in AVX registers eight times. As we need some
@@ -48,6 +49,7 @@ ENTRY(chacha20_8block_xor_avx2)
        lea             8(%rsp),%r10
        and             $~31, %rsp
        sub             $0x80, %rsp
+       mov             %rcx,%rax
 
        # x0..15[0-7] = s[0..15]
        vpbroadcastd    0x00(%rdi),%ymm0
@@ -375,74 +377,143 @@ ENTRY(chacha20_8block_xor_avx2)
        vpunpckhqdq     %ymm15,%ymm0,%ymm15
 
        # interleave 128-bit words in state n, n+4
-       vmovdqa         0x00(%rsp),%ymm0
-       vperm2i128      $0x20,%ymm4,%ymm0,%ymm1
-       vperm2i128      $0x31,%ymm4,%ymm0,%ymm4
-       vmovdqa         %ymm1,0x00(%rsp)
-       vmovdqa         0x20(%rsp),%ymm0
-       vperm2i128      $0x20,%ymm5,%ymm0,%ymm1
-       vperm2i128      $0x31,%ymm5,%ymm0,%ymm5
-       vmovdqa         %ymm1,0x20(%rsp)
-       vmovdqa         0x40(%rsp),%ymm0
-       vperm2i128      $0x20,%ymm6,%ymm0,%ymm1
-       vperm2i128      $0x31,%ymm6,%ymm0,%ymm6
-       vmovdqa         %ymm1,0x40(%rsp)
-       vmovdqa         0x60(%rsp),%ymm0
-       vperm2i128      $0x20,%ymm7,%ymm0,%ymm1
-       vperm2i128      $0x31,%ymm7,%ymm0,%ymm7
-       vmovdqa         %ymm1,0x60(%rsp)
+       # xor/write first four blocks
+       vmovdqa         0x00(%rsp),%ymm1
+       vperm2i128      $0x20,%ymm4,%ymm1,%ymm0
+       cmp             $0x0020,%rax
+       jl              .Lxorpart8
+       vpxor           0x0000(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0000(%rsi)
+       vperm2i128      $0x31,%ymm4,%ymm1,%ymm4
+
        vperm2i128      $0x20,%ymm12,%ymm8,%ymm0
+       cmp             $0x0040,%rax
+       jl              .Lxorpart8
+       vpxor           0x0020(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0020(%rsi)
        vperm2i128      $0x31,%ymm12,%ymm8,%ymm12
-       vmovdqa         %ymm0,%ymm8
-       vperm2i128      $0x20,%ymm13,%ymm9,%ymm0
-       vperm2i128      $0x31,%ymm13,%ymm9,%ymm13
-       vmovdqa         %ymm0,%ymm9
+
+       vmovdqa         0x40(%rsp),%ymm1
+       vperm2i128      $0x20,%ymm6,%ymm1,%ymm0
+       cmp             $0x0060,%rax
+       jl              .Lxorpart8
+       vpxor           0x0040(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0040(%rsi)
+       vperm2i128      $0x31,%ymm6,%ymm1,%ymm6
+
        vperm2i128      $0x20,%ymm14,%ymm10,%ymm0
+       cmp             $0x0080,%rax
+       jl              .Lxorpart8
+       vpxor           0x0060(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0060(%rsi)
        vperm2i128      $0x31,%ymm14,%ymm10,%ymm14
-       vmovdqa         %ymm0,%ymm10
-       vperm2i128      $0x20,%ymm15,%ymm11,%ymm0
-       vperm2i128      $0x31,%ymm15,%ymm11,%ymm15
-       vmovdqa         %ymm0,%ymm11
 
-       # xor with corresponding input, write to output
-       vmovdqa         0x00(%rsp),%ymm0
-       vpxor           0x0000(%rdx),%ymm0,%ymm0
-       vmovdqu         %ymm0,0x0000(%rsi)
-       vmovdqa         0x20(%rsp),%ymm0
+       vmovdqa         0x20(%rsp),%ymm1
+       vperm2i128      $0x20,%ymm5,%ymm1,%ymm0
+       cmp             $0x00a0,%rax
+       jl              .Lxorpart8
        vpxor           0x0080(%rdx),%ymm0,%ymm0
        vmovdqu         %ymm0,0x0080(%rsi)
-       vmovdqa         0x40(%rsp),%ymm0
-       vpxor           0x0040(%rdx),%ymm0,%ymm0
-       vmovdqu         %ymm0,0x0040(%rsi)
-       vmovdqa         0x60(%rsp),%ymm0
+       vperm2i128      $0x31,%ymm5,%ymm1,%ymm5
+
+       vperm2i128      $0x20,%ymm13,%ymm9,%ymm0
+       cmp             $0x00c0,%rax
+       jl              .Lxorpart8
+       vpxor           0x00a0(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x00a0(%rsi)
+       vperm2i128      $0x31,%ymm13,%ymm9,%ymm13
+
+       vmovdqa         0x60(%rsp),%ymm1
+       vperm2i128      $0x20,%ymm7,%ymm1,%ymm0
+       cmp             $0x00e0,%rax
+       jl              .Lxorpart8
        vpxor           0x00c0(%rdx),%ymm0,%ymm0
        vmovdqu         %ymm0,0x00c0(%rsi)
-       vpxor           0x0100(%rdx),%ymm4,%ymm4
-       vmovdqu         %ymm4,0x0100(%rsi)
-       vpxor           0x0180(%rdx),%ymm5,%ymm5
-       vmovdqu         %ymm5,0x00180(%rsi)
-       vpxor           0x0140(%rdx),%ymm6,%ymm6
-       vmovdqu         %ymm6,0x0140(%rsi)
-       vpxor           0x01c0(%rdx),%ymm7,%ymm7
-       vmovdqu         %ymm7,0x01c0(%rsi)
-       vpxor           0x0020(%rdx),%ymm8,%ymm8
-       vmovdqu         %ymm8,0x0020(%rsi)
-       vpxor           0x00a0(%rdx),%ymm9,%ymm9
-       vmovdqu         %ymm9,0x00a0(%rsi)
-       vpxor           0x0060(%rdx),%ymm10,%ymm10
-       vmovdqu         %ymm10,0x0060(%rsi)
-       vpxor           0x00e0(%rdx),%ymm11,%ymm11
-       vmovdqu         %ymm11,0x00e0(%rsi)
-       vpxor           0x0120(%rdx),%ymm12,%ymm12
-       vmovdqu         %ymm12,0x0120(%rsi)
-       vpxor           0x01a0(%rdx),%ymm13,%ymm13
-       vmovdqu         %ymm13,0x01a0(%rsi)
-       vpxor           0x0160(%rdx),%ymm14,%ymm14
-       vmovdqu         %ymm14,0x0160(%rsi)
-       vpxor           0x01e0(%rdx),%ymm15,%ymm15
-       vmovdqu         %ymm15,0x01e0(%rsi)
+       vperm2i128      $0x31,%ymm7,%ymm1,%ymm7
+
+       vperm2i128      $0x20,%ymm15,%ymm11,%ymm0
+       cmp             $0x0100,%rax
+       jl              .Lxorpart8
+       vpxor           0x00e0(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x00e0(%rsi)
+       vperm2i128      $0x31,%ymm15,%ymm11,%ymm15
+
+       # xor remaining blocks, write to output
+       vmovdqa         %ymm4,%ymm0
+       cmp             $0x0120,%rax
+       jl              .Lxorpart8
+       vpxor           0x0100(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0100(%rsi)
 
+       vmovdqa         %ymm12,%ymm0
+       cmp             $0x0140,%rax
+       jl              .Lxorpart8
+       vpxor           0x0120(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0120(%rsi)
+
+       vmovdqa         %ymm6,%ymm0
+       cmp             $0x0160,%rax
+       jl              .Lxorpart8
+       vpxor           0x0140(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0140(%rsi)
+
+       vmovdqa         %ymm14,%ymm0
+       cmp             $0x0180,%rax
+       jl              .Lxorpart8
+       vpxor           0x0160(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0160(%rsi)
+
+       vmovdqa         %ymm5,%ymm0
+       cmp             $0x01a0,%rax
+       jl              .Lxorpart8
+       vpxor           0x0180(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x0180(%rsi)
+
+       vmovdqa         %ymm13,%ymm0
+       cmp             $0x01c0,%rax
+       jl              .Lxorpart8
+       vpxor           0x01a0(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x01a0(%rsi)
+
+       vmovdqa         %ymm7,%ymm0
+       cmp             $0x01e0,%rax
+       jl              .Lxorpart8
+       vpxor           0x01c0(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x01c0(%rsi)
+
+       vmovdqa         %ymm15,%ymm0
+       cmp             $0x0200,%rax
+       jl              .Lxorpart8
+       vpxor           0x01e0(%rdx),%ymm0,%ymm0
+       vmovdqu         %ymm0,0x01e0(%rsi)
+
+.Ldone8:
        vzeroupper
        lea             -8(%r10),%rsp
        ret
+
+.Lxorpart8:
+       # xor remaining bytes from partial register into output
+       mov             %rax,%r9
+       and             $0x1f,%r9
+       jz              .Ldone8
+       and             $~0x1f,%rax
+
+       mov             %rsi,%r11
+
+       lea             (%rdx,%rax),%rsi
+       mov             %rsp,%rdi
+       mov             %r9,%rcx
+       rep movsb
+
+       vpxor           0x00(%rsp),%ymm0,%ymm0
+       vmovdqa         %ymm0,0x00(%rsp)
+
+       mov             %rsp,%rsi
+       lea             (%r11,%rax),%rdi
+       mov             %r9,%rcx
+       rep movsb
+
+       jmp             .Ldone8
+
 ENDPROC(chacha20_8block_xor_avx2)
index 8f1ef1a9ce5cfa2e2c6d58f8c36a88f66480c0f6..882e8bf5965a780c0c00ea838136de71c06228ba 100644 (file)
@@ -24,7 +24,8 @@ asmlinkage void chacha20_block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
 asmlinkage void chacha20_4block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
                                          unsigned int len);
 #ifdef CONFIG_AS_AVX2
-asmlinkage void chacha20_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src);
+asmlinkage void chacha20_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
+                                        unsigned int len);
 static bool chacha20_use_avx2;
 #endif
 
@@ -34,7 +35,7 @@ static void chacha20_dosimd(u32 *state, u8 *dst, const u8 *src,
 #ifdef CONFIG_AS_AVX2
        if (chacha20_use_avx2) {
                while (bytes >= CHACHA20_BLOCK_SIZE * 8) {
-                       chacha20_8block_xor_avx2(state, dst, src);
+                       chacha20_8block_xor_avx2(state, dst, src, bytes);
                        bytes -= CHACHA20_BLOCK_SIZE * 8;
                        src += CHACHA20_BLOCK_SIZE * 8;
                        dst += CHACHA20_BLOCK_SIZE * 8;