SSAGES  0.9.3
Software Suite for Advanced General Ensemble Simulations
Umbrella.cpp
1 
21 #include "Umbrella.h"
22 #include "Snapshot.h"
23 #include "CVs/CVManager.h"
24 #include "Validator/ObjectRequirement.h"
25 #include "Drivers/DriverException.h"
26 #include "schema.h"
27 #include <iostream>
28 
29 using namespace Json;
30 
31 namespace SSAGES
32 {
33  void Umbrella::PreSimulation(Snapshot* /* snapshot */, const CVManager& cvmanager)
34  {
35  if(IsMasterRank(comm_))
36  {
37  if(append_)
38  umbrella_.open(filename_.c_str(), std::ofstream::out | std::ofstream::app);
39  else
40  {
41  // Write out header.
42  umbrella_.open(filename_.c_str(), std::ofstream::out);
43  umbrella_ << "#";
44  umbrella_ << "Iteration ";
45 
46  auto cvs = cvmanager.GetCVs(cvmask_);
47  for(size_t i = 0; i < cvs.size(); ++i)
48  umbrella_ << "cv_" + std::to_string(i) << " ";
49 
50  for(size_t i = 0; i < cvs.size() - 1; ++i)
51  umbrella_ << "center_" + std::to_string(i) << " ";
52  umbrella_ << "center_" + std::to_string(cvs.size() - 1) << std::endl;
53  }
54  }
55  }
56 
57  void Umbrella::PostIntegration(Snapshot* snapshot, const CVManager& cvmanager)
58  {
59  // Get necessary info.
60  auto cvs = cvmanager.GetCVs(cvmask_);
61  auto& forces = snapshot->GetForces();
62  auto& virial = snapshot->GetVirial();
63 
64  for(size_t i = 0; i < cvs.size(); ++i)
65  {
66  // Get current CV and gradient.
67  auto& cv = cvs[i];
68  auto& grad = cv->GetGradient();
69  auto& boxgrad = cv->GetBoxGradient();
70  // Compute dV/dCV.
71  auto center = GetCurrentCenter(snapshot->GetIteration(), i);
72  auto D = kspring_[i]*cv->GetDifference(center);
73 
74  // Update forces.
75  for(size_t j = 0; j < forces.size(); ++j)
76  forces[j] -= D*grad[j];
77 
78  // Update virial.
79  virial += D*boxgrad;
80  }
81 
82  if(snapshot->GetIteration() % outfreq_ == 0)
83  PrintUmbrella(cvs, snapshot->GetIteration());
84  }
85 
86  void Umbrella::PostSimulation(Snapshot*, const CVManager&)
87  {
88  if(IsMasterRank(comm_))
89  umbrella_.close();
90  }
91 
92  void Umbrella::PrintUmbrella(const CVList& cvs, size_t iteration)
93  {
94  if(IsMasterRank(comm_))
95  {
96  umbrella_.precision(8);
97  umbrella_ << iteration << " ";
98 
99  // Print out CV values first.
100  for(auto& cv : cvs)
101  umbrella_ << cv->GetValue() << " ";
102 
103  // Print out target (center) of each CV.
104  for(size_t i = 0; i < cvs.size() - 1; ++i)
105  umbrella_ << GetCurrentCenter(iteration, i) << " ";
106  umbrella_ << GetCurrentCenter(iteration, cvs.size() - 1);
107 
108  umbrella_ << std::endl;
109  }
110  }
111 
112  Umbrella* Umbrella::Build(const Json::Value& json,
113  const MPI_Comm& world,
114  const MPI_Comm& comm,
115  const std::string& path)
116  {
117  ObjectRequirement validator;
118  Value schema;
119  CharReaderBuilder rbuilder;
120  CharReader* reader = rbuilder.newCharReader();
121 
122  reader->parse(JsonSchema::UmbrellaMethod.c_str(),
123  JsonSchema::UmbrellaMethod.c_str() + JsonSchema::UmbrellaMethod.size(),
124  &schema, nullptr);
125  validator.Parse(schema, path);
126 
127  // Validate inputs.
128  validator.Validate(json, path);
129  if(validator.HasErrors())
130  throw BuildException(validator.GetErrors());
131 
132  unsigned int wid = GetWalkerID(world, comm);
133  unsigned int wcount = GetNumWalkers(world, comm);
134  bool ismulti = wcount > 1;
135 
136  std::vector<std::vector<double>> ksprings;
137  for(auto& s : json["ksprings"])
138  {
139  std::vector<double> kspring;
140  if(s.isArray())
141  for(auto& k : s)
142  kspring.push_back(k.asDouble());
143  else
144  kspring.push_back(s.asDouble());
145 
146  ksprings.push_back(kspring);
147  }
148 
149  std::vector<std::vector<double>> centers0, centers1;
150  if(json.isMember("centers"))
151  {
152  for(auto& s : json["centers"])
153  {
154  std::vector<double> center;
155  if(s.isArray())
156  for(auto& k : s)
157  center.push_back(k.asDouble());
158  else
159  center.push_back(s.asDouble());
160 
161  centers0.push_back(center);
162  }
163  }
164  else if(json.isMember("centers0") && json.isMember("centers1") && json.isMember("timesteps"))
165  {
166  for(auto& s : json["centers0"])
167  {
168  std::vector<double> center;
169  if(s.isArray())
170  for(auto& k : s)
171  center.push_back(k.asDouble());
172  else
173  center.push_back(s.asDouble());
174 
175  centers0.push_back(center);
176  }
177 
178  for(auto& s : json["centers1"])
179  {
180  std::vector<double> center;
181  if(s.isArray())
182  for(auto& k : s)
183  center.push_back(k.asDouble());
184  else
185  center.push_back(s.asDouble());
186 
187  centers1.push_back(center);
188  }
189  }
190  else
191  throw BuildException({"Either \"centers\" or \"timesteps\", \"centers0\" and \"centers1\" must be defined for umbrella."});
192 
193  if(ksprings[0].size() != centers0[0].size())
194  throw BuildException({"Need to define a spring for every center or a center for every spring!"});
195 
196  // If only one set of center/ksprings are specified. Fill it up for multi.
197  if(ismulti)
198  {
199  if(ksprings.size() == 1)
200  for(size_t i = 1; i < wcount; ++i)
201  ksprings.push_back(ksprings[0]);
202  else if(ksprings.size() != wcount)
203  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"ksprings\" match the number of walkers.");
204  if(centers0.size() == 1)
205  for(size_t i = 1; i < wcount; ++i)
206  centers0.push_back(centers0[0]);
207  else if(centers0.size() != wcount)
208  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers\"/\"centers0\" match the number of walkers.");
209  if(centers1.size() == 1)
210  for(size_t i = 1; i < wcount; ++i)
211  centers1.push_back(centers1[0]);
212  else if(centers1.size()) // centers1 is optional.
213  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers1\" match the number of walkers.");
214  }
215 
216  auto freq = json.get("frequency", 1).asInt();
217 
218  size_t timesteps = 0;
219  if(json.isMember("timesteps"))
220  {
221  if(json["timesteps"].isArray())
222  timesteps = json["timesteps"][wid].asUInt();
223  else
224  timesteps = json["timesteps"].asUInt();
225  }
226 
227  std::string name = "umbrella.dat";
228  if(json["output_file"].isArray())
229  name = json["output_file"][wid].asString();
230  else if(ismulti)
231  throw std::invalid_argument(path + ": Multi-walker simulations require a separate output file for each.");
232  else
233  name = json["output_file"].asString();
234 
235  Umbrella* m = nullptr;
236  if(timesteps == 0)
237  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], name, freq);
238  else
239  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], centers1[wid], timesteps, name, freq);
240 
241  m->SetOutputFrequency(json.get("output_frequency",0).asInt());
242  m->SetAppend(json.get("append", false).asBool());
243 
244  return m;
245  }
246 }
Requirements on an object.
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Exception to be thrown when building the Driver fails.
Collective variable manager.
Definition: CVManager.h:43
CVList GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:81
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:48
const Matrix3 & GetVirial() const
Get box virial.
Definition: Snapshot.h:135
const std::vector< Vector3 > & GetForces() const
Access the per-particle forces.
Definition: Snapshot.h:351
size_t GetIteration() const
Get the current iteration.
Definition: Snapshot.h:105
Umbrella sampling method.
Definition: Umbrella.h:36
void SetOutputFrequency(int outfreq)
Set output frequency.
Definition: Umbrella.h:150
void SetAppend(bool append)
Set append mode.
Definition: Umbrella.h:159
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51