Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import os
import sys
import pickle
import random
import npy2npz as n2n
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
device = torch.device('cuda')
test_embeddings = n2n.get_dataset(Path('data/protbert_cafa3_targets.npz'), False)
embeddings = n2n.get_dataset(Path('data/protbert_goa_2017.npz'), False)
with open("data/all_annotated_go_terms.pkl", "rb") as f:
goterms = pickle.load(f)
with open("data/propagated_go_annotations.pkl", "rb") as f:
go_annotations = pickle.load(f)
class PFPDataset(Dataset):
def __init__(self, embeddings, annotations):
embedded_prots = sorted(embeddings.keys())
self.prot_embeddings = [torch.from_numpy(embeddings[prot]) for prot in embedded_prots]
go2idx = dict(zip(sorted(goterms), range(len(goterms))))
self.prot_annotations = [[go2idx[go] for go in annotations[prot]] for prot in embedded_prots]
self.all_go = set(range(len(goterms)))
def __len__(self):
return len(self.prot_embeddings)
def __getitem__(self, idx):
prot_embedding = self.prot_embeddings[idx]
negative_goterms = self.all_go - set(self.prot_annotations[idx])
n_terms = len(self.prot_annotations[idx])
golabels = random.sample(negative_goterms, n_terms)
golabels.extend(self.prot_annotations[idx])
predlabels = [0]*n_terms + [1]*n_terms
return prot_embedding, golabels, predlabels
dataset = PFPDataset(embeddings, go_annotations)
def collate_fn(batch):
proteins = torch.cat([data[0] for data in batch], dim=0)
maxlen = max([len(data[1]) for data in batch])
for data in batch:
data[1].extend([-1] * (maxlen-len(data[1])))
data[2].extend([0] * (maxlen-len(data[2])))
goterms = torch.tensor([data[1] for data in batch])
labels = torch.tensor([data[2] for data in batch], dtype=torch.float)
return proteins, goterms, labels
data_iterator = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)
class PFPmodel(nn.Module):
def __init__(self, num_input, num_hidden, num_output):
super(PFPmodel, self).__init__()
self.num_output = num_output
self.linear = nn.Linear(num_input, num_hidden)
self.activation = nn.ReLU()
init_tensor = (1 / torch.sqrt(torch.tensor(num_hidden))) * torch.randn(num_output+1, num_hidden)
init_tensor[-1] = 0
bias_tensor = torch.zeros(num_output+1)
self.go_embeddings = nn.Parameter(init_tensor)
self.go_biases = nn.Parameter(bias_tensor)
self.register_parameter(name='goterm_weights', param=self.go_embeddings)
self.register_parameter(name='goterm_biases', param=self.go_biases)
def forward(self, protein, labels=None):
if labels is None:
labels = torch.tensor([range(self.num_output)]*protein.shape[0])
hidden = self.activation(self.linear(protein))
hidden = torch.unsqueeze(hidden, dim=1)
logits = torch.matmul(hidden, torch.transpose(self.go_embeddings[labels], 1, 2))
logits = torch.squeeze(logits) + self.go_biases[labels]
return torch.squeeze(logits)
model = PFPmodel(1024, 256, len(goterms))
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
writer = SummaryWriter("logs/")
running_loss = 0.
for epoch in range(100):
for i, (data, target_embedding, target) in enumerate(data_iterator):
data = data.to(device)
target_embedding = target_embedding.to(device)
target = target.to(device)
logit = model(data, target_embedding)
loss = criterion(logit, target)
running_loss += loss.item()
if i % 100 == 99:
writer.add_scalar('Training loss', running_loss/100, epoch*len(data_iterator)+i)
running_loss = 0.
optimizer.zero_grad()
loss.backward()
model.go_embeddings.grad[-1] = 0
model.go_biases.grad[-1] = 0
optimizer.step()
model.to(torch.device("cpu"))
torch.save(model.state_dict(), "saves/neg_sampling.save")
sorted_goterms = sorted(goterms)
test_predictions = {}
for name in test_embeddings:
if name not in test_predictions:
test_predictions[name] = []
scores = torch.sigmoid(model(torch.from_numpy(test_embeddings[name])))
for term, score in zip(sorted_goterms, scores.tolist()):
if score <= 0.01:
continue
test_predictions[name].append([term, round(score, 2)])
taxon_proteins = {}
target_dir = Path("/net/kihara/home/nguye330/CAFA_evaluation/data/CAFA3_targets/Target files")
for filepath in target_dir.glob("target.*.fasta"):
filename = os.path.basename(filepath)
taxon_id = filename.split(".")[1]
taxon_proteins[taxon_id] = []
fp = open(filepath, "r")
for line in fp:
if line[0] == '>':
taxon_proteins[taxon_id].append(line.split()[0][1:])
fp.close()
predict_dir = Path("linear_predictions/")
for taxon_id in taxon_proteins:
fwrite = open(predict_dir.joinpath("LinearRegression_1_{}.txt".format(taxon_id)), "w")
fwrite.write("AUTHOR:\tMinh\nMODEL\t1\nKEYWORDS\tlinear regression\n")
for protein_id in taxon_proteins[taxon_id]:
if protein_id not in test_predictions:
continue
for goterm, prob in test_predictions[protein_id]:
fwrite.write("{}\t{}\t{:.2f}\n".format(protein_id, goterm, prob))
fwrite.write("END\n")
fwrite.close()