forked from pmkravets/strassen-winograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwinograd.c
71 lines (60 loc) · 1.74 KB
/
winograd.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
//
// winograd.c
// strassen-winograd
//
// Created by Pavel Kravets on 04.10.13.
// Copyright (c) 2013 Pavel Kravets. All rights reserved.
//
#include <stdio.h>
#include <stdlib.h>
#include "winograd.h"
void winograd_preprocess(matrix* m1, matrix* m2, double* row, double* col)
{
int d = m1->colNum/2;
for (int i=0; i<m1->rowNum; i++)
{
row[i] = 0;
for (int j=0; j<d; j++)
{
row[i] += element(m1, i, 2*j) * element(m1, i, 2*j+1);
}
}
for (int i=0; i<m2->colNum; i++)
{
col[i] = 0;
for (int j=0; j<d; j++)
{
col[i] += element(m2, 2*j, i) * element(m2, 2*j+1, i);
}
}
}
void winograd_mult(matrix* m1, matrix* m2, matrix* res)
{
if (m1->colNum != m2->rowNum)
{
printf("Error. Matrices can not be multiplicated");
return;
}
res->matrix = (double*) malloc(m1->rowNum * m2->colNum * sizeof(double));
double* row = (double*) malloc(m1->rowNum * sizeof(double));
double* col = (double*) malloc(m2->colNum * sizeof(double));
int d = m1->colNum/2;
res->rowNum = m1->rowNum;
res->colNum = m2->colNum;
winograd_preprocess(m1, m2, row, col);
for (int i=0; i<res->rowNum; i++)
{
for (int j=0; j<res->colNum; j++)
{
set_element(res, i, j, -row[i]-col[j]);
for (int k=0; k<d; k++)
{
add_to_element(res, i, j, (element(m1, i, 2*k)+element(m2, 2*k+1, j)) * (element(m1, i, 2*k+1)+element(m2, 2*k, j)));
}
if (m1->colNum % 2 !=0)
{
add_to_element(res, i, j, element(m1, i, m1->colNum-1)*element(m2, m2->rowNum-1, j));
}
}
}
}