Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
#
# Copyright (C) 2018 Xiao Wang
# Email:xiaowang20140001@gmail.com
#
##Dependence keras 2.0.2(no higher): ~/anaconda3/bin/pip install --ignore-installed keras==2.0.2
import warnings
warnings.filterwarnings('ignore')
#That's because keras updating notice when you running, and we do not want to waste time on that
from data_processing.read_mnist import read_mnist
import parser
import time
import os
import numpy as np
#import gym
#import ray
from GAN.LargeGan import LargeGan
from Evaluator.Evaluator import Evaluator
from data_processing.face_process import get_image
#set_gpu()
#from GAN.Generator import *
#from GAN.Discriminator import *
from GAN.MNIST_G import arsgan_generator
from GAN.MNIST_D import arsgan_discriminator
from GAN.ImprovedWGAN import *
from ops.argparser import argparser
from Evaluator.Evaluator import *
import tensorflow as tf
from A2C.Learner import Learner
def set_gpu(gpu_id):
"""This is only use for my server, others can comment this function and use your own settings"""
os.environ["CUDA_VISIBLE_DEVICES"] =str(gpu_id)
os.system("export CUDA_HOME=/usr/local/cuda-8.0/")
os.system("export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH")
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth=True #allocate as you need
session = tf.Session(config=config)
K.set_session(session)
def preparemodel(Evaluator,params):
z_dim=params['noise_dim']#Noise dim
init_nch=params['channel_size']#channel size
init_row=params['image_height']
init_col=params['image_width']
print('Initializing Generator!')
Generator=[]
for i in range(params['n_workers']):
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess:
#K.set_session(sess)
if True:
tempGenerator=arsgan_generator(z_dim, params['batch_size'], init_row, init_col,init_nch, architecture_id='000', output_act='tanh')
Generator.append(tempGenerator)
input_shape=(params['image_width'],params['image_height'],params['channel_size'])
print('Initializing Discriminator!')
Discriminator=[]
for i in range(params['n_workers']):
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess:
#K.set_session(sess)
if True:
tempDiscriminator=arsgan_discriminator(input_shape,1,architecture_id='001', output_act='sigmoid')
Discriminator.append(tempDiscriminator)
print('Finish Discriminator!')
# Realdiscriminator=dcgan_discriminator(input_shape,discriminate_class,architecture_id='001', output_act='sigmoid')#It will update only for decided steps, and it will offer a copy to help D come back to original states
# Realgenerator=dcgan_generator(z_dim, params['batch_size'],init_nch, init_row, init_col, architecture_id='000', output_act='tanh')
# #Set weights are same
# Realgenerator.load_weights('beginG.h5')
# Realdiscriminator.load_weights('beginD.h5')
for i in range(0,params['n_workers']):
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess:
if True:
Discriminator[i].load_weights('beginD.h5')
Generator[i].load_weights('beginG.h5')
print('Initializing WGAN!')
ARSGAN_GAN=[]
for i in range(params['n_workers']):
#with tf.Session("grpc://worker%d.localhost:%d"%(i,2222+i)) as sess:
# K.set_session(sess)
if True:
tempARSGAN_GAN=ImprovedWGAN(Generator[i],
Discriminator[i],
1,
params,
min_grad=1e-3,
max_patient=100, len_loss_his=100, idx=0, base_lr=params['base_lr'],
max_inception_score=10,logdir=params['gan_logdir'],evaluator=Evaluator,
is_largeGan=True,is_face=False)
ARSGAN_GAN.append(tempARSGAN_GAN)
return ARSGAN_GAN
def preparerealmodel(params):
z_dim=params['noise_dim']#Noise dim
init_nch=params['channel_size']#channel size
init_row=params['image_height']
init_col=params['image_width']
print('Initializing Generator!')
realgenerator=arsgan_generator(z_dim, params['batch_size'], init_row, init_col,init_nch, architecture_id='000', output_act='tanh')
if params['resume_id']!=0:
realpath=os.path.join(os.getcwd(),params['a2cmodel_path'])
realgpath=os.path.join(realpath,'realG_step'+str(params['resume_id'])+'.h5')
realdpath=os.path.join(realpath,'realD_step'+str(params['resume_id'])+'.h5')
realgenerator.load_weights(realgpath)
realgenerator.save_weights('beginG.h5')
input_shape=(params['image_width'],params['image_height'],params['channel_size'])
realdiscriminator=arsgan_discriminator(input_shape,1,architecture_id='001', output_act='sigmoid')
if params['resume_id']!=0:
realdiscriminator.load_weights(realdpath)
realdiscriminator.save_weights('beginD.h5')
return realgenerator,realdiscriminator
def run_a2c(evaluatemodel,largegan,num_class,params,xt):
dir_path = params['dir_path']
if not(os.path.exists(dir_path)):
os.makedirs(dir_path)
logdir = params['a2clog_path']
if not(os.path.exists(logdir)):
os.makedirs(logdir)
ob_dim = params['ob_dim']
ac_dim =params['ac_dim']
A2CLearner = Learner(
cmodel=largegan,
evaluator=evaluatemodel,
num_class=num_class,
ob_dim = ob_dim,
ac_dim =ac_dim,
num_workers=params['n_workers'],
logdir=logdir,
rollout_length=params['rollout_length'],
params=params,
trainset=xt,
aimset=None,is_face=False)
A2CLearner.train(params['n_iter'])
def prepare_evaluationmodel(discriminate_class,xt,yt,xv,yv,params):
input_shape=(params['image_width'],params['image_height'],params['channel_size'])
evaluator=Evaluator(input_shape,discriminate_class,architecture_id='001', output_act='softmax')
print(evaluator.model.summary())
trainset=np.concatenate([xt,xv],axis=0)
aimset=np.concatenate([yt,yv],axis=0)
print('combined trainset')
print(trainset.shape)
trained_evaluator=evaluator.train(trainset,aimset,params['evaluator_path'])
return trained_evaluator
if __name__ == '__main__':
#Initialize GPU
params=argparser()#Get the params from the command line
set_gpu(params['gpu_id'])
xt,yt,xv,yv = read_mnist() #Because core dumped, segmentation fault, we use another method
#data_path=os.path.join(os.getcwd(),'data')
#data_path=os.path.join(data_path,'celebA')
#xt=os.listdir(data_path)
discriminate_class=yt.shape[-1]
print('Initializing model!')
trained_evaluator=prepare_evaluationmodel(discriminate_class,xt,yt,xv,yv,params)
print('Initializing GAN model!')
realgenerator,realdiscriminator=preparerealmodel(params)
gans=preparemodel(trained_evaluator,params)
print('Building Large GAN')
combine=LargeGan(gans,params['n_workers'],params,trained_evaluator,is_face=False,logdir=params['a2ctmp_ganlog'])
combinemodel=combine.form_model()
print('Finishing Large GAN')
run_a2c(trained_evaluator,combine,1,params,xt)