This code implements a **single-precision general matrix-vector multiplication (SGEMV)** operation,...
This code implements a single-precision general matrix-vector multiplication (SGEMV) operation, optimized for performance using SIMD (Single Instruction, Multiple Data) intrinsics available in Intel's AVX (Advanced Vector Extensions) and SSE instruction sets.
Here’s what the code does:
Purpose:
The function computes the matrix-vector multiplication:
[ out = weights \cdot x ]
Where:
weights
is a matrix of dimensionsrows
xcols
.x
is a vector of sizecols
.out
is the resulting vector of sizerows
.
Details:
Input Parameters:
-
float *out
:
This is the output vector where the result of the matrix-vector multiplication is stored. -
const float *weights
:
The input matrix (rows x cols
) stored in row-major order. -
int rows
:
The number of rows in theweights
matrix (and size of theout
vector). -
int cols
:
The number of columns in theweights
matrix (the size of thex
vector). -
int col_stride
:
The stride (distance in memory) between consecutive columns ofweights
. -
const float *x
:
The input vector of sizecols
.
Output:
- Computes
out[i]
for each rowi
as the dot product of the rowi
in the matrixweights
with the input vectorx
.
Implementation Overview:
-
SIMD Vectorization: The function uses AVX (256-bit registers) and SSE (128-bit registers) intrinsics for efficient computation. SIMD enables the function to process multiple elements of data in parallel, significantly improving performance for large matrices.
-
Loop Structure: The computation is divided into several loops to handle different numbers of elements in SIMD registers:
Step 1 (Main Loop with 16-wide Vectorization):
- Processes 16 rows at a time using AVX instructions (
__m256
). - Two
__m256
vectors (vy0
andvy8
) accumulate results for the first and second halves of these 16 rows. - Each column of
x
is broadcast to an AVX register (vxj
) and multiplied with the corresponding 8 elements ofweights
in two batches (8 each). - Results are stored back into the
out
vector after processing all columns.
Step 2 (8-wide Vectorization):
- Processes 8 rows at a time when the remaining rows are less than 16 but at least 8.
- Similar to the first stage but without splitting
vy0
into two halves.
Step 3 (4-wide Vectorization):
- When the remaining rows are less than 8 but at least 4, SSE instructions (
__m128
) are used instead of AVX. - Uses
__m128
registers for 4-wide processing.
Step 4 (Scalar Fallback):
- For any remaining rows (less than 4), falls back to a simple scalar loop that computes the dot product for each remaining row using standard arithmetic.
- Processes 16 rows at a time using AVX instructions (
-
SIMD Intrinsics Used:
_mm256_setzero_ps()
: Zero-initializes a 256-bit AVX register._mm_setzero_ps()
: Zero-initializes a 128-bit SSE register._mm256_broadcast_ss(&x[j])
: Broadcasts a single scalar value (x[j]
) to all lanes of a 256-bit register._mm256_loadu_ps()
: Loads 8float
values into a 256-bit register from memory (weights)._mm_loadu_ps()
: Loads 4float
values into a 128-bit register from memory._mm256_fmadd_ps() / _mm_fmadd_ps()
: Performs fused multiply-add (FMA), computingvy = a * b + vy
for improved performance and accuracy._mm256_storeu_ps() / _mm_storeu_ps()
: Stores 8 or 4float
values from the 256-bit or 128-bit register back to memory.
-
Final Result: After the computation, the resulting vector
out
contains the dot product of each row of theweights
matrix with thex
vector.
Key Features:
-
Fusion of Multiply and Add (FMA): The function uses FMA instructions (
_mm256_fmadd_ps()
and_mm_fmadd_ps()
) to optimize the computation of dot products. FMA reduces the latency and improves accuracy compared to separate multiply and add instructions. -
SIMD Optimization:
- Highly optimized for modern CPUs with AVX or SSE support.
- Processes multiple rows in parallel via vectorized operations.
-
Handling Edge Cases:
- The function includes fallback scalar loops for cases where the number of rows is not a multiple of 16, 8, or 4.
-
Column Stride:
- Accounts for non-contiguous memory layout of the
weights
matrix usingcol_stride
.
- Accounts for non-contiguous memory layout of the
Usage Example:
If weights
is a 3x3
matrix and x
is a vector:
weights = [1 2 3
4 5 6
7 8 9]
x = [1 1 1]
Then the function computes: [ out = weights \cdot x = [6, 15, 24] ]
Performance:
This function is designed for scenarios where performance is critical, such as deep learning, numerical computing, or handling large data matrices. However, it requires a CPU with AVX and SSE support to realize its full potential.