Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Embedding-PFP/fully_connected.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
119 lines (107 sloc)
4.35 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import pickle | |
import random | |
import npy2npz as n2n | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
device = torch.device('cuda:1') | |
test_embeddings = n2n.get_dataset(Path('data/seqvec_cafa3_targets.npz'), False) | |
embeddings = n2n.get_dataset(Path('data/seqvec_goa_2017.npz'), False) | |
with open("data/all_annotated_go_terms.pkl", "rb") as f: | |
goterms = pickle.load(f) | |
with open("data/propagated_go_annotations.pkl", "rb") as f: | |
go_annotations = pickle.load(f) | |
class PFPDataset(Dataset): | |
def __init__(self, embeddings, annotations): | |
embedded_prots = sorted(embeddings.keys()) | |
self.prot_embeddings = [torch.from_numpy(embeddings[prot]) for prot in embedded_prots] | |
go2idx = dict(zip(sorted(goterms), range(len(goterms)))) | |
self.prot_annotations = [] | |
for prot in embedded_prots: | |
row = [0]*len(goterms) | |
for go in annotations[prot]: | |
row[go2idx[go]] = 1 | |
self.prot_annotations.append(row) | |
def __len__(self): | |
return len(self.prot_embeddings) | |
def __getitem__(self, idx): | |
return self.prot_embeddings[idx], self.prot_annotations[idx] | |
dataset = PFPDataset(embeddings, go_annotations) | |
print("Finished loading dataset") | |
def collate_fn(batch): | |
inputs = torch.cat([data[0] for data in batch], dim=0) | |
outputs = torch.tensor([data[1] for data in batch], dtype=torch.float) | |
return inputs, outputs | |
data_iterator = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True) | |
class PFPmodel(nn.Module): | |
def __init__(self, num_input, num_hidden, num_output, drop_prob=0.5): | |
super(PFPmodel, self).__init__() | |
self.num_output = num_output | |
self.linear = nn.Linear(num_input, num_hidden) | |
self.activation = nn.ReLU() | |
# self.dropout = nn.Dropout(drop_prob) | |
self.output = nn.Linear(num_hidden, num_output) | |
def forward(self, protein, labels=None): | |
# hidden = self.dropout(self.activation(self.linear(protein))) | |
hidden = self.activation(self.linear(protein)) | |
logits = self.output(hidden) | |
return torch.squeeze(logits) | |
model = PFPmodel(1024, 256, len(goterms)) | |
model.to(device) | |
criterion = nn.BCEWithLogitsLoss() | |
optimizer = optim.Adam(model.parameters()) | |
writer = SummaryWriter("logs/") | |
running_loss = 0. | |
for epoch in range(100): | |
for i, (data, target) in enumerate(data_iterator): | |
data = data.to(device) | |
target = target.to(device) | |
logit = model(data) | |
loss = criterion(logit, target) | |
running_loss += loss.item() | |
if i % 100 == 99: | |
writer.add_scalar('Training loss', running_loss/100, epoch*len(data_iterator)+i) | |
running_loss = 0. | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
model.to(torch.device("cpu")) | |
torch.save(model.state_dict(), "saves/fully_connected.save") | |
sorted_goterms = sorted(goterms) | |
test_predictions = {} | |
for name in test_embeddings: | |
if name not in test_predictions: | |
test_predictions[name] = [] | |
scores = torch.sigmoid(model(torch.from_numpy(test_embeddings[name]))) | |
for term, score in zip(sorted_goterms, scores.tolist()): | |
if score <= 0.01: | |
continue | |
test_predictions[name].append([term, round(score, 2)]) | |
taxon_proteins = {} | |
target_dir = Path("/net/kihara/home/nguye330/CAFA_evaluation/data/CAFA3_targets/Target files") | |
for filepath in target_dir.glob("target.*.fasta"): | |
filename = os.path.basename(filepath) | |
taxon_id = filename.split(".")[1] | |
taxon_proteins[taxon_id] = [] | |
fp = open(filepath, "r") | |
for line in fp: | |
if line[0] == '>': | |
taxon_proteins[taxon_id].append(line.split()[0][1:]) | |
fp.close() | |
predict_dir = Path("linear_predictions/") | |
for taxon_id in taxon_proteins: | |
fwrite = open(predict_dir.joinpath("LinearRegression_1_{}.txt".format(taxon_id)), "w") | |
fwrite.write("AUTHOR:\tMinh\nMODEL\t1\nKEYWORDS\tlinear regression\n") | |
for protein_id in taxon_proteins[taxon_id]: | |
if protein_id not in test_predictions: | |
continue | |
for goterm, prob in test_predictions[protein_id]: | |
fwrite.write("{}\t{}\t{:.2f}\n".format(protein_id, goterm, prob)) | |
fwrite.write("END\n") | |
fwrite.close() |