SSAGES  0.9.3
Software Suite for Advanced General Ensemble Simulations
ANNCV.h
1 
20 #pragma once
21 
22 #include "CollectiveVariable.h"
23 #include "Validator/ObjectRequirement.h"
24 #include "Drivers/DriverException.h"
25 #include "Snapshot.h"
26 #include "schema.h"
27 
28 #include <array>
29 #include <cmath>
30 #include <assert.h>
31 
32 namespace SSAGES
33 {
35 
42  class ANNCV : public CollectiveVariable
43  {
44  private:
45  Label atomids_; // indices of atoms
46  double scaling_factor_; // scaling factor to make input unitless, note that it is unit dependent (a model trained with A will have different scaling factor with one trained with nm)
47  std::vector<unsigned int> num_nodes_; // numbers of nodes for neural network
48  std::vector<Eigen::MatrixXd> weight_coeff_;
49  std::vector<Vector> bias_;
50  std::vector<std::string> activations_;
51  int out_index_; // index of output component
52 
53  public:
54 
56 
67  Label atomids,
68  double scaling_factor,
69  std::vector<unsigned int> num_nodes,
70  std::string coeff_file, // file storing weights and bias of ANN
71  std::vector<std::string> activations,
72  int out_index
73  ) :
74  atomids_(atomids), scaling_factor_(scaling_factor), num_nodes_(num_nodes),
75  activations_(activations), out_index_(out_index)
76  {
77  // read coefficients from file
78  std::ifstream my_f(coeff_file);
79  std::string temp_vec;
80  int layer_index = 0;
81  if (num_nodes_[0] != atomids_.size() * 3) {
82  throw BuildException({
83  "WARNING: input dim should be " + std::to_string(atomids_.size() * 3) + " found: "
84  + std::to_string(num_nodes_[0])
85  });
86  }
87  while (std::getline(my_f, temp_vec) )
88  {
89  std::istringstream ss(temp_vec);
90  std::string token;
91  std::vector<double> temp_weight, temp_bias;
92  while (std::getline(ss, token, ',')) // coefficients are separated by comma
93  {
94  temp_weight.push_back(stod(token));
95  }
96  if (temp_weight.size() != num_nodes_[layer_index] * num_nodes_[layer_index + 1]) {
97  throw BuildException({
98  "WARNING: layer weight size = " + std::to_string(temp_weight.size()) + " expected: "
99  + std::to_string(num_nodes_[layer_index] * num_nodes_[layer_index + 1])
100  });
101  }
102  Eigen::Map<Eigen::MatrixXd> temp_weight_v(
103  &temp_weight[0], num_nodes_[layer_index], num_nodes_[layer_index + 1]
104  );
105  std::getline(my_f, temp_vec);
106  std::istringstream ss2(temp_vec);
107  while (std::getline(ss2, token, ','))
108  {
109  temp_bias.push_back(stod(token));
110  }
111  if (temp_bias.size() != num_nodes_[layer_index + 1]) {
112  throw BuildException({
113  "WARNING: layer bias size = " + std::to_string(temp_bias.size())
114  + " expected: " + std::to_string(num_nodes_[layer_index + 1])
115  });
116  }
117  Vector temp_bias_v = Vector::Map(temp_bias.data(), temp_bias.size());
118  weight_coeff_.push_back(temp_weight_v);
119  bias_.push_back(temp_bias_v);
120  layer_index ++;
121  }
122  }
123 
125 
128  void Initialize(const Snapshot& snapshot) override
129  {
130  using std::to_string;
131 
132  std::vector<int> found;
133  snapshot.GetLocalIndices(atomids_, &found);
134  size_t nfound = found.size();
135  MPI_Allreduce(MPI_IN_PLACE, &nfound, 1, MPI_INT, MPI_SUM, snapshot.GetCommunicator());
136 
137  if(nfound != atomids_.size())
138  throw BuildException({
139  "ANNCV: Expected to find " +
140  to_string(atomids_.size()) +
141  " atoms, but only found " +
142  to_string(nfound) + "."
143  });
144  }
145 
146  std::vector<Vector> forward_prop(Vector& input_vec) {
147  std::vector<Vector> output_of_layers;
148  Vector temp_out = input_vec;
149  output_of_layers.push_back(temp_out);
150  for (size_t ii = 0; ii < weight_coeff_.size(); ii ++) {
151  temp_out = weight_coeff_[ii].transpose() * temp_out + bias_[ii];
152  if (activations_[ii] == "Tanh") {
153  for (int kk = 0; kk < temp_out.size(); kk ++) {
154  temp_out[kk] = std::tanh(temp_out[kk]);
155  }
156  }
157  else if (activations_[ii] == "ReLU") {
158  for (int kk = 0; kk < temp_out.size(); kk ++) {
159  temp_out[kk] = temp_out[kk] < 0 ? 0 : temp_out[kk];
160  }
161  }
162  output_of_layers.push_back(temp_out);
163  }
164  return output_of_layers;
165  }
166 
167  std::vector<Vector> back_prop(std::vector<Vector>& output_of_layers) {
168  auto deriv_back = output_of_layers;
169  int num = output_of_layers.size();
170  for (int ii = 0; ii < output_of_layers[num - 1].size(); ii ++ ) {
171  if (ii == out_index_) {
172  deriv_back[num - 1][ii] = 1;
173  }
174  else {
175  deriv_back[num - 1][ii] = 0;
176  }
177  }
178  for (int ii = num - 2; ii >= 0; ii --) {
179  if (activations_[ii] == "Tanh") {
180  for (int kk = 0; kk < deriv_back[ii + 1].size(); kk ++) {
181  deriv_back[ii + 1][kk] = deriv_back[ii + 1][kk] * (
182  1 - output_of_layers[ii + 1][kk] * output_of_layers[ii + 1][kk]);
183  }
184  }
185  deriv_back[ii] = weight_coeff_[ii] * deriv_back[ii + 1];
186  }
187  return deriv_back;
188  }
189 
191 
194  void Evaluate(const Snapshot& snapshot) override
195  {
196  // Get data from snapshot.
197  auto n = snapshot.GetNumAtoms();
198  const auto& pos = snapshot.GetPositions();
199  auto& comm = snapshot.GetCommunicator();
200 
201  // Initialize gradient.
202  std::fill(grad_.begin(), grad_.end(), Vector3{0,0,0});
203  grad_.resize(n, Vector3{0,0,0});
204 
205  // Vector3 xi{0, 0, 0}, xj{0, 0, 0}, xk{0, 0, 0};
206  std::vector<Vector3> positions(atomids_.size(), Vector3({0,0,0}));
207  Label local_idx;
208  snapshot.GetLocalIndices(atomids_, &local_idx);
209  for (size_t ii = 0; ii < atomids_.size(); ii ++) {
210  if (local_idx[ii] != -1) {
211  positions[ii] = pos[local_idx[ii]];
212  }
213  }
214  // By performing a reduce, we actually collect all. This can
215  // be converted to a more intelligent allgater on rank then bcast.
216  MPI_Allreduce(MPI_IN_PLACE, positions.data(), positions.size(), MPI_DOUBLE, MPI_SUM, comm);
217  auto com = snapshot.CenterOfMass(atomids_, false); // center of mass coordinates (not weighted with mass)
218  // remove translation degree of freedom
219  for (auto& temp_pos: positions) {
220  temp_pos = temp_pos - com;
221  }
222  Vector input_vec(positions.size() * 3); // flatten input vector
223  for (size_t ii = 0; ii < positions.size(); ii ++) {
224  for (size_t jj = 0; jj < 3; jj ++) {
225  input_vec[ii * 3 + jj] = positions[ii][jj];
226  }
227  }
228  input_vec = input_vec / scaling_factor_;
229  auto output_of_layers = forward_prop(input_vec);
230  val_ = output_of_layers[output_of_layers.size() - 1][out_index_];
231  auto deriv_back = back_prop(output_of_layers);
232  // subtract mean from deriv_back[0]
233  int num_atoms = deriv_back[0].size() / 3;
234  double average[3] = {0};
235  for (int kk = 0; kk < 3; kk ++) {
236  for (int ii = 0; ii < num_atoms; ii ++) {
237  average[kk] += deriv_back[0][ii * 3 + kk];
238  }
239  average[kk] /= num_atoms;
240  }
241  }
242 
244  static ANNCV* Build(const Json::Value& json, const std::string& path)
245  {
246  Json::ObjectRequirement validator;
247  Json::Value schema;
248  Json::CharReaderBuilder rbuilder;
249  Json::CharReader* reader = rbuilder.newCharReader();
250 
251  reader->parse(JsonSchema::ANNCV.c_str(),
252  JsonSchema::ANNCV.c_str() + JsonSchema::ANNCV.size(),
253  &schema, NULL);
254  validator.Parse(schema, path);
255 
256  // Validate inputs.
257  validator.Validate(json, path);
258  if(validator.HasErrors())
259  throw BuildException(validator.GetErrors());
260 
261  std::vector<int> atomids;
262  for(auto& s : json["atom_ids"])
263  atomids.push_back(s.asInt());
264  double scaling_factor = json["scaling_factor"].asDouble();
265  std::vector<unsigned int> num_nodes;
266  for (auto &s : json["num_nodes"])
267  num_nodes.push_back(s.asUInt());
268  std::vector<std::string> activations;
269  std::string coeff_file = json["coeff_file"].asString();
270  for (auto &s : json["activations"]) {
271  auto name = s.asString();
272  if (name == "Tanh" || name == "ReLU" || name == "Linear")
273  activations.push_back(name);
274  else
275  throw std::runtime_error("invalid activation function " + name + " provided");
276  }
277  int out_index = json["index"].asInt();
278  return new ANNCV(atomids, scaling_factor, num_nodes, coeff_file, activations, out_index);
279  }
280  };
281 }
Requirements on an object.
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
ANN (artifical neural network) collective variables.
Definition: ANNCV.h:43
void Initialize(const Snapshot &snapshot) override
Initialize necessary variables.
Definition: ANNCV.h:128
void Evaluate(const Snapshot &snapshot) override
Evaluate the CV.
Definition: ANNCV.h:194
static ANNCV * Build(const Json::Value &json, const std::string &path)
Set up collective variable.
Definition: ANNCV.h:244
ANNCV(Label atomids, double scaling_factor, std::vector< unsigned int > num_nodes, std::string coeff_file, std::vector< std::string > activations, int out_index)
Constructor.
Definition: ANNCV.h:66
Exception to be thrown when building the Driver fails.
Abstract class for a collective variable.
std::vector< Vector3 > grad_
Gradient vector dCv/dxi.
double val_
Current value of CV.
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:48
unsigned GetNumAtoms() const
Get number of atoms in this snapshot.
Definition: Snapshot.h:202
void GetLocalIndices(const Label &ids, Label *indices) const
Definition: Snapshot.h:537
const mxx::comm & GetCommunicator() const
Get communicator for walker.
Definition: Snapshot.h:186
const std::vector< Vector3 > & GetPositions() const
Access the particle positions.
Definition: Snapshot.h:325
Vector3 CenterOfMass(const Label &indices, bool mass_weight=true) const
Compute center of mass of a group of atoms based on index.
Definition: Snapshot.h:442
Eigen::VectorXd Vector
Arbitrary length vector.
Definition: types.h:30
Eigen::Vector3d Vector3
Three-dimensional vector.
Definition: types.h:33
std::vector< int > Label
List of integers.
Definition: types.h:48