IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    用 LSTM 预测字符序列

    lufficc发表于 2017-10-31 03:26:00
    love 0

    在循环神经网络(Recurrent Neural Network)简介中我们了解了什么是 RNN,本文用 TensorFlow 实现一个超级简单的字符预测模型,并对代码进行详细说明,防止自己以后忘记( ╯□╰ )。

    首先定义 RNN 网络:

    class RNN:
        def __init__(self,
                     in_size,
                     cell_size,
                     num_layers,
                     out_size,
                     sess,
                     lr=0.003):
            self.in_size = in_size  # 输入数据大小,这里为字符数目(注意这里将大写字母转换为了小写字母,减小了字符数目,加快训练)
            self.cell_size = cell_size  # LSTM cell 的unit数目
            self.num_layers = num_layers  # 因为是多层LSTM,num_layers指定了层数
            self.out_size = out_size  # 输出数据大小,同样为字符数目
            self.sess = sess  # session
            self.lr = lr  # 学习速率
    
            # 储存上一次的state, 测试的时候用
            self.last_state = np.zeros([num_layers * 2 * cell_size])
    
            # 输入数据,(batch, time_step, in_size)
            self.inputs = tf.placeholder(tf.float32, shape=[None, None, in_size])
            self.lstm_cells = [
                rnn.BasicLSTMCell(cell_size, state_is_tuple=False)
                for _ in range(num_layers)
            ]
            self.lstm = rnn.MultiRNNCell(self.lstm_cells, state_is_tuple=False)
    
            self.init_state = tf.placeholder(tf.float32, [None, num_layers * 2 * cell_size])
    
            # 定义 recurrent neural network
            outputs, self.new_state = tf.nn.dynamic_rnn(
                self.lstm,
                self.inputs,
                initial_state=self.init_state,
                dtype=tf.float32)
    
            # 最后使用全连接层来计算loss,w b 为全连接层参数
            w = tf.Variable(tf.random_normal([cell_size, out_size], stddev=0.01))
            b = tf.Variable(tf.random_normal([out_size], stddev=0.01))
    
            reshaped_outputs = tf.matmul(tf.reshape(outputs, [-1, cell_size]), w) + b
            # 将输出转换为概率
            fc = tf.nn.softmax(reshaped_outputs)
    
            shape = tf.shape(outputs)
            self.final_outputs = tf.reshape(fc, [shape[0], shape[1], out_size])
    
            # labels, (batch, time_step, in_size)
            self.targets = tf.placeholder(tf.float32, [None, None, out_size])
    
            # loss
            self.cost = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=reshaped_outputs,
                    labels=tf.reshape(self.targets, [-1, out_size])))
            self.optimizer = tf.train.RMSPropOptimizer(lr).minimize(self.cost)
    
        def train(self, inputs, targets):
            # init_state,(batch_size, num_layers * 2 * self.cell_size)
            init_state = np.zeros(
                [inputs.shape[0], self.num_layers * 2 * self.cell_size])
    
            _, loss = self.sess.run(
                [self.optimizer, self.cost],
                feed_dict={
                    self.inputs: inputs,
                    self.targets: targets,
                    self.init_state: init_state
                })
    
            return loss
    
        def get_next_char_pro(self, x, init=False):
            """根据输入字符x, 预测下一个字符并返回,x 的 shape 为(1, in_size), 即为单个字符的 one-hot 形式
            """
            if init:
                init_state = np.zeros([self.num_layers * 2 * self.cell_size])
            else:
                init_state = self.last_state
    
            out, next_state = self.sess.run(
                [self.final_outputs, self.new_state],
                feed_dict={self.inputs: [x], self.init_state: [init_state]})
            # 将当前state储存起来,下次预测的时候使用,这样就保留了上下文信息
            self.last_state = next_state[0]
            return out[0][0]
    

    处理数据:

    def generate_one_hot_data(data, vocabulary):
        data_ = np.zeros([len(data), len(vocabulary)])
    
        count = 0
        for char in data:
            i = vocabulary.index(char)
            data_[count, i] = 1.0
            count += 1
        return data_
    
    
    path = 'data/data.txt'
    
    with open(path, 'r') as f:
        data = f.read()
    
    data = data.lower()  # 全部小写,降低复杂度
    
    vocabulary = list(set(data))  # 字符列表
    
    one_hot_data = generate_one_hot_data(data, vocabulary)  # one_hot_data,shape 为(len(data), len(vocabulary))
    

    定义参数:

    in_size = out_size = len(vocabulary)
    
    cell_size = 128
    num_layers = 2
    batch_size = 64
    time_steps = 128
    
    NUM_TRAIN_BATCHES = 5000
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)
    
    model = RNN(in_size=in_size,
                cell_size=cell_size,
                num_layers=num_layers,
                out_size=out_size,
                sess=sess)
    
    sess.run(tf.global_variables_initializer())
    
    inputs = np.zeros([batch_size, time_steps, in_size])
    targets = np.zeros([batch_size, time_steps, out_size])
    
    possible_batch_start_ids = range(len(data) - time_steps - 1)
    

    生成训练数据:

    for i in range(NUM_TRAIN_BATCHES):
        batch_start_ids = random.sample(possible_batch_start_ids, batch_size)
    
        for j in range(time_steps):
            inputs_ids = [k + j for k in batch_start_ids]
            targets_ids = [k + j + 1 for k in batch_start_ids]
    
            inputs[:, j, :] = one_hot_data[inputs_ids, :]
            targets[:, j, :] = one_hot_data[targets_ids, :]
    
        loss = model.train(inputs, targets)
    
        if i % 100 == 0:
            print('loss: {:.5f} of batch {}'.format(loss, i))
    
    

    测试:

    # 预测一个以'we ' 开头的句子
    TEST_PREFIX = 'we '
    
    for i in range(len(TEST_PREFIX)):
        out = model.get_next_char_pro(
            generate_one_hot_data(TEST_PREFIX[i], vocabulary), i == 0)
    
    gen_str = TEST_PREFIX
    
    for i in range(200):
        index = np.random.choice(range(len(vocabulary)), p=out)
        pred = vocabulary[index]
        # print(index, len(out), len(vocabulary))
        gen_str += pred
    
        out = model.get_next_char_pro(generate_one_hot_data(pred, vocabulary))
    
    print(gen_str)
    

    训练结果:

    loss: 3.12114 of batch 100
    loss: 3.06464 of batch 200
    loss: 2.49030 of batch 300
    loss: 2.20478 of batch 400
    loss: 2.02495 of batch 500
    loss: 1.88005 of batch 600
    loss: 1.74067 of batch 700
    loss: 1.70199 of batch 800
    loss: 1.69426 of batch 900
    loss: 1.54315 of batch 1000
    loss: 1.55390 of batch 1100
    loss: 1.48157 of batch 1200
    loss: 1.49068 of batch 1300
    loss: 1.48002 of batch 1400
    loss: 1.44718 of batch 1500
    loss: 1.44463 of batch 1600
    loss: 1.44997 of batch 1700
    loss: 1.41875 of batch 1800
    loss: 1.39939 of batch 1900
    loss: 1.40289 of batch 2000
    loss: 1.36561 of batch 2100
    loss: 1.37862 of batch 2200
    loss: 1.39195 of batch 2300
    loss: 1.36577 of batch 2400
    loss: 1.33385 of batch 2500
    loss: 1.36705 of batch 2600
    loss: 1.28989 of batch 2700
    loss: 1.35604 of batch 2800
    loss: 1.30150 of batch 2900
    loss: 1.30576 of batch 3000
    loss: 1.30832 of batch 3100
    loss: 1.28968 of batch 3200
    loss: 1.28770 of batch 3300
    loss: 1.27296 of batch 3400
    loss: 1.31154 of batch 3500
    loss: 1.30212 of batch 3600
    loss: 1.29709 of batch 3700
    loss: 1.28448 of batch 3800
    loss: 1.28766 of batch 3900
    loss: 1.25970 of batch 4000
    loss: 1.27008 of batch 4100
    loss: 1.29763 of batch 4200
    loss: 1.25666 of batch 4300
    loss: 1.29813 of batch 4400
    loss: 1.26807 of batch 4500
    loss: 1.23903 of batch 4600
    loss: 1.21010 of batch 4700
    loss: 1.27084 of batch 4800
    loss: 1.27161 of batch 4900
    loss: 1.25789 of batch 5000
    loss: 1.22986 of batch 5100
    loss: 1.24404 of batch 5200
    loss: 1.27089 of batch 5300
    loss: 1.23036 of batch 5400
    loss: 1.25348 of batch 5500
    loss: 1.23626 of batch 5600
    loss: 1.21493 of batch 5700
    loss: 1.20419 of batch 5800
    loss: 1.23771 of batch 5900
    loss: 1.20754 of batch 6000
    loss: 1.23489 of batch 6100
    loss: 1.20233 of batch 6200
    loss: 1.20366 of batch 6300
    loss: 1.23586 of batch 6400
    loss: 1.21687 of batch 6500
    loss: 1.19479 of batch 6600
    loss: 1.21297 of batch 6700
    loss: 1.23598 of batch 6800
    loss: 1.19476 of batch 6900
    loss: 1.21584 of batch 7000
    loss: 1.22816 of batch 7100
    loss: 1.19449 of batch 7200
    loss: 1.19346 of batch 7300
    loss: 1.23466 of batch 7400
    loss: 1.18541 of batch 7500
    loss: 1.19469 of batch 7600
    loss: 1.21069 of batch 7700
    loss: 1.19641 of batch 7800
    loss: 1.15550 of batch 7900
    loss: 1.19861 of batch 8000
    loss: 1.22582 of batch 8100
    loss: 1.19766 of batch 8200
    loss: 1.19041 of batch 8300
    loss: 1.15410 of batch 8400
    loss: 1.13109 of batch 8500
    loss: 1.16434 of batch 8600
    loss: 1.19457 of batch 8700
    loss: 1.18558 of batch 8800
    loss: 1.18043 of batch 8900
    loss: 1.17171 of batch 9000
    loss: 1.17663 of batch 9100
    loss: 1.14107 of batch 9200
    loss: 1.20001 of batch 9300
    loss: 1.16926 of batch 9400
    loss: 1.14761 of batch 9500
    loss: 1.15305 of batch 9600
    loss: 1.20601 of batch 9700
    loss: 1.16141 of batch 9800
    loss: 1.15704 of batch 9900
    

    一些输出:

    prefix:she
    sheated here do in the weary blood
    unless he single thou mayst break your ancients dren
    lads, then i in to thy fair age.
    so i say the nursed with mine eyes from my sleep.
    
    clarence:
    how is it pale and mo
    
    
    prefix:the
    the tears of mine.
    hark! they may providy to the extremement,
    make an office of day, like to be so run out.
    dost than the heaven and your pardon did forght
    the crown: all shall be too close hather change
    prefix:i
    in heavy part:
    in your comforts of our state exposed
    for in the kin a phalteres times disprison.
    
    isabella:
    dispatch and great to do a plance:
    and now they do repent us to be pitteth caser
    which in my 
    
    
    prefix:i
    it confined him to his necessities,
    but to an e this young belonged,'
    as it susminine of france of vices of folly,
    when though 'i trust up me not, our reward,
    sir, you shall be absent disposed.
    
    warwic
    
    
    prefix:me
    me;
    for will is angelo, take it,--and thou knows he did:
    one catesby!
    
    pronirerd:
    uppecured for me hencefully, you make one flight,
    making down thee thought in who thy by the world,
    but name, though hat
    


沪ICP备19023445号-2号
友情链接