SSAGES  0.9.3
Software Suite for Advanced General Ensemble Simulations
Public Member Functions | Static Public Member Functions | Private Member Functions | Private Attributes | List of all members
SSAGES::ANN Class Reference

Artificial Neural Network Method. More...

#include <ANN.h>

Inheritance diagram for SSAGES::ANN:
Inheritance graph
[legend]

Public Member Functions

 ANN (const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< unsigned int > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, unsigned int nsweep)
 Constructor. More...
 
void PreSimulation (Snapshot *snapshot, const class CVManager &cvmanager) override
 Method call prior to simulation initiation. More...
 
void PostIntegration (Snapshot *snapshot, const class CVManager &cvmanager) override
 Method call post integration. More...
 
void PostSimulation (Snapshot *snapshot, const class CVManager &cvmanager) override
 Method call post simulation. More...
 
void SetPrevWeight (double h)
 Set previous history weight. More...
 
void SetOutput (const std::string &outfile)
 Set name of output file. More...
 
void SetOutputOverwrite (bool overwrite)
 Set overwrite flag on output file. More...
 
void SetConvergeIters (unsigned int citers)
 Set number of iterations after which we turn on full weight. More...
 
void SetMaxIters (unsigned int iters)
 Set maximum number of training iterations per sweep. More...
 
void SetMinLoss (double loss)
 Set minimum loss function value (should be zero for production). More...
 
void ReadBias (const std::string &, const std::string &)
 Load network state and bias from file.
 
- Public Member Functions inherited from SSAGES::Method
 Method (unsigned int frequency, const MPI_Comm &world, const MPI_Comm &comm)
 Constructor. More...
 
void SetCVMask (const std::vector< unsigned int > &mask)
 Sets the collective variable mask. More...
 
virtual ~Method ()
 Destructor.
 
- Public Member Functions inherited from SSAGES::EventListener
 EventListener (unsigned int frequency)
 Constructor. More...
 
unsigned int GetFrequency () const
 Get frequency of event listener. More...
 
virtual ~EventListener ()
 Destructor.
 

Static Public Member Functions

