Overview
ProtPTDE (Protein Pre-Training Model-Assisted Protein Directed Evolution) is a computational strategy designed to assist protein directed evolution by integrating multiple deep learning models. A key design highlight of this framework lies in its streamlined parameter management: we have centralized the majority of parameters and hyperparameters involved in the entire architectural workflow and fitness prediction framework into a single configuration file, config/config.json.
This centralized structure enables unified parameter governance: when users need to adjust a parameter, they only need to modify the target entry in config.json—the system will automatically sync this update across all associated scripts. This eliminates the cumbersome, error-prone process of manually searching through multiple files to modify parameters individually, thereby achieving highly centralized control and automated parameter tuning.
Furthermore, we have prioritized framework extensibility to accommodate diverse research needs:
Adding custom protein language models: Users can easily integrate new protein language models by developing their own
function.pyfiles, following the template provided ingenerate_features/{model}_embedding/function.py. Once created, the new model is automatically incorporated into the framework’s model search scope, requiring no extensive modifications to the core codebase.Supporting multi-model embedding concatenation: We have expanded the framework’s capability to concatenate embeddings from multiple models (instead of limiting to a single model). This design not only grants users greater flexibility in model selection but also allows leveraging richer, multi-source embedding information—enhancing the potential for predicting the structure and fitness of complex proteins.
Prerequisites
install
Anaconda/Minicondasoftwareinstall Python packages
conda create -n Prot_PTDE python=3.13
conda activate Prot_PTDE
pip install torch==2.7.1
pip install tqdm
pip install click
pip install biopython
pip install "pandas[excel]"
conda install numba
pip install scikit-learn
pip install more-itertools
pip install iterative-stratification
pip install optuna
pip install transformers
pip install einops
pip install seaborn
pip install plotly
File Preparation
The model requires three essential input files:
A
xlsxfile (referred to as mutation_data_file) stored in thedata/directory, which contains mutation information and their corresponding fitness values.A
fastafile of the wild-type sequence, namedresult.fasta, located in thefeatures/wt/directory.A
config.jsonfile in theconfig/directory is used to manage all file parameters. You can use this project’s config.json file directly and adjust it to your needs.
Basic Usage
Processing data and configuring parameters
We need to have an xlsx file with two columns. The first column is the mutation information in the form of F282L, R11H, and the second column is the fitness label value (numeric). Change the value of the key basic_data_name in the config/config.json to the name of the mutation_data_file (without the extension).
Generate the fasta file of the mutant sequence according to the mutation information in the xlsx file and the fasta file of the wild-type sequence, and convert the xlsx file into a csv format file.
cd data
python convert_xlsx_to_csv_and_generate_fasta_file.py
cd ../
The following is an explanation of the parameters in config.json:
basic_data_name: the name of the mutation_data_file (without the extension)
single_model_embedding_output_dim: the final dimension of the linear transformation applied to embeddings generated by a single model, which is used for concatenation with embeddings from other models
all_model: After completing the embedding generation step, the generated embedding models and their corresponding embedding dimensions will be automatically updated here.
cross_validation: The parameters involved in the cross-validation section mainly include: hyperparameter_search, model_number, and training_parameter. Among them, hyperparameter_search uses the optuna library for hyperparameter search, including the number of model layers and maximum learning rate; model_number indicates the number of models selected each time from the available model range for embedding concatenation to perform joint inference and prediction; training_parameter represents parameters that need to be preset before model training, including device, minimum learning rate, initial learning rate, total epochs, warmup epochs ratio, batch size, test size, k-fold number, and whether to shuffle the training and validation set data.
best_hyperparameters: The optimal hyperparameters selected from the hyperparameter search range based on training performance, including the selected model combination, number of model layers, maximum learning rate, and random seed.
ensemble_size: Due to the randomness in model initialization, multiple training and inference runs are required to evaluate prediction performance. The number of training runs can be adjusted here.
final_model: The final_model section involves parameters mainly including train_parameter and finetune_parameter, which represent parameters that need to be preset in the training phase and fine-tuning phase respectively. The involved parameters are similar to training_parameter in cross_validation but with different values.
inference: The inference section involves only the max_mutations parameter, which indicates the maximum number of possible mutant proteins to be generated and inferred for classes with the same number of mutation sites. Any part exceeding this number will not be generated.
generating embeddings
If you don’t need to add other models, you can run the following code:
cd generate_features
python generate_all_embeddings.py
cd ../
If you need to add other models, you can create a {model}_embedding folder, imitate the script provided in the project to write the corresponding function.py script for the model you added, and then run the above code.
cross validation
Before concatenating multiple model embeddings, determine the output dimension of the linear transformation for single model and write it to the single_model_embedding_output_dim key in the config/config.json file. At the same time, modify the desired configuration parameters under the cross_validation key in the config/config.json.
Note: Bash scripts with filenames in the format of
2000.sh,2001.sh, etc.—specifically those located in the01_cross_validation/and02_final_model/folders—must be renamed according to your server configuration. Note that the last digit of each such script’s filename indicates the GPU card number used on the server (e.g.,2000.shcorresponds to GPU card 0,2013.shto GPU card 3, and so forth). Additionally, ensure the corresponding script names (e.g.,2000.sh,2001.sh) referenced in01_train.sh(within both01_cross_validation/and02_final_model/) are updated to match.
After completing the above operations, you can run the following code:
cd 01_cross_validation
bash 01_train.sh
python 02_Dis_cross_validation.py
cd ../
Select the best hyperparameters from the displot (Dis_cross_validation.pdf) and write them in best_hyperparameters key in the config/config.json. They are selected_models, num_layer and max_lr.
train and finetune
Adjust the number of model training runs and write it to the ensemble_size key in the config/config.json. The training parameters are the same each time, except for model initialization and the batch order of training data provided by DataLoader. The results are saved independently and finally merged and analyzed to evaluate the stability of the final prediction results.
Modify the basic parameters for model training and finetuning under the final_model key in the config/config.json.
Note: Bash scripts with filenames in the format of
2000.sh,2001.sh, etc.—specifically those located in the01_cross_validation/and02_final_model/folders—must be renamed according to your server configuration. Note that the last digit of each such script’s filename indicates the GPU card number used on the server (e.g.,2000.shcorresponds to GPU card 0,2013.shto GPU card 3, and so forth). Additionally, ensure the corresponding script names (e.g.,2000.sh,2001.sh) referenced in01_train.sh(within both01_cross_validation/and02_final_model/) are updated to match.
After completing the above operations, you can run the following code:
cd 02_final_model
bash 01_train.sh
python 02_plot_random_seed_train.py
Select a good randomseed based on the scatter plot (Scatter_best_train_test_epoch_ratio.html) and write it to the best_hyperparameters key in the config/config.json.
Then you can run the following code:
bash 03_train_ensemble.sh
bash 04_finetune.sh
bash 05_finetune_ensemble.sh
cd ../
inference
Determine the maximum number of mutation combinations and write it to the max_mutations key in the inference section of config/config.json. Then you can run the following code:
cd 03_inference
bash 01_generate_unpredicted_muts_csv.sh
bash 02_inference.sh
bash 03_inference_ensemble.sh
cd ../
Finally, only the one with the highest average fitness prediction value of the same site combination is retained, and then all site combinations are sorted in ascending order according to the standard deviation of the fitness value to understand the reliability of the prediction. The code is as follows:
cd 03_inference
bash 04_get_cluster_csv.sh
cd ../