Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e4af42d
support host device buffer
gstnt Apr 20, 2026
447037c
move angle parameters to buffer
goodstudyqwq Apr 20, 2026
bcdd463
remove n_coords
goodstudyqwq Apr 20, 2026
b613eb5
move bond parameters to buffer
goodstudyqwq Apr 20, 2026
cee7552
move torsion parameters to buffer
goodstudyqwq Apr 20, 2026
f70880b
move improper2 parameters to buffer
goodstudyqwq Apr 20, 2026
50c4d15
move atypes parameters to buffer
goodstudyqwq Apr 20, 2026
ccf7e62
move charges parameters to buffer
goodstudyqwq Apr 20, 2026
bfa6c2e
move coords velocities dvelocities parameters to buffer
goodstudyqwq Apr 20, 2026
a9d042a
remove q_atoms in gpu side
goodstudyqwq Apr 20, 2026
2e3c864
move LJ parameters to buffer
goodstudyqwq Apr 20, 2026
b71d0eb
move restrpos parameters to buffer
goodstudyqwq Apr 20, 2026
78318b2
move restrangs parameters to buffer
goodstudyqwq Apr 20, 2026
4f4b904
move restrdists parameters to buffer
goodstudyqwq Apr 20, 2026
fff8136
move restrseqs and restrwalls to buffer
goodstudyqwq Apr 20, 2026
b59ab7e
move heavy to buffer
goodstudyqwq Apr 20, 2026
95f9a40
remove d_p_atoms
goodstudyqwq Apr 20, 2026
7e8a5ca
move excluded to buffer
goodstudyqwq Apr 20, 2026
5606506
move winv to buffer
goodstudyqwq Apr 20, 2026
175be86
move shake to buffer
goodstudyqwq Apr 20, 2026
aa72183
move xcoords to buffer
goodstudyqwq Apr 20, 2026
c2acf60
remove d_q_atypes
goodstudyqwq Apr 20, 2026
a4810e5
move q_elscales to buffer
goodstudyqwq Apr 20, 2026
e4e60e6
move shells to buffer
goodstudyqwq Apr 21, 2026
c5620ef
move wshells to buffer
goodstudyqwq Apr 21, 2026
726ba35
move lambda to buffer
goodstudyqwq Apr 21, 2026
3d362f6
remove d_EQ_nonbond_qq
goodstudyqwq Apr 21, 2026
18f6e84
move EQ_restraint to buffer
goodstudyqwq Apr 21, 2026
fe20260
remove CudaContext
goodstudyqwq Apr 21, 2026
9f38674
remove sync to host function
goodstudyqwq Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 138 additions & 77 deletions src/core/common/include/context.h
Original file line number Diff line number Diff line change
@@ -1,101 +1,179 @@
#pragma once

#include <vector_types.h>

#include <array>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <array>
#include <vector_types.h>

#include "common/include/md_types.h"
#include "common/include/nonbonded_14_mode.h"
#include "common/include/vdw_rules.h"
#include "host_device_buffer.h"