static ANNBuild (const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
 Build a derived method from JSON node. More...
 
- Static Public Member Functions inherited from SSAGES::Method
static MethodBuildMethod (const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
 Build a derived method from JSON node. More...
 
- Static Public Member Functions inherited from SSAGES::EventListener
static unsigned int GetWalkerID (const MPI_Comm &world, const MPI_Comm &comm)
 Get walker ID number of specified communicator. More...
 
static unsigned int GetNumWalkers (const MPI_Comm &world, const MPI_Comm &comm)
 Get total number of walkers in the simulation. More...
 
static bool IsMasterRank (const MPI_Comm &comm)
 Check if current processor is master. More...
 

Private Member Functions

void TrainNetwork ()
 Trains the neural network.
 
void WriteBias ()
 Writes out the bias to file.
 

Private Attributes

Eigen::VectorXi topol_
 Neural network topology.
 
unsigned int citers_
 Number of iterations after which we turn on full weight.
 
nnet::neural_net net_
 Neural network.
 
Grid< Eigen::VectorXd > * fgrid_
 Force grid.
 
Grid< unsigned int > * hgrid_
 Histogram grid.
 
Grid< double > * ugrid_
 Unbiased histogram grid.
 
std::string outfile_
 Output filename.
 
bool preloaded_
 Is the network preloaded?
 
bool overwrite_
 Overwrite outputs?
 
unsigned int sweep_
 
unsigned int nsweep_
 
double pweight_
 
double weight_
 
double temp_
 
double kbt_
 
Eigen::MatrixXd hist_
 
Eigen::MatrixXd bias_
 
Eigen::MatrixXd rbias_
 
std::vector< double > lowerb_
 
std::vector< double > upperb_
 
std::vector< double > lowerk_
 
std::vector< double > upperk_
 

Additional Inherited Members

- Protected Attributes inherited from SSAGES::Method
mxx::comm world_
 Global MPI communicator.
 
mxx::comm comm_
 Local MPI communicator.
 
std::vector< unsigned int > cvmask_
 Mask which identifies which CVs to act on.
 

Detailed Description

Artificial Neural Network Method.

Implementation of the Artificial Neural Network Method based on [3]

Definition at line 35 of file ANN.h.

Constructor & Destructor Documentation

◆ ANN()

SSAGES::ANN::ANN ( const MPI_Comm &  world,
const MPI_Comm &  comm,
const Eigen::VectorXi &  topol,
Grid< Eigen::VectorXd > *  fgrid,
Grid< unsigned int > *  hgrid,
Grid< double > *  ugrid,
const std::vector< double > &  lowerb,
const std::vector< double > &  upperb,
const std::vector< double > &  lowerk,
const std::vector< double > &  upperk,
double  temperature,
double  weight,
unsigned int  nsweep 
)

Constructor.

Parameters
worldMPI global communicator.
commMPI local communicator.
topolTopology of network.
fgridGrid containing biasing forces.
hgridGrid containing histogram.
ugridGrid containing unbiased histogram.
lowerbLower bounds for CVs.
upperbUpper bounds for CVs.
lowerkLower bound restraints for CVs.
upperkUpper bound restraints for CVs.
temperatureTemperature of the simulation.
weightRelative weight of the statistics in sweep.
nsweepNumber of iterations in the sweep.

Constructs an instance of Artificial Neural Network method.

Definition at line 34 of file ANN.cpp.

46  :
47  Method(1, world, comm), topol_(topol), sweep_(0), nsweep_(nsweep),
48  citers_(0), net_(topol), pweight_(1.), weight_(weight), temp_(temperature),
49  kbt_(0), fgrid_(fgrid), hgrid_(hgrid), ugrid_(ugrid), hist_(), bias_(),
50  lowerb_(lowerb), upperb_(upperb), lowerk_(lowerk), upperk_(upperk),
51  outfile_("ann.out"), preloaded_(false), overwrite_(true)
52  {
53  // Create histogram grid matrix.
54  hist_.resize(hgrid_->size(), hgrid_->GetDimension());
55 
56  // Fill it up.
57  size_t i = 0;
58  for(auto it = hgrid_->begin(); it != hgrid_->end(); ++it)
59  {
60  auto coord = it.coordinates();
61  for(size_t j = 0; j < coord.size(); ++j)
62  hist_(i, j) = coord[j];
63  ++i;
64  }
65 
66  // Initialize restraint reweight vector.
67  rbias_.resize(hgrid_->size(), hgrid_->GetDimension());
68  rbias_.fill(0);
69 
70  // Initialize FES vector.
71  bias_.resize(hgrid_->size(), 1);
72  net_.forward_pass(hist_);
73  bias_.array() = net_.get_activation().col(0).array();
74  }
std::vector< double > lowerb_
Definition: ANN.h:78
Grid< unsigned int > * hgrid_
Histogram grid.
Definition: ANN.h:66
double temp_
Definition: ANN.h:59
std::vector< double > lowerk_
Definition: ANN.h:83
Eigen::MatrixXd hist_
Definition: ANN.h:73
bool preloaded_
Is the network preloaded?
Definition: ANN.h:90
unsigned int sweep_
Definition: ANN.h:43
Grid< Eigen::VectorXd > * fgrid_
Force grid.
Definition: ANN.h:63
nnet::neural_net net_
Neural network.
Definition: ANN.h:50
double pweight_
Definition: ANN.h:54
unsigned int citers_
Number of iterations after which we turn on full weight.
Definition: ANN.h:47
Eigen::VectorXi topol_
Neural network topology.
Definition: ANN.h:39
std::string outfile_
Output filename.
Definition: ANN.h:87
Grid< double > * ugrid_
Unbiased histogram grid.
Definition: ANN.h:69
bool overwrite_
Overwrite outputs?
Definition: ANN.h:93
size_t size() const
Get the size of the internal storage vector.
Definition: GridBase.h:326
size_t GetDimension() const
Get the dimension.
Definition: GridBase.h:195
iterator begin()
Return iterator at first grid point.
Definition: Grid.h:527
iterator end()
Return iterator after last valid grid point.
Definition: Grid.h:540
Method(unsigned int frequency, const MPI_Comm &world, const MPI_Comm &comm)
Constructor.
Definition: Method.h:61

References SSAGES::Grid< T >::begin(), SSAGES::Grid< T >::end(), SSAGES::GridBase< T >::GetDimension(), hgrid_, hist_, net_, and SSAGES::GridBase< T >::size().

Referenced by Build().

Here is the call graph for this function:
Here is the caller graph for this function:

Member Function Documentation

◆ Build()

ANN * SSAGES::ANN::Build ( const Json::Value &  json,
const MPI_Comm &  world,
const MPI_Comm &  comm,
const std::string &  path 
)
static

Build a derived method from JSON node.

Parameters
jsonJSON Value containing all input information.
worldMPI global communicator.
commMPI local communicator.
pathPath for JSON path specification.
Returns
Pointer to the Method built. nullptr if an unknown error occurred.

This function builds a registered method from a JSON node. The difference between this function and "Build" is that this automatically determines the appropriate derived type based on the JSON node information.

Note
Object lifetime is the caller's responsibility.

Definition at line 287 of file ANN.cpp.

292  {
293  ObjectRequirement validator;
294  Value schema;
295  CharReaderBuilder rbuilder;
296  CharReader* reader = rbuilder.newCharReader();
297 
298  reader->parse(JsonSchema::ANNMethod.c_str(),
299  JsonSchema::ANNMethod.c_str() + JsonSchema::ANNMethod.size(),
300  &schema, nullptr);
301  validator.Parse(schema, path);
302 
303  // Validate inputs.
304  validator.Validate(json, path);
305  if(validator.HasErrors())
306  throw BuildException(validator.GetErrors());
307 
308  // Grid.
309  auto* fgrid = Grid<VectorXd>::BuildGrid(json.get("grid", Json::Value()));
310  auto* hgrid = Grid<unsigned int>::BuildGrid(json.get("grid", Json::Value()));
311  auto* ugrid = Grid<double>::BuildGrid(json.get("grid", Json::Value()));
312 
313  // Topology.
314  auto nlayers = json["topology"].size() + 2;
315  VectorXi topol(nlayers);
316  topol[0] = fgrid->GetDimension();
317  topol[nlayers-1] = 1;
318  for(int i = 0; i < static_cast<int>(json["topology"].size()); ++i)
319  topol[i+1] = json["topology"][i].asInt();
320 
321  auto weight = json.get("weight", 1.).asDouble();
322  auto temp = json["temperature"].asDouble();
323  auto nsweep = json["nsweep"].asUInt();
324 
325  // Assume all vectors are the same size.
326  std::vector<double> lowerb, upperb, lowerk, upperk;
327  for(int i = 0; i < static_cast<int>(json["lower_bound_restraints"].size()); ++i)
328  {
329  lowerk.push_back(json["lower_bound_restraints"][i].asDouble());
330  upperk.push_back(json["upper_bound_restraints"][i].asDouble());
331  lowerb.push_back(json["lower_bounds"][i].asDouble());
332  upperb.push_back(json["upper_bounds"][i].asDouble());
333  }
334 
335  auto* m = new ANN(world, comm, topol, fgrid, hgrid, ugrid, lowerb, upperb, lowerk, upperk, temp, weight, nsweep);
336 
337  // Set optional params.
338  m->SetPrevWeight(json.get("prev_weight", 1).asDouble());
339  m->SetOutput(json.get("output_file", "ann.out").asString());
340  m->SetOutputOverwrite( json.get("overwrite_output", true).asBool());
341  m->SetConvergeIters(json.get("converge_iters", 0).asUInt());
342  m->SetMaxIters(json.get("max_iters", 1000).asUInt());
343  m->SetMinLoss(json.get("min_loss", 0).asDouble());
344 
345  if(json.isMember("net_state") && json.isMember("load_bias"))
346  m->ReadBias(json["net_state"].asString(), json["load_bias"].asString());
347 
348  return m;
349  }
Requirements on an object.
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
ANN(const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< unsigned int > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, unsigned int nsweep)
Constructor.
Definition: ANN.cpp:34
static Grid< T > * BuildGrid(const Json::Value &json)
Set up the grid.
Definition: Grid.h:127

References ANN(), SSAGES::Grid< T >::BuildGrid(), Json::Requirement::GetErrors(), Json::Requirement::HasErrors(), Json::ObjectRequirement::Parse(), and Json::ObjectRequirement::Validate().

Here is the call graph for this function:

◆ PostIntegration()

void SSAGES::ANN::PostIntegration ( Snapshot snapshot,
const class CVManager cvmanager 
)
overridevirtual

Method call post integration.

Parameters
snapshotPointer to the simulation snapshot.
cvmanagerCollective variable manager.

This function will be called after each integration step.

Implements SSAGES::Method.

Definition at line 113 of file ANN.cpp.

114  {
115  if(snapshot->GetIteration() && snapshot->GetIteration() % nsweep_ == 0)
116  {
117  // Switch to full blast.
118  if(citers_ && snapshot->GetIteration() > citers_)
119  pweight_ = 1.0;
120 
121  TrainNetwork();
122  if(IsMasterRank(world_))
123  WriteBias();
124  }
125 
126  // Get CV vals.
127  auto cvs = cvmanager.GetCVs(cvmask_);
128  auto n = cvs.size();
129 
130  // Determine if we are in bounds.
131  RowVectorXd vec(n);
132  std::vector<double> val(n);
133  bool inbounds = true;
134  for(size_t i = 0; i < n; ++i)
135  {
136  val[i] = cvs[i]->GetValue();
137  vec[i] = cvs[i]->GetValue();
138  if(val[i] < hgrid_->GetLower(i) || val[i] > hgrid_->GetUpper(i))
139  inbounds = false;
140  }
141 
142  // If in bounds, bias.
143  VectorXd derivatives = VectorXd::Zero(n);
144  if(inbounds)
145  {
146  // Record histogram hit and get gradient.
147  // Only record hits on master processes since we will
148  // reduce later.
149  if(IsMasterRank(comm_))
150  hgrid_->at(val) += 1;
151  //derivatives = (*fgrid_)[val];
152  net_.forward_pass(vec);
153  derivatives = net_.get_gradient(0);
154  }
155  else
156  {
157  if(IsMasterRank(comm_))
158  {
159  std::cerr << "ANN (" << snapshot->GetIteration() << "): out of bounds ( ";
160  for(auto& v : val)
161  std::cerr << v << " ";
162  std::cerr << ")" << std::endl;
163  }
164  }
165 
166  // Restraints.
167  for(size_t i = 0; i < n; ++i)
168  {
169  auto cval = cvs[i]->GetValue();
170  if(cval < lowerb_[i])
171  derivatives[i] += lowerk_[i]*cvs[i]->GetDifference(lowerb_[i]);
172  else if(cval > upperb_[i])
173  derivatives[i] += upperk_[i]*cvs[i]->GetDifference(upperb_[i]);
174  }
175 
176  // Apply bias to atoms.
177  auto& forces = snapshot->GetForces();
178  auto& virial = snapshot->GetVirial();
179 
180  for(size_t i = 0; i < cvs.size(); ++i)
181  {
182  auto& grad = cvs[i]->GetGradient();
183  auto& boxgrad = cvs[i]->GetBoxGradient();
184 
185  // Update the forces in snapshot by adding in the force bias from each
186  // CV to each atom based on the gradient of the CV.
187  for (size_t j = 0; j < forces.size(); ++j)
188  forces[j] -= derivatives[i]*grad[j];
189 
190  virial += derivatives[i]*boxgrad;
191  }
192  }
void TrainNetwork()
Trains the neural network.
Definition: ANN.cpp:198
void WriteBias()
Writes out the bias to file.
Definition: ANN.cpp:245
static bool IsMasterRank(const MPI_Comm &comm)
Check if current processor is master.
const std::vector< double > GetLower() const
Return the lower edges of the Grid.
Definition: GridBase.h:231
const T & at(const std::vector< int > &indices) const
Access Grid element read-only.
Definition: GridBase.h:546
const std::vector< double > GetUpper() const
Return the upper edges of the Grid.
Definition: GridBase.h:262
mxx::comm comm_
Local MPI communicator.
Definition: Method.h:47
std::vector< unsigned int > cvmask_
Mask which identifies which CVs to act on.
Definition: Method.h:50
mxx::comm world_
Global MPI communicator.
Definition: Method.h:46

References SSAGES::GridBase< T >::at(), citers_, SSAGES::Method::comm_, SSAGES::Method::cvmask_, SSAGES::CVManager::GetCVs(), SSAGES::Snapshot::GetForces(), SSAGES::Snapshot::GetIteration(), SSAGES::GridBase< T >::GetLower(), SSAGES::GridBase< T >::GetUpper(), SSAGES::Snapshot::GetVirial(), hgrid_, SSAGES::EventListener::IsMasterRank(), lowerb_, lowerk_, net_, pweight_, TrainNetwork(), SSAGES::Method::world_, and WriteBias().

Here is the call graph for this function:

◆ PostSimulation()

void SSAGES::ANN::PostSimulation ( Snapshot snapshot,
const class CVManager cvmanager 
)
overridevirtual

Method call post simulation.

Parameters
snapshotPointer to the simulation snapshot.
cvmanagerCollective variable manager.

This function will be called after the end of the simulation run.

Implements SSAGES::Method.

Definition at line 194 of file ANN.cpp.

195  {
196  }

◆ PreSimulation()

void SSAGES::ANN::PreSimulation ( Snapshot snapshot,
const class CVManager cvmanager 
)
overridevirtual

Method call prior to simulation initiation.

Parameters
snapshotPointer to the simulation snapshot.
cvmanagerCollective variable manager.

This function will be called before the simulation is started.

Implements SSAGES::Method.

Definition at line 76 of file ANN.cpp.

77  {
78  auto ndim = hgrid_->GetDimension();
79  kbt_ = snapshot->GetKb()*temp_;
80 
81  // Zero out forces and histogram.
82  VectorXd vec = VectorXd::Zero(ndim);
83  std::fill(hgrid_->begin(), hgrid_->end(), 0);
84 
85  if(preloaded_)
86  {
87  net_.forward_pass(hist_);
88  bias_.array() = net_.get_activation().col(0).array();
89  TrainNetwork();
90  }
91  else
92  std::fill(ugrid_->begin(), ugrid_->end(), 1.0);
93 
94  // Fill in the reweighting matrix for restraints.
95  auto cvs = cvmanager.GetCVs(cvmask_);
96  auto ncvs = cvs.size();
97 
98  for(int i = 0; i < hist_.rows(); ++i)
99  {
100  auto cval = hist_.row(i);
101  for(size_t j = 0; j < ncvs; ++j)
102  {
103  if(cval[j] < lowerb_[j])
104  rbias_(i,j) = lowerk_[j]*(cval[j] - lowerb_[j])*(cval[j] - lowerb_[j]);
105  else if(cval[j] > upperb_[j])
106  rbias_(i,j) = upperk_[j]*(cval[j] - upperb_[j])*(cval[j] - upperb_[j]);
107  }
108  }
109 
110  std::fill(fgrid_->begin(), fgrid_->end(), vec);
111  }

References SSAGES::Grid< T >::begin(), SSAGES::Method::cvmask_, SSAGES::Grid< T >::end(), fgrid_, SSAGES::CVManager::GetCVs(), SSAGES::GridBase< T >::GetDimension(), SSAGES::Snapshot::GetKb(), hgrid_, hist_, lowerb_, lowerk_, net_, preloaded_, temp_, TrainNetwork(), and ugrid_.

Here is the call graph for this function:

◆ SetConvergeIters()

void SSAGES::ANN::SetConvergeIters ( unsigned int  citers)
inline

Set number of iterations after which we turn on full weight.

Parameters
citersNumber of iterations before full weight

Definition at line 175 of file ANN.h.

176  {
177  citers_ = citers;
178  }

References citers_.

◆ SetMaxIters()

void SSAGES::ANN::SetMaxIters ( unsigned int  iters)
inline

Set maximum number of training iterations per sweep.

Parameters
itersMaximum iterations per sweep

Definition at line 184 of file ANN.h.

185  {
186  auto params = net_.get_train_params();
187  params.max_iter = iters;
188  net_.set_train_params(params);
189  }

References net_.

◆ SetMinLoss()

void SSAGES::ANN::SetMinLoss ( double  loss)
inline

Set minimum loss function value (should be zero for production).

Parameters
lossMinimum loss function value

Definition at line 195 of file ANN.h.

196  {
197  auto params = net_.get_train_params();
198  params.min_loss = loss;
199  net_.set_train_params(params);
200  }

References net_.

◆ SetOutput()

void SSAGES::ANN::SetOutput ( const std::string &  outfile)
inline

Set name of output file.

Parameters
outfileOutput file

Definition at line 157 of file ANN.h.

158  {
159  outfile_ = outfile;
160  }

References outfile_.

◆ SetOutputOverwrite()

void SSAGES::ANN::SetOutputOverwrite ( bool  overwrite)
inline

Set overwrite flag on output file.

Parameters
overwriteBoolean if output file should be overwritten

Definition at line 166 of file ANN.h.

167  {
168  overwrite_ = overwrite;
169  }

References overwrite_.

◆ SetPrevWeight()

void SSAGES::ANN::SetPrevWeight ( double  h)
inline

Set previous history weight.

Parameters
hHistory weight

Definition at line 148 of file ANN.h.

149  {
150  pweight_ = h;
151  }

References pweight_.

Member Data Documentation

◆ hist_

Eigen::MatrixXd SSAGES::ANN::hist_
private

Eigen matrices of grids.

Definition at line 73 of file ANN.h.

Referenced by ANN(), PreSimulation(), ReadBias(), TrainNetwork(), and WriteBias().

◆ lowerb_

std::vector<double> SSAGES::ANN::lowerb_
private

Bounds

Definition at line 78 of file ANN.h.

Referenced by PostIntegration(), and PreSimulation().

◆ lowerk_

std::vector<double> SSAGES::ANN::lowerk_
private

Bound restraints.

Definition at line 83 of file ANN.h.

Referenced by PostIntegration(), and PreSimulation().

◆ pweight_

double SSAGES::ANN::pweight_
private

Previous and current histogram weight.

Definition at line 54 of file ANN.h.

Referenced by PostIntegration(), SetPrevWeight(), and TrainNetwork().

◆ sweep_

unsigned int SSAGES::ANN::sweep_
private

Number of iterations per sweep.

Definition at line 43 of file ANN.h.

Referenced by TrainNetwork(), and WriteBias().

◆ temp_

double SSAGES::ANN::temp_
private

System temperature and energy units.

Definition at line 59 of file ANN.h.

Referenced by PreSimulation().


The documentation for this class was generated from the following files: