-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathutil.py
27 lines (22 loc) · 975 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import math
import tensorflow as tf
def min_batch(batch_size, n):
"""min_batch generates a permutation of n elements with a width of batch_size"""
ix = np.random.permutation(n)
k = np.empty([math.ceil(float(n) / batch_size)], dtype=object)
for y in range(0, math.ceil(n / batch_size)):
k[y] = np.array([], dtype=int)
for z in range(0, batch_size):
if y * batch_size + z > n - 1:
break
k[y] = np.append(k[y], ix[y * batch_size + z])
return k
def weight_variable(shape, std=0.1):
initial = tf.truncated_normal(shape, stddev=std)
return tf.Variable(initial)
def header(newLine=True):
print('\t mse rmse std ')
print('\t training validation training validation training validation reference runtime ', end="", flush=True)
if newLine:
print()