% preferences for training data
pathToImages = 'training_images/tr_%04d.mat'; %sprintf-readable mat-file location
[pathOnly, ~] = fileparts(pathToImages);

% create Datastore for training data
voldsGT = fileDatastore(pathOnly, ...
    'FileExtensions','.mat','ReadFcn',@(x) matReadGT(x));

voldsNoise = fileDatastore(pathOnly, ...
    'FileExtensions','.mat','ReadFcn',@(x) matReadNoise(x));

com = combine(voldsNoise, voldsGT);

% % get Validation data
% voldsGT = fileDatastore('/home/michael/dustyNET_tmp/allStages_allcams/validation', ...
%    'FileExtensions','.mat','ReadFcn',@(x) matReadGT(x));
%
% % create Datastore for training data
% voldsNoise = fileDatastore('/home/michael/dustyNET_tmp/allStages_allcams/validation', ...
%    'FileExtensions','.mat','ReadFcn',@(x) matReadNoise(x));
%
% com_validation = combine(voldsNoise, voldsGT);

clear voldsGT;
clear voldsNoise;

%% train network

options = trainingOptions('sgdm', ...
    'ExecutionEnvironment', 'auto',...  % use gpu if possible
    'ValidationPatience', Inf, ...
    'MaxEpochs',3,...
    'InitialLearnRate',1e-1, ...
    'LearnRateSchedule', 'piecewise',...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',1, ...
    'Verbose',true, ...
    'Shuffle', 'never',...
    'MiniBatchSize', 1,...
    'Plots','none');
%    'CheckpointPath', './checkpoints',...
%   'ValidationData', com_validation, ...

gridSize = [length(volGridpoints.x),length(volGridpoints.y),length(volGridpoints.z) ];
lgraph = createParticleFieldNetwork(gridSize);

trainedNetwork = trainNetwork(com, lgraph.Layers, options);
save(network_savename, 'trainedNetwork', 'volGridpoints');
save



%% extra functions
function data = matReadGT(filename)
data = load(filename, 'field3d_gt');
data = data.field3d_gt;
data = single(data)./255;
end

function data = matReadNoise(filename)
data = load(filename, 'field_initial');
data = data.field_initial;
data = single(data)./255;
data = data  + 0.03.*rand(size(data));
end
