SSAGES  0.9.3
Software Suite for Advanced General Ensemble Simulations
ANN.cpp
1 
20 #include "ANN.h"
21 #include "schema.h"
22 #include "Snapshot.h"
23 #include "mxx/bcast.hpp"
24 #include "CVs/CVManager.h"
25 #include "Drivers/DriverException.h"
26 #include "Validator/ObjectRequirement.h"
27 
28 using namespace Eigen;
29 using namespace nnet;
30 using namespace Json;
31 
32 namespace SSAGES
33 {
34  ANN::ANN(const MPI_Comm& world,
35  const MPI_Comm& comm,
36  const VectorXi& topol,
37  Grid<VectorXd>* fgrid,
38  Grid<unsigned int>* hgrid,
39  Grid<double>* ugrid,
40  const std::vector<double>& lowerb,
41  const std::vector<double>& upperb,
42  const std::vector<double>& lowerk,
43  const std::vector<double>& upperk,
44  double temperature,
45  double weight,
46  unsigned int nsweep) :
47  Method(1, world, comm), topol_(topol), sweep_(0), nsweep_(nsweep),
48  citers_(0), net_(topol), pweight_(1.), weight_(weight), temp_(temperature),
49  kbt_(0), fgrid_(fgrid), hgrid_(hgrid), ugrid_(ugrid), hist_(), bias_(),
50  lowerb_(lowerb), upperb_(upperb), lowerk_(lowerk), upperk_(upperk),
51  outfile_("ann.out"), preloaded_(false), overwrite_(true)
52  {
53  // Create histogram grid matrix.
54  hist_.resize(hgrid_->size(), hgrid_->GetDimension());
55 
56  // Fill it up.
57  size_t i = 0;
58  for(auto it = hgrid_->begin(); it != hgrid_->end(); ++it)
59  {
60  auto coord = it.coordinates();
61  for(size_t j = 0; j < coord.size(); ++j)
62  hist_(i, j) = coord[j];
63  ++i;
64  }
65 
66  // Initialize restraint reweight vector.
67  rbias_.resize(hgrid_->size(), hgrid_->GetDimension());
68  rbias_.fill(0);
69 
70  // Initialize FES vector.
71  bias_.resize(hgrid_->size(), 1);
72  net_.forward_pass(hist_);
73  bias_.array() = net_.get_activation().col(0).array();
74  }
75 
76  void ANN::PreSimulation(Snapshot* snapshot, const CVManager& cvmanager)
77  {
78  auto ndim = hgrid_->GetDimension();
79  kbt_ = snapshot->GetKb()*temp_;
80 
81  // Zero out forces and histogram.
82  VectorXd vec = VectorXd::Zero(ndim);
83  std::fill(hgrid_->begin(), hgrid_->end(), 0);
84 
85  if(preloaded_)
86  {
87  net_.forward_pass(hist_);
88  bias_.array() = net_.get_activation().col(0).array();
89  TrainNetwork();
90  }
91  else
92  std::fill(ugrid_->begin(), ugrid_->end(), 1.0);
93 
94  // Fill in the reweighting matrix for restraints.
95  auto cvs = cvmanager.GetCVs(cvmask_);
96  auto ncvs = cvs.size();
97 
98  for(int i = 0; i < hist_.rows(); ++i)
99  {
100  auto cval = hist_.row(i);
101  for(size_t j = 0; j < ncvs; ++j)
102  {
103  if(cval[j] < lowerb_[j])
104  rbias_(i,j) = lowerk_[j]*(cval[j] - lowerb_[j])*(cval[j] - lowerb_[j]);
105  else if(cval[j] > upperb_[j])
106  rbias_(i,j) = upperk_[j]*(cval[j] - upperb_[j])*(cval[j] - upperb_[j]);
107  }
108  }
109 
110  std::fill(fgrid_->begin(), fgrid_->end(), vec);
111  }
112 
113  void ANN::PostIntegration(Snapshot* snapshot, const CVManager& cvmanager)
114  {
115  if(snapshot->GetIteration() && snapshot->GetIteration() % nsweep_ == 0)
116  {
117  // Switch to full blast.
118  if(citers_ && snapshot->GetIteration() > citers_)
119  pweight_ = 1.0;
120 
121  TrainNetwork();
122  if(IsMasterRank(world_))
123  WriteBias();
124  }
125 
126  // Get CV vals.
127  auto cvs = cvmanager.GetCVs(cvmask_);
128  auto n = cvs.size();
129 
130  // Determine if we are in bounds.
131  RowVectorXd vec(n);
132  std::vector<double> val(n);
133  bool inbounds = true;
134  for(size_t i = 0; i < n; ++i)
135  {
136  val[i] = cvs[i]->GetValue();
137  vec[i] = cvs[i]->GetValue();
138  if(val[i] < hgrid_->GetLower(i) || val[i] > hgrid_->GetUpper(i))
139  inbounds = false;
140  }
141 
142  // If in bounds, bias.
143  VectorXd derivatives = VectorXd::Zero(n);
144  if(inbounds)
145  {
146  // Record histogram hit and get gradient.
147  // Only record hits on master processes since we will
148  // reduce later.
149  if(IsMasterRank(comm_))
150  hgrid_->at(val) += 1;
151  //derivatives = (*fgrid_)[val];
152  net_.forward_pass(vec);
153  derivatives = net_.get_gradient(0);
154  }
155  else
156  {
157  if(IsMasterRank(comm_))
158  {
159  std::cerr << "ANN (" << snapshot->GetIteration() << "): out of bounds ( ";
160  for(auto& v : val)
161  std::cerr << v << " ";
162  std::cerr << ")" << std::endl;
163  }
164  }
165 
166  // Restraints.
167  for(size_t i = 0; i < n; ++i)
168  {
169  auto cval = cvs[i]->GetValue();
170  if(cval < lowerb_[i])
171  derivatives[i] += lowerk_[i]*cvs[i]->GetDifference(lowerb_[i]);
172  else if(cval > upperb_[i])
173  derivatives[i] += upperk_[i]*cvs[i]->GetDifference(upperb_[i]);
174  }
175 
176  // Apply bias to atoms.
177  auto& forces = snapshot->GetForces();
178  auto& virial = snapshot->GetVirial();
179 
180  for(size_t i = 0; i < cvs.size(); ++i)
181  {
182  auto& grad = cvs[i]->GetGradient();
183  auto& boxgrad = cvs[i]->GetBoxGradient();
184 
185  // Update the forces in snapshot by adding in the force bias from each
186  // CV to each atom based on the gradient of the CV.
187  for (size_t j = 0; j < forces.size(); ++j)
188  forces[j] -= derivatives[i]*grad[j];
189 
190  virial += derivatives[i]*boxgrad;
191  }
192  }
193 
195  {
196  }
197 
199  {
200  // Increment cycle counter.
201  ++sweep_;
202 
203  // Reduce histogram across procs.
204  mxx::allreduce(hgrid_->data(), hgrid_->size(), std::plus<unsigned int>(), world_);
205 
206  // Synchronize grid in case it's periodic.
207  hgrid_->syncGrid();
208 
209  // Update FES estimator. Synchronize unbiased histogram.
211  Map<Array<double, Dynamic, 1>> rbias(rbias_.data(), rbias_.size());
213  uhist.array() = pweight_*uhist.array() + hist.cast<double>()*(1./kbt_*bias_).array().exp()*weight_;
214  ugrid_->syncGrid();
215  hist.setZero();
216 
217  bias_.array() = kbt_*uhist.array().log();
218  bias_.array() -= bias_.minCoeff();
219 
220  // Train network.
221  net_.autoscale(hist_, bias_);
222  if(IsMasterRank(world_))
223  {
224  net_.train(hist_, bias_, true);
225  }
226 
227  // Send optimal nnet params to all procs.
228  vector_t wb = net_.get_wb();
229  mxx::bcast(wb.data(), wb.size(), 0, world_);
230  net_.set_wb(wb);
231 
232  // Evaluate and subtract off min value for applied bias.
233  net_.forward_pass(hist_);
234  bias_.array() = net_.get_activation().col(0).array();
235  bias_.array() -= bias_.minCoeff();
236 
237  // Calc new bias force.
238  for(size_t i = 0; i < fgrid_->size(); ++i)
239  {
240  MatrixXd forces = net_.get_gradient(i);
241  fgrid_->data()[i] = forces.row(i).transpose();
242  }
243  }
244 
246  {
247  net_.write("netstate.dat");
248 
249  std::string filename = overwrite_ ? outfile_ : outfile_ + std::to_string(sweep_);
250  std::ofstream file(filename);
251  file.precision(16);
252  net_.forward_pass(hist_);
253  matrix_t y = net_.get_activation();
254  for(int i = 0; i < y.rows(); ++i)
255  {
256  for(int j = 0; j < hist_.cols(); ++j)
257  file << std::fixed << hist_(i,j) << " ";
258  file << std::fixed << ugrid_->data()[i] << " " << std::fixed << y(i) << "\n";
259  }
260 
261  file.close();
262  }
263 
264  void ANN::ReadBias(const std::string& state_file, const std::string& bias_file)
265  {
266  std::ifstream file(bias_file, std::ios::in);
267  if(!file)
268  {
269  throw BuildException({"ANN::ReadBias() could not read file "+bias_file});
270  }
271 
272  net_ = nnet::neural_net(state_file.c_str());
273 
274  for(int i = 0; i < hist_.rows(); ++i)
275  {
276  double burn = 0.0;
277  for(int j = 0; j < hist_.cols(); ++j)
278  file >> burn;
279 
280  file >> ugrid_->data()[i];
281  file >> burn;
282  }
283 
284  preloaded_ = true;
285  }
286 
288  const Json::Value& json,
289  const MPI_Comm& world,
290  const MPI_Comm& comm,
291  const std::string& path)
292  {
293  ObjectRequirement validator;
294  Value schema;
295  CharReaderBuilder rbuilder;
296  CharReader* reader = rbuilder.newCharReader();
297 
298  reader->parse(JsonSchema::ANNMethod.c_str(),
299  JsonSchema::ANNMethod.c_str() + JsonSchema::ANNMethod.size(),
300  &schema, nullptr);
301  validator.Parse(schema, path);
302 
303  // Validate inputs.
304  validator.Validate(json, path);
305  if(validator.HasErrors())
306  throw BuildException(validator.GetErrors());
307 
308  // Grid.
309  auto* fgrid = Grid<VectorXd>::BuildGrid(json.get("grid", Json::Value()));
310  auto* hgrid = Grid<unsigned int>::BuildGrid(json.get("grid", Json::Value()));
311  auto* ugrid = Grid<double>::BuildGrid(json.get("grid", Json::Value()));
312 
313  // Topology.
314  auto nlayers = json["topology"].size() + 2;
315  VectorXi topol(nlayers);
316  topol[0] = fgrid->GetDimension();
317  topol[nlayers-1] = 1;
318  for(int i = 0; i < static_cast<int>(json["topology"].size()); ++i)
319  topol[i+1] = json["topology"][i].asInt();
320 
321  auto weight = json.get("weight", 1.).asDouble();
322  auto temp = json["temperature"].asDouble();
323  auto nsweep = json["nsweep"].asUInt();
324 
325  // Assume all vectors are the same size.
326  std::vector<double> lowerb, upperb, lowerk, upperk;
327  for(int i = 0; i < static_cast<int>(json["lower_bound_restraints"].size()); ++i)
328  {
329  lowerk.push_back(json["lower_bound_restraints"][i].asDouble());
330  upperk.push_back(json["upper_bound_restraints"][i].asDouble());
331  lowerb.push_back(json["lower_bounds"][i].asDouble());
332  upperb.push_back(json["upper_bounds"][i].asDouble());
333  }
334 
335  auto* m = new ANN(world, comm, topol, fgrid, hgrid, ugrid, lowerb, upperb, lowerk, upperk, temp, weight, nsweep);
336 
337  // Set optional params.
338  m->SetPrevWeight(json.get("prev_weight", 1).asDouble());
339  m->SetOutput(json.get("output_file", "ann.out").asString());
340  m->SetOutputOverwrite( json.get("overwrite_output", true).asBool());
341  m->SetConvergeIters(json.get("converge_iters", 0).asUInt());
342  m->SetMaxIters(json.get("max_iters", 1000).asUInt());
343  m->SetMinLoss(json.get("min_loss", 0).asDouble());
344 
345  if(json.isMember("net_state") && json.isMember("load_bias"))
346  m->ReadBias(json["net_state"].asString(), json["load_bias"].asString());
347 
348  return m;
349  }
350 }
Requirements on an object.
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Artificial Neural Network Method.
Definition: ANN.h:36
std::vector< double > lowerb_
Definition: ANN.h:78
void PostSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Method call post simulation.
Definition: ANN.cpp:194
Grid< unsigned int > * hgrid_
Histogram grid.
Definition: ANN.h:66
double temp_
Definition: ANN.h:59
std::vector< double > lowerk_
Definition: ANN.h:83
static ANN * Build(const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
Build a derived method from JSON node.
Definition: ANN.cpp:287
Eigen::MatrixXd hist_
Definition: ANN.h:73
bool preloaded_
Is the network preloaded?
Definition: ANN.h:90
void TrainNetwork()
Trains the neural network.
Definition: ANN.cpp:198
void PreSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Method call prior to simulation initiation.
Definition: ANN.cpp:76
unsigned int sweep_
Definition: ANN.h:43
Grid< Eigen::VectorXd > * fgrid_
Force grid.
Definition: ANN.h:63
nnet::neural_net net_
Neural network.
Definition: ANN.h:50
ANN(const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< unsigned int > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, unsigned int nsweep)
Constructor.
Definition: ANN.cpp:34
double pweight_
Definition: ANN.h:54
void ReadBias(const std::string &, const std::string &)
Load network state and bias from file.
Definition: ANN.cpp:264
unsigned int citers_
Number of iterations after which we turn on full weight.
Definition: ANN.h:47
void WriteBias()
Writes out the bias to file.
Definition: ANN.cpp:245
void PostIntegration(Snapshot *snapshot, const class CVManager &cvmanager) override
Method call post integration.
Definition: ANN.cpp:113
std::string outfile_
Output filename.
Definition: ANN.h:87
Grid< double > * ugrid_
Unbiased histogram grid.
Definition: ANN.h:69
bool overwrite_
Overwrite outputs?
Definition: ANN.h:93
Exception to be thrown when building the Driver fails.
Collective variable manager.
Definition: CVManager.h:43
CVList GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:81
static bool IsMasterRank(const MPI_Comm &comm)
Check if current processor is master.
const std::vector< double > GetLower() const
Return the lower edges of the Grid.
Definition: GridBase.h:231
size_t size() const
Get the size of the internal storage vector.
Definition: GridBase.h:326
T * data()
Get pointer to the internal data storage vector.
Definition: GridBase.h:338
size_t GetDimension() const
Get the dimension.
Definition: GridBase.h:195
const T & at(const std::vector< int > &indices) const
Access Grid element read-only.
Definition: GridBase.h:546
const std::vector< double > GetUpper() const
Return the upper edges of the Grid.
Definition: GridBase.h:262
void syncGrid()
Sync the grid.
Definition: GridBase.h:146
Basic Grid.
Definition: Grid.h:59
static Grid< T > * BuildGrid(const Json::Value &json)
Set up the grid.
Definition: Grid.h:127
iterator begin()
Return iterator at first grid point.
Definition: Grid.h:527
iterator end()
Return iterator after last valid grid point.
Definition: Grid.h:540
Interface for Method implementations.
Definition: Method.h:44
mxx::comm comm_
Local MPI communicator.
Definition: Method.h:47
std::vector< unsigned int > cvmask_
Mask which identifies which CVs to act on.
Definition: Method.h:50
mxx::comm world_
Global MPI communicator.
Definition: Method.h:46
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:48
double GetKb() const
Get system Kb.
Definition: Snapshot.h:165
const Matrix3 & GetVirial() const
Get box virial.
Definition: Snapshot.h:135
const std::vector< Vector3 > & GetForces() const
Access the per-particle forces.
Definition: Snapshot.h:351
size_t GetIteration() const
Get the current iteration.
Definition: Snapshot.h:105
Map for histogram and coefficients.
Definition: Basis.h:40