-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelLoss.m
31 lines (23 loc) · 1.04 KB
/
modelLoss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
modelLoss(netG,netD,X,T,Z,flipFactor)
% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X,T);
% Calculate the predictions for generated data with the discriminator network.
[XGenerated,stateG] = forward(netG,Z,T);
YGenerated = forward(netD,XGenerated,T);
% Calculate probabilities.
probGenerated = sigmoid(YGenerated);
probReal = sigmoid(YReal);
% Calculate the generator and discriminator scores.
scoreG = mean(probGenerated);
scoreD = (mean(probReal) + mean(1-probGenerated)) / 2;
% Flip labels.
numObservations = size(YReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);
% Calculate the GAN loss.
[lossG, lossD] = ganLoss(probReal,probGenerated);
% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);
end