Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ARS_GAN/A2C_MNIST.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
170 lines (156 sloc)
7.17 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Copyright (C) 2018 Xiao Wang | |
# Email:xiaowang20140001@gmail.com | |
# | |
##Dependence keras 2.0.2(no higher): ~/anaconda3/bin/pip install --ignore-installed keras==2.0.2 | |
import warnings | |
warnings.filterwarnings('ignore') | |
#That's because keras updating notice when you running, and we do not want to waste time on that | |
from data_processing.read_mnist import read_mnist | |
import parser | |
import time | |
import os | |
import numpy as np | |
#import gym | |
#import ray | |
from GAN.LargeGan import LargeGan | |
from Evaluator.Evaluator import Evaluator | |
from data_processing.face_process import get_image | |
#set_gpu() | |
#from GAN.Generator import * | |
#from GAN.Discriminator import * | |
from GAN.MNIST_G import arsgan_generator | |
from GAN.MNIST_D import arsgan_discriminator | |
from GAN.ImprovedWGAN import * | |
from ops.argparser import argparser | |
from Evaluator.Evaluator import * | |
import tensorflow as tf | |
from A2C.Learner import Learner | |
def set_gpu(gpu_id): | |
"""This is only use for my server, others can comment this function and use your own settings""" | |
os.environ["CUDA_VISIBLE_DEVICES"] =str(gpu_id) | |
os.system("export CUDA_HOME=/usr/local/cuda-8.0/") | |
os.system("export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH") | |
config = tf.ConfigProto(allow_soft_placement=True) | |
config.gpu_options.allow_growth=True #allocate as you need | |
session = tf.Session(config=config) | |
K.set_session(session) | |
def preparemodel(Evaluator,params): | |
z_dim=params['noise_dim']#Noise dim | |
init_nch=params['channel_size']#channel size | |
init_row=params['image_height'] | |
init_col=params['image_width'] | |
print('Initializing Generator!') | |
Generator=[] | |
for i in range(params['n_workers']): | |
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess: | |
#K.set_session(sess) | |
if True: | |
tempGenerator=arsgan_generator(z_dim, params['batch_size'], init_row, init_col,init_nch, architecture_id='000', output_act='tanh') | |
Generator.append(tempGenerator) | |
input_shape=(params['image_width'],params['image_height'],params['channel_size']) | |
print('Initializing Discriminator!') | |
Discriminator=[] | |
for i in range(params['n_workers']): | |
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess: | |
#K.set_session(sess) | |
if True: | |
tempDiscriminator=arsgan_discriminator(input_shape,1,architecture_id='001', output_act='sigmoid') | |
Discriminator.append(tempDiscriminator) | |
print('Finish Discriminator!') | |
# Realdiscriminator=dcgan_discriminator(input_shape,discriminate_class,architecture_id='001', output_act='sigmoid')#It will update only for decided steps, and it will offer a copy to help D come back to original states | |
# Realgenerator=dcgan_generator(z_dim, params['batch_size'],init_nch, init_row, init_col, architecture_id='000', output_act='tanh') | |
# #Set weights are same | |
# Realgenerator.load_weights('beginG.h5') | |
# Realdiscriminator.load_weights('beginD.h5') | |
for i in range(0,params['n_workers']): | |
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess: | |
if True: | |
Discriminator[i].load_weights('beginD.h5') | |
Generator[i].load_weights('beginG.h5') | |
print('Initializing WGAN!') | |
ARSGAN_GAN=[] | |
for i in range(params['n_workers']): | |
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess: | |
# K.set_session(sess) | |
if True: | |
tempARSGAN_GAN=ImprovedWGAN(Generator[i], | |
Discriminator[i], | |
1, | |
params, | |
min_grad=1e-3, | |
max_patient=100, len_loss_his=100, idx=0, base_lr=params['base_lr'], | |
max_inception_score=10,logdir=params['gan_logdir'],evaluator=Evaluator, | |
is_largeGan=True,is_face=False) | |
ARSGAN_GAN.append(tempARSGAN_GAN) | |
return ARSGAN_GAN | |
def preparerealmodel(params): | |
z_dim=params['noise_dim']#Noise dim | |
init_nch=params['channel_size']#channel size | |
init_row=params['image_height'] | |
init_col=params['image_width'] | |
print('Initializing Generator!') | |
realgenerator=arsgan_generator(z_dim, params['batch_size'], init_row, init_col,init_nch, architecture_id='000', output_act='tanh') | |
if params['resume_id']!=0: | |
realpath=os.path.join(os.getcwd(),params['a2cmodel_path']) | |
realgpath=os.path.join(realpath,'realG_step'+str(params['resume_id'])+'.h5') | |
realdpath=os.path.join(realpath,'realD_step'+str(params['resume_id'])+'.h5') | |
realgenerator.load_weights(realgpath) | |
realgenerator.save_weights('beginG.h5') | |
input_shape=(params['image_width'],params['image_height'],params['channel_size']) | |
realdiscriminator=arsgan_discriminator(input_shape,1,architecture_id='001', output_act='sigmoid') | |
if params['resume_id']!=0: | |
realdiscriminator.load_weights(realdpath) | |
realdiscriminator.save_weights('beginD.h5') | |
return realgenerator,realdiscriminator | |
def run_a2c(evaluatemodel,largegan,num_class,params,xt): | |
dir_path = params['dir_path'] | |
if not(os.path.exists(dir_path)): | |
os.makedirs(dir_path) | |
logdir = params['a2clog_path'] | |
if not(os.path.exists(logdir)): | |
os.makedirs(logdir) | |
ob_dim = params['ob_dim'] | |
ac_dim =params['ac_dim'] | |
A2CLearner = Learner( | |
cmodel=largegan, | |
evaluator=evaluatemodel, | |
num_class=num_class, | |
ob_dim = ob_dim, | |
ac_dim =ac_dim, | |
num_workers=params['n_workers'], | |
logdir=logdir, | |
rollout_length=params['rollout_length'], | |
params=params, | |
trainset=xt, | |
aimset=None,is_face=False) | |
A2CLearner.train(params['n_iter']) | |
def prepare_evaluationmodel(discriminate_class,xt,yt,xv,yv,params): | |
input_shape=(params['image_width'],params['image_height'],params['channel_size']) | |
evaluator=Evaluator(input_shape,discriminate_class,architecture_id='001', output_act='softmax') | |
print(evaluator.model.summary()) | |
trainset=np.concatenate([xt,xv],axis=0) | |
aimset=np.concatenate([yt,yv],axis=0) | |
print('combined trainset') | |
print(trainset.shape) | |
trained_evaluator=evaluator.train(trainset,aimset,params['evaluator_path']) | |
return trained_evaluator | |
if __name__ == '__main__': | |
#Initialize GPU | |
params=argparser()#Get the params from the command line | |
set_gpu(params['gpu_id']) | |
xt,yt,xv,yv = read_mnist() #Because core dumped, segmentation fault, we use another method | |
#data_path=os.path.join(os.getcwd(),'data') | |
#data_path=os.path.join(data_path,'celebA') | |
#xt=os.listdir(data_path) | |
discriminate_class=yt.shape[-1] | |
print('Initializing model!') | |
trained_evaluator=prepare_evaluationmodel(discriminate_class,xt,yt,xv,yv,params) | |
print('Initializing GAN model!') | |
realgenerator,realdiscriminator=preparerealmodel(params) | |
gans=preparemodel(trained_evaluator,params) | |
print('Building Large GAN') | |
combine=LargeGan(gans,params['n_workers'],params,trained_evaluator,is_face=False,logdir=params['a2ctmp_ganlog']) | |
combinemodel=combine.form_model() | |
print('Finishing Large GAN') | |
run_a2c(trained_evaluator,combine,1,params,xt) |