C语言实现矩阵乘法的高效方法

被风吹过灼思 2023-06-30 10:03:11 浏览数 (3380)
反馈

本文将介绍一种使用C语言实现矩阵乘法的高效方法,即分块算法。分块算法的基本思想是将两个大矩阵分成若干个小矩阵,然后对每对小矩阵进行乘法运算,最后将结果合并成一个大矩阵。这样可以减少缓存失效的次数,提高运算速度。下面给出具体的代码实现。

#include <stdio.h>
#include <stdlib.h>
#include <time.h>


#define N 1000 // 矩阵的大小
#define B 100 // 分块的大小


// 生成一个随机矩阵
void generate_matrix(double *A) {
    srand(time(NULL));
    for (int i = 0; i < N * N; i++) {
        A[i] = rand() % 10;
    }
}


// 打印一个矩阵
void print_matrix(double *A) {
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            printf("%.2f ", A[i * N + j]);
        }
        printf("\n");
    }
}


// 普通的矩阵乘法
void normal_multiply(double *A, double *B, double *C) {
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            double sum = 0;
            for (int k = 0; k < N; k++) {
                sum += A[i * N + k] * B[k * N + j];
            }
            C[i * N + j] = sum;
        }
    }
}


// 分块的矩阵乘法
void block_multiply(double *A, double *B, double *C) {
    for (int i = 0; i < N; i += B) {
        for (int j = 0; j < N; j += B) {
            for (int k = 0; k < N; k += B) {
                // 对每个小矩阵进行乘法运算
                for (int ii = i; ii < i + B && ii < N; ii++) {
                    for (int jj = j; jj < j + B && jj < N; jj++) {
                        double sum = 0;
                        for (int kk = k; kk < k + B && kk < N; kk++) {
                            sum += A[ii * N + kk] * B[kk * N + jj];
                        }
                        C[ii * N + jj] += sum;
                    }
                }
            }
        }
    }
}


// 测试两种方法的运行时间
void test_time() {
    double *A = malloc(sizeof(double) * N * N);
    double *B = malloc(sizeof(double) * N * N);
    double *C1 = malloc(sizeof(double) * N * N);
    double *C2 = malloc(sizeof(double) * N * N);


    generate_matrix(A);
    generate_matrix(B);


    clock_t start, end;


    start = clock();
    normal_multiply(A, B, C1);
    end = clock();
    printf("Normal multiply time: %.3f s\n", (double)(end - start) / CLOCKS_PER_SEC);


    start = clock();
    block_multiply(A, B, C2);
    end = clock();
    printf("Block multiply time: %.3f s\n", (double)(end - start) / CLOCKS_PER_SEC);


    free(A);
    free(B);
    free(C1);
    free(C2);
}


// 主函数
int main() {
    test_time();
    return 0;
}

C语言相关课程推荐:C语言相关课程

C

0 人点赞