-
Notifications
You must be signed in to change notification settings - Fork 1
/
trmm.py
69 lines (55 loc) · 1.96 KB
/
trmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
import json
import pytest
import allo
import numpy as np
from allo.ir.types import int32, float32
import allo.ir.types as T
def trmm_np(A, B, alpha):
for i in range(A.shape[0]):
for j in range(B.shape[1]):
for k in range(i + 1, A.shape[0]):
B[i, j] += A[k, i] * B[k, j]
B[i, j] *= alpha
def top_trmm(concrete_type, m, n, alpha=1.5):
def S0[T: (float32, int32), M, N](A: "T[M, M]", B: "T[M, N]"):
for i1, j1 in allo.grid(M, N, name="update"):
for k1 in allo.reduction(M):
if k1 > i1:
B[i1, j1] += A[k1, i1] * B[k1, j1]
def S1[T: (float32, int32), M, N](B: "T[M, N]"):
for i0, j0 in allo.grid(M, N, name="mul"):
B[i0, j0] = B[i0, j0] * alpha
def kernel_trmm[T: (float32, int32), M, N](A: "T[M, M]", B: "T[M, N]"):
S0[T, M, N](A, B)
S1[T, M, N](B)
factor = 20
s0 = allo.customize(S0, instantiate=[concrete_type, m, n])
s0.partition(s0.B, dim=2, partition_type=2, factor=factor) # cyclic
s0.reorder("k1", "j1")
s0.buffer_at(s0.B, "i1")
s0.pipeline("j1")
s0.unroll("j1", factor=factor)
s1 = allo.customize(S1, instantiate=[concrete_type, m, n])
s1.pipeline("j0")
s1.unroll("j0", factor=factor)
s = allo.customize(kernel_trmm, instantiate=[concrete_type, m, n])
s.compose(s0)
s.compose(s1)
return s
def test_trmm():
# read problem size settings
setting_path = os.path.join(os.path.dirname(__file__), "../psize.json")
with open(setting_path, "r") as fp:
psize = json.load(fp)
test_psize = "medium"
M = psize["trmm"][test_psize]["M"]
N = psize["trmm"][test_psize]["N"]
concrete_type = float32
alpha = 1.5
s = top_trmm(concrete_type, M, N, alpha)
s.build(target="vitis_hls", mode="hw", project="trmm.prj")
if __name__ == "__main__":
test_trmm()