This repository has been archived by the owner on Apr 4, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_functional.mojo
71 lines (57 loc) · 2.33 KB
/
test_functional.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from time.time import now
from tensor import TensorShape
from voodoo.core import Tensor, HeNormal, RandomUniform, SGD
from voodoo.utils import (
info,
clear,
)
fn nanoseconds_to_seconds(t: Int) -> Float64:
return Float64(t) / 1_000_000_000.0
fn main() raises:
var W1 = Tensor[TensorShape(1, 32), HeNormal[1]]()
var W2 = Tensor[TensorShape(32, 32), HeNormal[32]]()
var W3 = Tensor[TensorShape(32, 1), HeNormal[32]]()
var b1 = Tensor[TensorShape(32), HeNormal[32]]()
var b2 = Tensor[TensorShape(32), HeNormal[32]]()
var b3 = Tensor[TensorShape(1), HeNormal[1]]()
var avg_loss: Float32 = 0.0
var every = 1000
var num_epochs = 200000
var input = Tensor[TensorShape(32, 1), RandomUniform[0, 1]]()
var true_vals = Tensor[TensorShape(32, 1), RandomUniform[0, 1]]()
var x = (input @ W1 + b1).compute_activation["relu"]()
x = (x @ W2 + b2).compute_activation["relu"]()
x = x @ W3 + b3
var loss = x.compute_loss["mse"](true_vals)
var initial_start = now()
var epoch_start = now()
var bar_accuracy = 20
for epoch in range(1, num_epochs + 1):
input.refresh()
for i in range(input.shape.num_elements()):
true_vals[i] = math.sin(15.0 * input[i])
var computed_loss = loss.forward_static()
avg_loss += computed_loss[0]
loss.backward()
loss.optimize[SGD[0.01]]()
if epoch % every == 0:
var bar = String("")
for i in range(bar_accuracy):
if i < ((epoch * bar_accuracy) / num_epochs).to_int():
bar += "█"
else:
bar += "░"
clear()
print_no_newline("\nEpoch: " + String(epoch) + " ")
info(bar + " ")
print_no_newline(String(((epoch * 100) / num_epochs).to_int()) + "%\n")
print("----------------------------------------\n")
print_no_newline("Average Loss: ")
info(String(avg_loss / every) + "\n")
print_no_newline("Time: ")
info(String(nanoseconds_to_seconds(now() - epoch_start)) + "s\n")
epoch_start = now()
print("\n----------------------------------------\n")
avg_loss = 0.0
print_no_newline("Total Time: ")
info(String(nanoseconds_to_seconds(now() - initial_start)) + "s\n\n")