class Context {
public:

static Context& instance() {
static Context ctx;
return ctx;
}

/* =============================================
* == GENERAL
* == CONFIG
* =============================================
*/

int n_atoms = 0;
int n_atoms_solute = 0;
std::string base_folder;
bool run_gpu = false;

int n_atoms = 0; // the total number of atoms
int n_atoms_solute = 0; // the total number of solute number, in our system [0, n_atoms_solute) are solute, [n_atoms_solute, n_atoms) are water atoms
int n_patoms = 0;
int n_qatoms = 0;
int n_waters = 0;
int n_molecules = 0;

std::string base_folder;
double dt = 0.0;
double tau_T = 0.0;
md_t md;

bool run_gpu = false;
/*

*/
std::unique_ptr<HostDeviceBuffer<coord_t>> coords;
std::unique_ptr<HostDeviceBuffer<vel_t>> velocities;
std::unique_ptr<HostDeviceBuffer<dvel_t>> dvelocities;

/* =============================================
* == FROM MD FILE
* =============================================
*/

md_t md;
bool separate_scaling = false;

/* =============================================
* == FROM TOPOLOGY FILE
* =============================================
*/
/*
*/
int n_angles = 0;
int n_angles_solute = 0;
int n_cangles = 0;
std::unique_ptr<HostDeviceBuffer<angle_t>> angles;
std::unique_ptr<HostDeviceBuffer<cangle_t>> cangles;

int n_coords = 0;
int n_bonds = 0;
int n_bonds_solute = 0;
int n_cbonds = 0;
int n_angles = 0;
int n_angles_solute = 0;
int n_cangles = 0;
std::unique_ptr<HostDeviceBuffer<bond_t>> bonds;
std::unique_ptr<HostDeviceBuffer<cbond_t>> cbonds;
int n_torsions = 0;
int n_torsions_solute = 0;
int n_ctorsions = 0;
std::unique_ptr<HostDeviceBuffer<torsion_t>> torsions;
std::unique_ptr<HostDeviceBuffer<ctorsion_t>> ctorsions;
int n_impropers = 0;
int n_impropers_solute = 0;
int n_cimpropers = 0;
std::unique_ptr<HostDeviceBuffer<improper_t>> impropers;
std::unique_ptr<HostDeviceBuffer<cimproper_t>> cimpropers;

int n_restrspos = 0;
std::unique_ptr<HostDeviceBuffer<restrpos_t>> restrspos;

int n_restrangs = 0;
std::unique_ptr<HostDeviceBuffer<restrang_t>> restrangs;


int n_restrdists = 0;
std::unique_ptr<HostDeviceBuffer<restrdis_t>> restrdists;

int n_restrseqs = 0;
std::unique_ptr<HostDeviceBuffer<restrseq_t>> restrseqs;
int n_restrwalls = 0;
std::unique_ptr<HostDeviceBuffer<restrwall_t>> restrwalls;

/*
Atom Info
*/
int n_charges = 0;
int n_ccharges = 0;
std::unique_ptr<HostDeviceBuffer<charge_t>> charges;
std::unique_ptr<HostDeviceBuffer<ccharge_t>> ccharges;
int n_atypes = 0;
int n_catypes = 0;
std::unique_ptr<HostDeviceBuffer<atype_t>> atypes;
std::unique_ptr<HostDeviceBuffer<catype_t>> catypes;

std::unique_ptr<HostDeviceBuffer<bool>> heavy;
std::unique_ptr<HostDeviceBuffer<coord_t>> coords_init;

std::unique_ptr<HostDeviceBuffer<bool>> excluded;

std::unique_ptr<HostDeviceBuffer<double>> winv;

std::unique_ptr<HostDeviceBuffer<bool>> shell;
/*
Pair
*/
std::unique_ptr<HostDeviceBuffer<int>> LJ_matrix;

/*
Shake
*/
int n_shake_constraints = 0;
std::unique_ptr<HostDeviceBuffer<int>> mol_n_shakes;
std::unique_ptr<HostDeviceBuffer<shake_bond_t>> shake_bonds;
std::unique_ptr<HostDeviceBuffer<coord_t>> xcoords; // todo: It's just a temporary variables...
/*
Water
*/

std::unique_ptr<HostDeviceBuffer<shell_t>> wshells;



/*
FEP
*/
std::unique_ptr<HostDeviceBuffer<double>> lambdas; // Actually length is only 2..

/*
Energy
*/

std::unique_ptr<HostDeviceBuffer<E_restraint_t>> EQ_restraint;

/*
*/

int n_ngbrs23 = 0;
int n_ngbrs14 = 0;
int n_excluded = 0;
int n_cgrps_solute = 0;
int n_cgrps_solvent = 0;
int iuse_switch_atom = 0;

std::vector<coord_t> coords_top;
std::vector<bond_t> bonds;
std::vector<cbond_t> cbonds;
std::vector<angle_t> angles;
std::vector<cangle_t> cangles;
std::vector<torsion_t> torsions;
std::vector<ctorsion_t> ctorsions;
std::vector<improper_t> impropers;
std::vector<cimproper_t> cimpropers;
std::vector<charge_t> charges;
std::vector<ccharge_t> ccharges;
std::vector<atype_t> atypes;
std::vector<catype_t> catypes;
std::vector<int> atom_to_qi;
std::vector<ccharge_t> unified_ccharges;
std::vector<catype_t> unified_catypes;
std::vector<int> LJ_matrix;
std::unique_ptr<HostDeviceBuffer<ccharge_t>> unified_ccharges;
std::unique_ptr<HostDeviceBuffer<catype_t>> unified_catypes;
std::unique_ptr<HostDeviceBuffer<int3>> ngbrs_14;
std::vector<int3> ngbrs_14_builder;

std::unique_ptr<HostDeviceBuffer<int>> p_atoms_list;
std::unique_ptr<HostDeviceBuffer<int>> w_atoms_list;
std::unique_ptr<HostDeviceBuffer<int>> q_atoms_list;

std::unique_ptr<HostDeviceBuffer<ccharge_t>> charge_table_all;
std::unique_ptr<HostDeviceBuffer<double>> charge_pair_products;
std::unique_ptr<HostDeviceBuffer<int>> p_charge_types;
std::unique_ptr<HostDeviceBuffer<int>> w_charge_types;
std::unique_ptr<HostDeviceBuffer<int>> q_charge_types;

std::unique_ptr<HostDeviceBuffer<catype_t>> catype_table_all;
std::unique_ptr<HostDeviceBuffer<vdw_pair_param_t>> catype_pair_params;
std::unique_ptr<HostDeviceBuffer<int>> p_catype_types;
std::unique_ptr<HostDeviceBuffer<int>> w_catype_types;
std::unique_ptr<HostDeviceBuffer<int>> q_catype_types;

std::map<std::array<double, 4>, int> catype_to_type_host;
int n_charge_types = 0;
int zero_charge_type = -1;
int n_catype_types = 0;
int zero_catype_type = -1;

std::vector<int3> ngbrs_14;

std::unique_ptr<bool[]> excluded;
std::unique_ptr<bool[]> heavy;
std::vector<int> molecules;
std::vector<double> winv;
std::vector<cgrp_t> charge_groups;

topo_t topo = {};
Expand All @@ -106,7 +184,6 @@ class Context {
*/

int n_lambdas = 0;
std::vector<double> lambdas;

int n_qangcouples = 0;
int n_qangles = 0;
Expand Down Expand Up @@ -143,7 +220,7 @@ class Context {
std::vector<atype_t> q_atypes;
std::vector<bond_t> q_bonds;
std::vector<ccharge_t> q_charges;
std::vector<q_elscale_t> q_elscales;
std::unique_ptr<HostDeviceBuffer<q_elscale_t>> q_elscales;
std::vector<q_exclpair_t> q_exclpairs;
std::vector<q_improper_t> q_impropers;
std::vector<q_shake_t> q_shakes;
Expand All @@ -155,20 +232,6 @@ class Context {
* =============================================
*/

int n_restrseqs = 0;
int n_restrspos = 0;
int n_restrdists = 0;
int n_restrangs = 0;
int n_restrwalls = 0;

std::vector<restrseq_t> restrseqs;
std::vector<restrpos_t> restrspos;
std::vector<restrdis_t> restrdists;
std::vector<restrang_t> restrangs;
std::vector<restrwall_t> restrwalls;

std::unique_ptr<bool[]> shell;

/* =============================================
* == SHELLS / SOLVENT
* =============================================
Expand All @@ -185,27 +248,19 @@ class Context {
int n_shells = 0;
std::vector<std::vector<int>> list_sh;
std::vector<std::vector<int>> nsort;
std::vector<shell_t> wshells;

/* =============================================
* == SHAKE
* =============================================
*/

int n_shake_constraints = 0;
std::vector<int> mol_n_shakes;
std::vector<shake_bond_t> shake_bonds;

/* =============================================
* == CALCULATED IN THE INTEGRATION
* =============================================
*/

std::vector<int> p_atoms;
std::vector<coord_t> coords;
std::vector<coord_t> xcoords;
std::vector<vel_t> velocities;
std::vector<dvel_t> dvelocities;

energy_t E_total = {};
std::vector<energy_t> EQ_total;
Expand All @@ -225,7 +280,6 @@ class Context {
std::vector<E_nonbonded_t> EQ_nonbond_qx;

E_restraint_t E_restraint = {};
std::vector<E_restraint_t> EQ_restraint;

double Temp = 0.0;
double Tfree = 0.0;
Expand All @@ -244,6 +298,7 @@ class Context {
* =============================================
*/

bool separate_scaling = false;
double Ndegf = 0.0;
double Ndegfree = 0.0;
double Ndegf_solvent = 0.0;
Expand All @@ -270,7 +325,7 @@ class Context {
}

const ccharge_t& unified_ccharge_by_code(int code) const {
return unified_ccharges[code - 1];
return unified_ccharges->cpu_data_p[code - 1];
}

const ccharge_t& unified_ccharge(int atom_idx, int state) const {
Expand All @@ -282,21 +337,27 @@ class Context {
}

const catype_t& unified_catype_by_code(int code) const {
return unified_catypes[code - 1];
return unified_catypes->cpu_data_p[code - 1];
}

const catype_t& unified_catype(int atom_idx, int state) const {
return unified_catype_by_code(unified_atype_code(atom_idx, state));
}

void cuda_initialize_helpers();
void cuda_free_helpers();
void cuda_reset_energies();
void cuda_sync_all_to_device();

private:
Context() = default;
~Context() {}

void cuda_initialize_atom_lists_host();
void cuda_initialize_ngbrs14_host();
void cuda_initialize_charge_tables_host();
void cuda_initialize_catype_tables_host();

Context(const Context&) = delete;
Context& operator=(const Context&) = delete;




};
17 changes: 17 additions & 0 deletions src/core/common/include/cuda_runtime_utility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <cuda_runtime.h>

#include <cstdio>
#include <cstdlib>

inline void check_cuda(cudaError_t status) {
if (status != cudaSuccess) {
std::printf(">>> FATAL: CUDA call failed with error code %d: %s\n", status, cudaGetErrorString(status));
std::exit(EXIT_FAILURE);
}
}

inline void check_cudaMalloc(void** dev_ptr, size_t size) {
check_cuda(cudaMalloc(dev_ptr, size));
}
Loading