init
This commit is contained in:
parent
94bd5c0bf7
commit
24cb406af9
6
.gitattributes
vendored
Normal file
6
.gitattributes
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
# convert to OS line endings on checkout, back to LF on commit
|
||||
* text=auto
|
||||
|
||||
# ensure anything copied to the container has unix style line endings
|
||||
*.sh text eol=lf
|
||||
requirements.txt text eol=lf
|
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
__pycache__
|
||||
.mypy_cache/
|
||||
models/
|
17
CONTRIBUTORS.md
Normal file
17
CONTRIBUTORS.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Contributors (alphabetically)
|
||||
|
||||
* **[madisonmay](https://github.com/madisonmay)**
|
||||
|
||||
Added Dockerfiles
|
||||
|
||||
* **[Margaret Mitchell et al](https://arxiv.org/abs/1810.03993)**
|
||||
|
||||
Our [usage](./README.md#usage) writeup was loosely inspired by the paper
|
||||
[Model Cards for Model Reporting](https://arxiv.org/abs/1810.03993)
|
||||
and related conversations with some of the authors.
|
||||
|
||||
* **[webproduktion01](https://github.com/webproduktion01)**
|
||||
|
||||
Ported download script to python.
|
||||
|
||||
**[Full code contributors list](https://github.com/openai/gpt-2/contributors).**
|
88
DEVELOPERS.md
Normal file
88
DEVELOPERS.md
Normal file
@ -0,0 +1,88 @@
|
||||
# Installation
|
||||
|
||||
Git clone this repository, and `cd` into directory for remaining commands
|
||||
```
|
||||
git clone https://github.com/openai/gpt-2.git && cd gpt-2
|
||||
```
|
||||
|
||||
Then, follow instructions for either native or Docker installation.
|
||||
|
||||
## Native Installation
|
||||
|
||||
All steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`.
|
||||
|
||||
Install tensorflow 1.12 (with GPU support, if you have a GPU and want everything to run faster)
|
||||
```
|
||||
pip3 install tensorflow==1.12.0
|
||||
```
|
||||
or
|
||||
```
|
||||
pip3 install tensorflow-gpu==1.12.0
|
||||
```
|
||||
|
||||
Install other python packages:
|
||||
```
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
Download the model data
|
||||
```
|
||||
python3 download_model.py 124M
|
||||
python3 download_model.py 355M
|
||||
python3 download_model.py 774M
|
||||
python3 download_model.py 1558M
|
||||
```
|
||||
|
||||
## Docker Installation
|
||||
|
||||
Build the Dockerfile and tag the created image as `gpt-2`:
|
||||
```
|
||||
docker build --tag gpt-2 -f Dockerfile.gpu . # or Dockerfile.cpu
|
||||
```
|
||||
|
||||
Start an interactive bash session from the `gpt-2` docker image.
|
||||
|
||||
You can opt to use the `--runtime=nvidia` flag if you have access to a NVIDIA GPU
|
||||
and a valid install of [nvidia-docker 2.0](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)).
|
||||
```
|
||||
docker run --runtime=nvidia -it gpt-2 bash
|
||||
```
|
||||
|
||||
# Running
|
||||
|
||||
| WARNING: Samples are unfiltered and may contain offensive content. |
|
||||
| --- |
|
||||
|
||||
Some of the examples below may include Unicode text characters. Set the environment variable:
|
||||
```
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
```
|
||||
to override the standard stream settings in UTF-8 mode.
|
||||
|
||||
## Unconditional sample generation
|
||||
|
||||
To generate unconditional samples from the small model:
|
||||
```
|
||||
python3 src/generate_unconditional_samples.py | tee /tmp/samples
|
||||
```
|
||||
There are various flags for controlling the samples:
|
||||
```
|
||||
python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee /tmp/samples
|
||||
```
|
||||
|
||||
To check flag descriptions, use:
|
||||
```
|
||||
python3 src/generate_unconditional_samples.py -- --help
|
||||
```
|
||||
|
||||
## Conditional sample generation
|
||||
|
||||
To give the model custom prompts, you can use:
|
||||
```
|
||||
python3 src/interactive_conditional_samples.py --top_k 40
|
||||
```
|
||||
|
||||
To check flag descriptions, use:
|
||||
```
|
||||
python3 src/interactive_conditional_samples.py -- --help
|
||||
```
|
11
Dockerfile.cpu
Normal file
11
Dockerfile.cpu
Normal file
@ -0,0 +1,11 @@
|
||||
FROM tensorflow/tensorflow:1.12.0-py3
|
||||
|
||||
ENV LANG=C.UTF-8
|
||||
RUN mkdir /gpt-2
|
||||
WORKDIR /gpt-2
|
||||
ADD . /gpt-2
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN python3 download_model.py 124M
|
||||
RUN python3 download_model.py 355M
|
||||
RUN python3 download_model.py 774M
|
||||
RUN python3 download_model.py 1558M
|
20
Dockerfile.gpu
Normal file
20
Dockerfile.gpu
Normal file
@ -0,0 +1,20 @@
|
||||
FROM tensorflow/tensorflow:1.12.0-gpu-py3
|
||||
|
||||
# nvidia-docker 1.0
|
||||
LABEL com.nvidia.volumes.needed="nvidia_driver"
|
||||
LABEL com.nvidia.cuda.version="${CUDA_VERSION}"
|
||||
|
||||
# nvidia-container-runtime
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all \
|
||||
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
|
||||
NVIDIA_REQUIRE_CUDA="cuda>=8.0" \
|
||||
LANG=C.UTF-8
|
||||
|
||||
RUN mkdir /gpt-2
|
||||
WORKDIR /gpt-2
|
||||
ADD . /gpt-2
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN python3 download_model.py 124M
|
||||
RUN python3 download_model.py 355M
|
||||
RUN python3 download_model.py 774M
|
||||
RUN python3 download_model.py 1558M
|
24
LICENSE
Normal file
24
LICENSE
Normal file
@ -0,0 +1,24 @@
|
||||
Modified MIT License
|
||||
|
||||
Software Copyright (c) 2019 OpenAI
|
||||
|
||||
We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
|
||||
We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included
|
||||
in all copies or substantial portions of the Software.
|
||||
The above copyright notice and this permission notice need not be included
|
||||
with content created by the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
OR OTHER DEALINGS IN THE SOFTWARE.
|
24
README.md
24
README.md
@ -0,0 +1,24 @@
|
||||
## Overview
|
||||
|
||||
A very quick and dirty implementation on Open AI's GPT-2 in a Discord Bot. This was hacked together in a day and is by no means the best or most efficient way to do it. I may come back to this project to make it prettier another time.
|
||||
|
||||
## Requirements
|
||||
|
||||
* General understanding of how to create and manage Discord bots
|
||||
* A trained GPT model
|
||||
* Basic understanding of Python
|
||||
* All the pre-requisites for using GPT 2 which this was forked from
|
||||
* A CUDA-Capable GPU (Recommended for better performance)
|
||||
|
||||
## Usage
|
||||
|
||||
You'll have to modify ``bot.py`` with your Discord bot token, as well as point the commands to the correct model. ``bot.py`` is a very basic Discord bot with Open AI's generation scripts pasted into it and modified to return a string rather than print to console. My model was trained off a group chat with friends, as such I've written the bot and its commands to reflect this. The data format which I trained my model on follows the regular expression ``^[A-z]{3} [0-9]{2}:[0-9]{2} [A-Z]{2} - .+: "(.*)"``, which looks like ``Sep 05:57 PM - Username: "Message content"``. Unless your model outputs text in this exact format, you will have to modify this bot to accommodate your needs
|
||||
|
||||
There are three commands to interact with this bot.
|
||||
* ``!g <prompt>``, which generates based off a prompt, or a random one if none is provided.
|
||||
* ``!r <prompt>``, which replies to a prompt. Example ``!r Hi, how are you?`` may respond ``I'm good!``. This command assumes the model to return a string following the regexp mentioned above, and cuts out the irrelevant information to mimic a response from the bot.
|
||||
* ``!c <prompt>``, which continues a prompt. Example ``!c My name is`` to which the bot may continue the prompt with ``My name is Jojgo``
|
||||
|
||||
## License
|
||||
|
||||
[Modified MIT](./LICENSE)
|
238
bot.py
Normal file
238
bot.py
Normal file
@ -0,0 +1,238 @@
|
||||
import os
|
||||
import discord
|
||||
from dotenv import load_dotenv
|
||||
import random
|
||||
import re
|
||||
|
||||
# GENERATE
|
||||
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import fire
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import model, sample, encoder
|
||||
|
||||
def sample_model(
|
||||
model_name='h5',
|
||||
seed=None,
|
||||
nsamples=1,
|
||||
batch_size=1,
|
||||
length=250,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Run the sample_model
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to
|
||||
reproduce results
|
||||
:nsamples=0 : Number of samples to return, if 0, continues to
|
||||
generate samples indefinately.
|
||||
:batch_size=1 : Number of batches (only affects speed/memory).
|
||||
:length=None : Number of tokens in generated text, if None (default), is
|
||||
determined by model hyperparameters
|
||||
:temperature=1 : Float value controlling randomness in boltzmann
|
||||
distribution. Lower temperature results in less random completions. As the
|
||||
temperature approaches zero, the model will become deterministic and
|
||||
repetitive. Higher temperature results in more random completions.
|
||||
:top_k=0 : Integer value controlling diversity. 1 means only 1 word is
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
length = hparams.n_ctx
|
||||
elif length > hparams.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)[:, 1:]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
generated = 0
|
||||
while nsamples == 0 or generated < nsamples:
|
||||
out = sess.run(output)
|
||||
for i in range(batch_size):
|
||||
generated += batch_size
|
||||
text = '```'
|
||||
text += enc.decode(out[i])
|
||||
text += '```'
|
||||
#print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
return text
|
||||
|
||||
|
||||
# RESPOND
|
||||
|
||||
def interact_model(
|
||||
model_name='h5',
|
||||
seed=None,
|
||||
nsamples=1,
|
||||
batch_size=1,
|
||||
length=100,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
raw_text='test',
|
||||
inline = True
|
||||
):
|
||||
"""
|
||||
Interactively run the model
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to reproduce
|
||||
results
|
||||
:nsamples=1 : Number of samples to return total
|
||||
:batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
|
||||
:length=None : Number of tokens in generated text, if None (default), is
|
||||
determined by model hyperparameters
|
||||
:temperature=1 : Float value controlling randomness in boltzmann
|
||||
distribution. Lower temperature results in less random completions. As the
|
||||
temperature approaches zero, the model will become deterministic and
|
||||
repetitive. Higher temperature results in more random completions.
|
||||
:top_k=0 : Integer value controlling diversity. 1 means only 1 word is
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
assert nsamples % batch_size == 0
|
||||
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
length = hparams.n_ctx // 2
|
||||
elif length > hparams.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
context=context,
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
# Generation Code
|
||||
|
||||
context_tokens = enc.encode(raw_text)
|
||||
generated = 0
|
||||
for _ in range(nsamples // batch_size):
|
||||
out = sess.run(output, feed_dict={
|
||||
context: [context_tokens for _ in range(batch_size)]
|
||||
})[:, len(context_tokens):]
|
||||
for i in range(batch_size):
|
||||
generated += 1
|
||||
text = ''
|
||||
if inline:
|
||||
text = '```'
|
||||
text += enc.decode(out[i])
|
||||
if inline:
|
||||
text += '```'
|
||||
return text
|
||||
print("=" * 80)
|
||||
|
||||
# DISCORD BOT
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
TOKEN = ('YOUR TOKEN HERE')
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print("Logged in as a bot {0.user}".format(client))
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
if message.content == '!g':
|
||||
response = sample_model()
|
||||
await message.channel.send(response)
|
||||
elif "!g" in message.content[:2]:
|
||||
text = message.content[3:]
|
||||
content = interact_model(raw_text = text, inline = True, length = 250)
|
||||
index = content.find('\n')
|
||||
await message.channel.send(content[index:])
|
||||
elif "!r" in message.content[:2]:
|
||||
text = message.content[3:]
|
||||
|
||||
content = interact_model(raw_text = text, inline = False, length = 50)
|
||||
|
||||
regexp = re.compile('^[A-z]{3} [0-9]{2}:[0-9]{2} [A-Z]{2} - .+: "(.*)"')
|
||||
matched = regexp.match(content)
|
||||
|
||||
if matched:
|
||||
toReturn = matched.group(1)
|
||||
await message.channel.send(toReturn)
|
||||
else:
|
||||
index = (content.find('\n')) + 1
|
||||
temp = content[index:]
|
||||
|
||||
temp_match = regexp.match(temp)
|
||||
if temp_match:
|
||||
toReturn = temp_match.group(1)
|
||||
await message.channel.send(toReturn)
|
||||
else:
|
||||
await message.channel.send(temp)
|
||||
|
||||
elif "!c" in message.content[:2]:
|
||||
# remove anything after \n
|
||||
# add text to response
|
||||
text = message.content[3:]
|
||||
content = interact_model(raw_text = text, length = 20, inline = False)
|
||||
|
||||
sep = '\n'
|
||||
stripped = content.split(sep, 1)[0]
|
||||
|
||||
toReturn = text + "" + stripped
|
||||
|
||||
await message.channel.send(toReturn[:(len(toReturn) - 1)])
|
||||
elif message.content =='!h':
|
||||
response = '```!g <prompt> - Generates Conversation. If no prompt provided a random one will be used.\n!r <prompt> - Responds to prompt\n!c <prompt> - Continues prompt```'
|
||||
await message.channel.send(response)
|
||||
|
||||
|
||||
client.run(TOKEN)
|
1000
domains.txt
Normal file
1000
domains.txt
Normal file
File diff suppressed because it is too large
Load Diff
28
download_model.py
Normal file
28
download_model.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print('You must enter the model name as a parameter, e.g.: download_model.py 124M')
|
||||
sys.exit(1)
|
||||
|
||||
model = sys.argv[1]
|
||||
|
||||
subdir = os.path.join('models', model)
|
||||
if not os.path.exists(subdir):
|
||||
os.makedirs(subdir)
|
||||
subdir = subdir.replace('\\','/') # needed for Windows
|
||||
|
||||
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
|
||||
|
||||
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True)
|
||||
|
||||
with open(os.path.join(subdir, filename), 'wb') as f:
|
||||
file_size = int(r.headers["content-length"])
|
||||
chunk_size = 1000
|
||||
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
|
||||
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes
|
||||
for chunk in r.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
pbar.update(chunk_size)
|
69
model_card.md
Normal file
69
model_card.md
Normal file
@ -0,0 +1,69 @@
|
||||
# GPT-2 model card
|
||||
|
||||
Last updated: November 2019
|
||||
|
||||
Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we’re providing some accompanying information about the GPT-2 family of models we're releasing.
|
||||
|
||||
## Model Details.
|
||||
|
||||
This model was developed by researchers at OpenAI to help us understand how the capabilities of language model capabilities scale as a function of the size of the models (by parameter count) combined with very large internet-scale datasets (WebText).
|
||||
|
||||
### Model date
|
||||
|
||||
February 2019, trained on data that cuts off at the end of 2017.
|
||||
|
||||
### Model type
|
||||
|
||||
Language model
|
||||
|
||||
### Model version
|
||||
|
||||
1.5 billion parameters: the fourth and largest GPT-2 version. We have also released 124 million, 355 million, and 774 million parameter models.
|
||||
|
||||
### Paper or other resource for more information
|
||||
[Blog post](https://openai.com/blog/better-language-models/) and [paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
|
||||
|
||||
### Where to send questions or comments about the model
|
||||
Please use this [Google Form](https://forms.gle/A7WBSbTY2EkKdroPA)
|
||||
|
||||
## Intended Uses:
|
||||
|
||||
### Primary intended uses
|
||||
|
||||
The primary intended users of these models are *AI researchers and practitioners*.
|
||||
|
||||
We primarily imagine these language models will be used by researchers to better understand the behaviors, capabilities, biases, and constraints of large-scale generative language models.
|
||||
|
||||
### Secondary uses
|
||||
|
||||
Here are some secondary use cases we believe are likely:
|
||||
|
||||
- **Writing assistance**: Grammar assistance, autocompletion (for normal prose or code)
|
||||
- **Creative writing and art**: exploring the generation of creative, fictional texts; aiding creation of poetry and other literary art.
|
||||
- **Entertainment**: Creation of games, chat bots, and amusing generations.
|
||||
|
||||
### Out-of-scope use cases
|
||||
|
||||
Because large-scale language models like GPT-2 do not distinguish fact from fiction, we don’t support use-cases that require the generated text to be true.
|
||||
|
||||
Additionally, language models like GPT-2 reflect the biases inherent to the systems they were trained on, so we do not recommend that they be deployed into systems that interact with humans unless the deployers first carry out a study of biases relevant to the intended use-case. We found no statistically significant difference in gender, race, and religious bias probes between 774M and 1.5B, implying all versions of GPT-2 should be approached with similar levels of caution around use cases that are sensitive to biases around human attributes.
|
||||
|
||||
## Evaluation Data
|
||||
|
||||
### Datasets
|
||||
|
||||
This model was trained on (and evaluated against) WebText, a dataset consisting of the text contents of 45 million links posted by users of the ‘Reddit’ social network. WebText is made of data derived from outbound links from Reddit and does not consist of data taken directly from Reddit itself. Before generating the dataset we used a blocklist to ensure we didn’t sample from a variety of subreddits which contain sexually explicit or otherwise offensive content.
|
||||
|
||||
To get a sense of the data that went into GPT-2, we’ve [published a list](domains.txt) of the top 1,000 domains present in WebText and their frequency. The top 15 domains by volume in WebText are: Google, Archive, Blogspot, GitHub, NYTimes, Wordpress, Washington Post, Wikia, BBC, The Guardian, eBay, Pastebin, CNN, Yahoo!, and the Huffington Post.
|
||||
|
||||
### Motivation
|
||||
|
||||
The motivation behind WebText was to create an Internet-scale, heterogeneous dataset that we could use to test large-scale language models against. WebText was (and is) intended to be primarily for research purposes rather than production purposes.
|
||||
|
||||
### Caveats and Recommendations
|
||||
|
||||
Because GPT-2 is an internet-scale language model, it’s currently difficult to know what disciplined testing procedures can be applied to it to fully understand its capabilities and how the data it is trained on influences its vast range of outputs. We recommend researchers investigate these aspects of the model and share their results.
|
||||
|
||||
Additionally, as indicated in our discussion of issues relating to potential misuse of the model, it remains unclear what the long-term dynamics are of detecting outputs from these models. We conducted [in-house automated ML-based detection research](https://github.com/openai/gpt-2-output-dataset/tree/master/detector) using simple classifiers, zero shot, and fine-tuning methods. Our fine-tuned detector model reached accuracy levels of approximately 95%. However, no one detection method is a panacea; automated ML-based detection, human detection, human-machine teaming, and metadata-based detection are all methods that can be combined for more confident classification. Developing better approaches to detection today will give us greater intuitions when thinking about future models and could help us understand ahead of time if detection methods will eventually become ineffective.
|
||||
|
||||
|
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
fire>=0.1.3
|
||||
regex==2017.4.5
|
||||
requests==2.21.0
|
||||
tqdm==4.31.1
|
117
src/encoder.py
Normal file
117
src/encoder.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""Byte pair encoding utilities"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import regex as re
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def get_encoder(model_name, models_dir):
|
||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||
encoder = json.load(f)
|
||||
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
80
src/generate_unconditional_samples.py
Executable file
80
src/generate_unconditional_samples.py
Executable file
@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import fire
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import model, sample, encoder
|
||||
|
||||
def sample_model(
|
||||
model_name='124M',
|
||||
seed=None,
|
||||
nsamples=0,
|
||||
batch_size=1,
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Run the sample_model
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to
|
||||
reproduce results
|
||||
:nsamples=0 : Number of samples to return, if 0, continues to
|
||||
generate samples indefinately.
|
||||
:batch_size=1 : Number of batches (only affects speed/memory).
|
||||
:length=None : Number of tokens in generated text, if None (default), is
|
||||
determined by model hyperparameters
|
||||
:temperature=1 : Float value controlling randomness in boltzmann
|
||||
distribution. Lower temperature results in less random completions. As the
|
||||
temperature approaches zero, the model will become deterministic and
|
||||
repetitive. Higher temperature results in more random completions.
|
||||
:top_k=0 : Integer value controlling diversity. 1 means only 1 word is
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
length = hparams.n_ctx
|
||||
elif length > hparams.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)[:, 1:]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
generated = 0
|
||||
while nsamples == 0 or generated < nsamples:
|
||||
out = sess.run(output)
|
||||
for i in range(batch_size):
|
||||
generated += batch_size
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(sample_model)
|
||||
|
92
src/interactive_conditional_samples.py
Executable file
92
src/interactive_conditional_samples.py
Executable file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import fire
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import model, sample, encoder
|
||||
|
||||
def interact_model(
|
||||
model_name='124M',
|
||||
seed=None,
|
||||
nsamples=1,
|
||||
batch_size=1,
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Interactively run the model
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to reproduce
|
||||
results
|
||||
:nsamples=1 : Number of samples to return total
|
||||
:batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples.
|
||||
:length=None : Number of tokens in generated text, if None (default), is
|
||||
determined by model hyperparameters
|
||||
:temperature=1 : Float value controlling randomness in boltzmann
|
||||
distribution. Lower temperature results in less random completions. As the
|
||||
temperature approaches zero, the model will become deterministic and
|
||||
repetitive. Higher temperature results in more random completions.
|
||||
:top_k=0 : Integer value controlling diversity. 1 means only 1 word is
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
assert nsamples % batch_size == 0
|
||||
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
length = hparams.n_ctx // 2
|
||||
elif length > hparams.n_ctx:
|
||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
context=context,
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
while True:
|
||||
raw_text = input("Model prompt >>> ")
|
||||
while not raw_text:
|
||||
print('Prompt should not be empty!')
|
||||
raw_text = input("Model prompt >>> ")
|
||||
context_tokens = enc.encode(raw_text)
|
||||
generated = 0
|
||||
for _ in range(nsamples // batch_size):
|
||||
out = sess.run(output, feed_dict={
|
||||
context: [context_tokens for _ in range(batch_size)]
|
||||
})[:, len(context_tokens):]
|
||||
for i in range(batch_size):
|
||||
generated += 1
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(text)
|
||||
print("=" * 80)
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(interact_model)
|
||||
|
174
src/model.py
Normal file
174
src/model.py
Normal file
@ -0,0 +1,174 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.training import HParams
|
||||
|
||||
def default_hparams():
|
||||
return HParams(
|
||||
n_vocab=0,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_head=12,
|
||||
n_layer=12,
|
||||
)
|
||||
|
||||
def shape_list(x):
|
||||
"""Deal with dynamic shape in tensorflow cleanly."""
|
||||
static = x.shape.as_list()
|
||||
dynamic = tf.shape(x)
|
||||
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
x = x - tf.reduce_max(x, axis=axis, keepdims=True)
|
||||
ex = tf.exp(x)
|
||||
return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
|
||||
|
||||
def gelu(x):
|
||||
return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
|
||||
|
||||
def norm(x, scope, *, axis=-1, epsilon=1e-5):
|
||||
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
|
||||
with tf.variable_scope(scope):
|
||||
n_state = x.shape[-1].value
|
||||
g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
|
||||
b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
|
||||
u = tf.reduce_mean(x, axis=axis, keepdims=True)
|
||||
s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
|
||||
x = (x - u) * tf.rsqrt(s + epsilon)
|
||||
x = x*g + b
|
||||
return x
|
||||
|
||||
def split_states(x, n):
|
||||
"""Reshape the last dimension of x into [n, x.shape[-1]/n]."""
|
||||
*start, m = shape_list(x)
|
||||
return tf.reshape(x, start + [n, m//n])
|
||||
|
||||
def merge_states(x):
|
||||
"""Smash the last two dimensions of x into a single dimension."""
|
||||
*start, a, b = shape_list(x)
|
||||
return tf.reshape(x, start + [a*b])
|
||||
|
||||
def conv1d(x, scope, nf, *, w_init_stdev=0.02):
|
||||
with tf.variable_scope(scope):
|
||||
*start, nx = shape_list(x)
|
||||
w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev))
|
||||
b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
|
||||
c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
|
||||
return c
|
||||
|
||||
def attention_mask(nd, ns, *, dtype):
|
||||
"""1's in the lower triangle, counting from the lower right corner.
|
||||
|
||||
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
||||
"""
|
||||
i = tf.range(nd)[:,None]
|
||||
j = tf.range(ns)
|
||||
m = i >= j - ns + nd
|
||||
return tf.cast(m, dtype)
|
||||
|
||||
|
||||
def attn(x, scope, n_state, *, past, hparams):
|
||||
assert x.shape.ndims == 3 # Should be [batch, sequence, features]
|
||||
assert n_state % hparams.n_head == 0
|
||||
if past is not None:
|
||||
assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
|
||||
|
||||
def split_heads(x):
|
||||
# From [batch, sequence, features] to [batch, heads, sequence, features]
|
||||
return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
|
||||
|
||||
def merge_heads(x):
|
||||
# Reverse of split_heads
|
||||
return merge_states(tf.transpose(x, [0, 2, 1, 3]))
|
||||
|
||||
def mask_attn_weights(w):
|
||||
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
||||
_, _, nd, ns = shape_list(w)
|
||||
b = attention_mask(nd, ns, dtype=w.dtype)
|
||||
b = tf.reshape(b, [1, 1, nd, ns])
|
||||
w = w*b - tf.cast(1e10, w.dtype)*(1-b)
|
||||
return w
|
||||
|
||||
def multihead_attn(q, k, v):
|
||||
# q, k, v have shape [batch, heads, sequence, features]
|
||||
w = tf.matmul(q, k, transpose_b=True)
|
||||
w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
|
||||
|
||||
w = mask_attn_weights(w)
|
||||
w = softmax(w)
|
||||
a = tf.matmul(w, v)
|
||||
return a
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
c = conv1d(x, 'c_attn', n_state*3)
|
||||
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
|
||||
present = tf.stack([k, v], axis=1)
|
||||
if past is not None:
|
||||
pk, pv = tf.unstack(past, axis=1)
|
||||
k = tf.concat([pk, k], axis=-2)
|
||||
v = tf.concat([pv, v], axis=-2)
|
||||
a = multihead_attn(q, k, v)
|
||||
a = merge_heads(a)
|
||||
a = conv1d(a, 'c_proj', n_state)
|
||||
return a, present
|
||||
|
||||
|
||||
def mlp(x, scope, n_state, *, hparams):
|
||||
with tf.variable_scope(scope):
|
||||
nx = x.shape[-1].value
|
||||
h = gelu(conv1d(x, 'c_fc', n_state))
|
||||
h2 = conv1d(h, 'c_proj', nx)
|
||||
return h2
|
||||
|
||||
|
||||
def block(x, scope, *, past, hparams):
|
||||
with tf.variable_scope(scope):
|
||||
nx = x.shape[-1].value
|
||||
a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
|
||||
x = x + a
|
||||
m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
|
||||
x = x + m
|
||||
return x, present
|
||||
|
||||
def past_shape(*, hparams, batch_size=None, sequence=None):
|
||||
return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]
|
||||
|
||||
def expand_tile(value, size):
|
||||
"""Add a new axis of given size."""
|
||||
value = tf.convert_to_tensor(value, name='value')
|
||||
ndims = value.shape.ndims
|
||||
return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)
|
||||
|
||||
def positions_for(tokens, past_length):
|
||||
batch_size = tf.shape(tokens)[0]
|
||||
nsteps = tf.shape(tokens)[1]
|
||||
return expand_tile(past_length + tf.range(nsteps), batch_size)
|
||||
|
||||
|
||||
def model(hparams, X, past=None, scope='model', reuse=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
results = {}
|
||||
batch, sequence = shape_list(X)
|
||||
|
||||
wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
|
||||
initializer=tf.random_normal_initializer(stddev=0.01))
|
||||
wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
|
||||
initializer=tf.random_normal_initializer(stddev=0.02))
|
||||
past_length = 0 if past is None else tf.shape(past)[-2]
|
||||
h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
|
||||
|
||||
# Transformer
|
||||
presents = []
|
||||
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
|
||||
assert len(pasts) == hparams.n_layer
|
||||
for layer, past in enumerate(pasts):
|
||||
h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
|
||||
presents.append(present)
|
||||
results['present'] = tf.stack(presents, axis=1)
|
||||
h = norm(h, 'ln_f')
|
||||
|
||||
# Language model loss. Do tokens <n predict token n?
|
||||
h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
|
||||
logits = tf.matmul(h_flat, wte, transpose_b=True)
|
||||
logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
|
||||
results['logits'] = logits
|
||||
return results
|
95
src/sample.py
Normal file
95
src/sample.py
Normal file
@ -0,0 +1,95 @@
|
||||
import tensorflow as tf
|
||||
|
||||
import model
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
if k == 0:
|
||||
# no truncation
|
||||
return logits
|
||||
|
||||
def _top_k():
|
||||
values, _ = tf.nn.top_k(logits, k=k)
|
||||
min_values = values[:, -1, tf.newaxis]
|
||||
return tf.where(
|
||||
logits < min_values,
|
||||
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
|
||||
logits,
|
||||
)
|
||||
return tf.cond(
|
||||
tf.equal(k, 0),
|
||||
lambda: logits,
|
||||
lambda: _top_k(),
|
||||
)
|
||||
|
||||
|
||||
def top_p_logits(logits, p):
|
||||
"""Nucleus sampling"""
|
||||
batch, _ = logits.shape.as_list()
|
||||
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
|
||||
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
indices = tf.stack([
|
||||
tf.range(0, batch),
|
||||
# number of indices to include
|
||||
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
|
||||
], axis=-1)
|
||||
min_values = tf.gather_nd(sorted_logits, indices)
|
||||
return tf.where(
|
||||
logits < min_values,
|
||||
tf.ones_like(logits) * -1e10,
|
||||
logits,
|
||||
)
|
||||
|
||||
|
||||
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = tf.fill([batch_size, 1], start_token)
|
||||
|
||||
def step(hparams, tokens, past=None):
|
||||
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
|
||||
|
||||
logits = lm_output['logits'][:, :, :hparams.n_vocab]
|
||||
presents = lm_output['present']
|
||||
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
|
||||
return {
|
||||
'logits': logits,
|
||||
'presents': presents,
|
||||
}
|
||||
|
||||
with tf.name_scope('sample_sequence'):
|
||||
def body(past, prev, output):
|
||||
next_outputs = step(hparams, prev, past=past)
|
||||
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
logits = top_p_logits(logits, p=top_p)
|
||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||
return [
|
||||
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
||||
samples,
|
||||
tf.concat([output, samples], axis=1)
|
||||
]
|
||||
|
||||
past, prev, output = body(None, context, context)
|
||||
|
||||
def cond(*args):
|
||||
return True
|
||||
|
||||
_, _, tokens = tf.while_loop(
|
||||
cond=cond, body=body,
|
||||
maximum_iterations=length - 1,
|
||||
loop_vars=[
|
||||
past,
|
||||
prev,
|
||||
output
|
||||
],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
],
|
||||
back_prop=False,
|
||||
)
|
||||
|
||||
return tokens
|
Loading…
x
Reference in New Issue
Block a user