Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): make cost calculation/return optional #75

Merged
merged 2 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions app/train-cloud-microphysics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ program train_cloud_microphysics
double precision, allocatable, dimension(:) :: time_in, time_out
double precision, parameter :: tolerance = 1.E-07
integer, allocatable :: lbounds(:)
integer, parameter :: initial = 1
integer t_end, t

print *,"Starting to read network inputs from " // network_input
Expand All @@ -79,7 +78,6 @@ program train_cloud_microphysics
! Skipping the following unnecessary inputs that are in the current file format as of 14 Aug 2023:
! precipitation, snowfall
call network_input_file%input("pressure", pressure_in)

call network_input_file%input("potential_temperature", potential_temperature_in)
call network_input_file%input("temperature", temperature_in)
call network_input_file%input("qv", qv_in)
Expand Down Expand Up @@ -157,6 +155,7 @@ program train_cloud_microphysics
integer, parameter :: mini_batch_size=1
integer batch, lon, lat, level, time
type(file_t) json_file
real(rkind), allocatable :: cost(:)

trainable_engine = random_hidden_layers(num_inputs=8, num_outputs=6)

Expand Down Expand Up @@ -189,7 +188,7 @@ program train_cloud_microphysics
mini_batches = [(mini_batch_t(input_output_pair_t(inputs(:,batch), outputs(:,batch))), batch = 1, num_mini_batches)]

print *,"Training network"
call trainable_engine%train(mini_batches)
call trainable_engine%train(mini_batches, cost)

end do
end associate
Expand Down
3 changes: 2 additions & 1 deletion src/inference_engine/trainable_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ pure module subroutine assert_consistent(self)
class(trainable_engine_t), intent(in) :: self
end subroutine

pure module subroutine train(self, mini_batches)
pure module subroutine train(self, mini_batches, cost)
implicit none
class(trainable_engine_t), intent(inout) :: self
type(mini_batch_t), intent(in) :: mini_batches(:)
real(rkind), intent(out), allocatable, optional :: cost(:)
end subroutine

elemental module function infer(self, inputs) result(outputs)
Expand Down
50 changes: 29 additions & 21 deletions src/inference_engine/trainable_engine_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,10 @@
end procedure

module procedure train
integer i, j, k, l, batch, iter, mini_batch_size, pair
integer l, batch, mini_batch_size, pair
real(rkind), parameter :: eta = 1.5e0 ! Learning parameter
real(rkind), allocatable :: z(:,:), a(:,:), y(:), delta(:,:), dcdw(:,:,:), dcdb(:,:)
real(rkind) cost
type(tensor_t), allocatable :: inputs(:)
type(tensor_t), allocatable :: expected_outputs(:)
real(rkind), allocatable :: z(:,:), a(:,:), delta(:,:), dcdw(:,:,:), dcdb(:,:)
type(tensor_t), allocatable :: inputs(:), expected_outputs(:)

call self%assert_consistent

Expand All @@ -81,14 +79,17 @@
allocate(delta, mold=self%b)
allocate(dcdb, mold=self%b) ! Gradient of cost function with respect with biases

associate(w => self%w, b => self%b, n => self%n)
associate(w => self%w, b => self%b, n => self%n, num_mini_batches => size(mini_batches))

if (present(cost)) allocate(cost(num_mini_batches))

iterate_across_batches: &
do iter = 1, size(mini_batches)
do batch = 1, num_mini_batches

cost = 0.; dcdw = 0.; dcdb = 0.
if (present(cost)) cost(batch) = 0.
dcdw = 0.; dcdb = 0.

associate(input_output_pairs => mini_batches(iter)%input_output_pairs())
associate(input_output_pairs => mini_batches(batch)%input_output_pairs())
inputs = input_output_pairs%inputs()
expected_outputs = input_output_pairs%expected_outputs()
mini_batch_size = size(input_output_pairs)
Expand All @@ -98,19 +99,21 @@
do pair = 1, mini_batch_size

a(1:self%num_inputs(), input_layer) = inputs(pair)%values()
y = expected_outputs(pair)%values()

feed_forward: &
do l = 1,output_layer
z(1:n(l),l) = matmul(w(1:n(l),1:n(l-1),l), a(1:n(l-1),l-1)) + b(1:n(l),l)
a(1:n(l),l) = self%differentiable_activation_strategy_%activation(z(1:n(l),l))
end do feed_forward

cost = cost + sum((y(1:n(output_layer))-a(1:n(output_layer),output_layer))**2)/(2.e0*mini_batch_size)
associate(y => expected_outputs(pair)%values())
if (present(cost)) &
cost(batch) = cost(batch) + sum((y(1:n(output_layer))-a(1:n(output_layer),output_layer))**2)/(2.e0*mini_batch_size)

delta(1:n(output_layer),output_layer) = &
(a(1:n(output_layer),output_layer) - y(1:n(output_layer))) &
* self%differentiable_activation_strategy_%activation_derivative(z(1:n(output_layer),output_layer))
delta(1:n(output_layer),output_layer) = &
(a(1:n(output_layer),output_layer) - y(1:n(output_layer))) &
* self%differentiable_activation_strategy_%activation_derivative(z(1:n(output_layer),output_layer))
end associate

associate(n_hidden => self%num_layers()-2)
back_propagate_error: &
Expand All @@ -120,13 +123,17 @@
end do back_propagate_error
end associate

sum_gradients: &
do l = 1,output_layer
dcdb(1:n(l),l) = dcdb(1:n(l),l) + delta(1:n(l),l)
do concurrent(j = 1:n(l))
dcdw(j,1:n(l-1),l) = dcdw(j,1:n(l-1),l) + a(1:n(l-1),l-1)*delta(j,l)
end do
end do sum_gradients
block
integer j

sum_gradients: &
do l = 1,output_layer
dcdb(1:n(l),l) = dcdb(1:n(l),l) + delta(1:n(l),l)
do concurrent(j = 1:n(l))
dcdw(j,1:n(l-1),l) = dcdw(j,1:n(l-1),l) + a(1:n(l-1),l-1)*delta(j,l)
end do
end do sum_gradients
end block

end do iterate_through_batch

Expand All @@ -137,6 +144,7 @@
dcdw(1:n(l),1:n(l-1),l) = dcdw(1:n(l),1:n(l-1),l)/mini_batch_size
w(1:n(l),1:n(l-1),l) = w(1:n(l),1:n(l-1),l) - eta*dcdw(1:n(l),1:n(l-1),l) ! Adjust weights
end do adjust_weights_and_biases

end do iterate_across_batches

end associate
Expand Down