From 24cb406af94a83ee486c5341fa6185fa328858c9 Mon Sep 17 00:00:00 2001 From: ahaas25 Date: Sun, 23 Feb 2025 23:27:46 -0500 Subject: [PATCH] init --- .gitattributes | 6 + .gitignore | 3 + CONTRIBUTORS.md | 17 + DEVELOPERS.md | 88 +++ Dockerfile.cpu | 11 + Dockerfile.gpu | 20 + LICENSE | 24 + README.md | 24 + bot.py | 238 ++++++ domains.txt | 1000 ++++++++++++++++++++++++ download_model.py | 28 + model_card.md | 69 ++ requirements.txt | 4 + src/encoder.py | 117 +++ src/generate_unconditional_samples.py | 80 ++ src/interactive_conditional_samples.py | 92 +++ src/model.py | 174 +++++ src/sample.py | 95 +++ 18 files changed, 2090 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 CONTRIBUTORS.md create mode 100644 DEVELOPERS.md create mode 100644 Dockerfile.cpu create mode 100644 Dockerfile.gpu create mode 100644 LICENSE create mode 100644 bot.py create mode 100644 domains.txt create mode 100644 download_model.py create mode 100644 model_card.md create mode 100644 requirements.txt create mode 100644 src/encoder.py create mode 100755 src/generate_unconditional_samples.py create mode 100755 src/interactive_conditional_samples.py create mode 100644 src/model.py create mode 100644 src/sample.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..7c3a822 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,6 @@ +# convert to OS line endings on checkout, back to LF on commit +* text=auto + +# ensure anything copied to the container has unix style line endings +*.sh text eol=lf +requirements.txt text eol=lf \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5b1de5a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +.mypy_cache/ +models/ diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..eab7132 --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,17 @@ +# Contributors (alphabetically) + +* **[madisonmay](https://github.com/madisonmay)** + + Added Dockerfiles + +* **[Margaret Mitchell et al](https://arxiv.org/abs/1810.03993)** + + Our [usage](./README.md#usage) writeup was loosely inspired by the paper + [Model Cards for Model Reporting](https://arxiv.org/abs/1810.03993) + and related conversations with some of the authors. + +* **[webproduktion01](https://github.com/webproduktion01)** + + Ported download script to python. + +**[Full code contributors list](https://github.com/openai/gpt-2/contributors).** diff --git a/DEVELOPERS.md b/DEVELOPERS.md new file mode 100644 index 0000000..d23c9d0 --- /dev/null +++ b/DEVELOPERS.md @@ -0,0 +1,88 @@ +# Installation + +Git clone this repository, and `cd` into directory for remaining commands +``` +git clone https://github.com/openai/gpt-2.git && cd gpt-2 +``` + +Then, follow instructions for either native or Docker installation. + +## Native Installation + +All steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`. + +Install tensorflow 1.12 (with GPU support, if you have a GPU and want everything to run faster) +``` +pip3 install tensorflow==1.12.0 +``` +or +``` +pip3 install tensorflow-gpu==1.12.0 +``` + +Install other python packages: +``` +pip3 install -r requirements.txt +``` + +Download the model data +``` +python3 download_model.py 124M +python3 download_model.py 355M +python3 download_model.py 774M +python3 download_model.py 1558M +``` + +## Docker Installation + +Build the Dockerfile and tag the created image as `gpt-2`: +``` +docker build --tag gpt-2 -f Dockerfile.gpu . # or Dockerfile.cpu +``` + +Start an interactive bash session from the `gpt-2` docker image. + +You can opt to use the `--runtime=nvidia` flag if you have access to a NVIDIA GPU +and a valid install of [nvidia-docker 2.0](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)). +``` +docker run --runtime=nvidia -it gpt-2 bash +``` + +# Running + +| WARNING: Samples are unfiltered and may contain offensive content. | +| --- | + +Some of the examples below may include Unicode text characters. Set the environment variable: +``` +export PYTHONIOENCODING=UTF-8 +``` +to override the standard stream settings in UTF-8 mode. + +## Unconditional sample generation + +To generate unconditional samples from the small model: +``` +python3 src/generate_unconditional_samples.py | tee /tmp/samples +``` +There are various flags for controlling the samples: +``` +python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee /tmp/samples +``` + +To check flag descriptions, use: +``` +python3 src/generate_unconditional_samples.py -- --help +``` + +## Conditional sample generation + +To give the model custom prompts, you can use: +``` +python3 src/interactive_conditional_samples.py --top_k 40 +``` + +To check flag descriptions, use: +``` +python3 src/interactive_conditional_samples.py -- --help +``` diff --git a/Dockerfile.cpu b/Dockerfile.cpu new file mode 100644 index 0000000..b6e4f94 --- /dev/null +++ b/Dockerfile.cpu @@ -0,0 +1,11 @@ +FROM tensorflow/tensorflow:1.12.0-py3 + +ENV LANG=C.UTF-8 +RUN mkdir /gpt-2 +WORKDIR /gpt-2 +ADD . /gpt-2 +RUN pip3 install -r requirements.txt +RUN python3 download_model.py 124M +RUN python3 download_model.py 355M +RUN python3 download_model.py 774M +RUN python3 download_model.py 1558M diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..5ac049a --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,20 @@ +FROM tensorflow/tensorflow:1.12.0-gpu-py3 + +# nvidia-docker 1.0 +LABEL com.nvidia.volumes.needed="nvidia_driver" +LABEL com.nvidia.cuda.version="${CUDA_VERSION}" + +# nvidia-container-runtime +ENV NVIDIA_VISIBLE_DEVICES=all \ + NVIDIA_DRIVER_CAPABILITIES=compute,utility \ + NVIDIA_REQUIRE_CUDA="cuda>=8.0" \ + LANG=C.UTF-8 + +RUN mkdir /gpt-2 +WORKDIR /gpt-2 +ADD . /gpt-2 +RUN pip3 install -r requirements.txt +RUN python3 download_model.py 124M +RUN python3 download_model.py 355M +RUN python3 download_model.py 774M +RUN python3 download_model.py 1558M diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f56abfe --- /dev/null +++ b/LICENSE @@ -0,0 +1,24 @@ +Modified MIT License + +Software Copyright (c) 2019 OpenAI + +We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please. +We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. +The above copyright notice and this permission notice need not be included +with content created by the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index e69de29..81ae197 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,24 @@ +## Overview + +A very quick and dirty implementation on Open AI's GPT-2 in a Discord Bot. This was hacked together in a day and is by no means the best or most efficient way to do it. I may come back to this project to make it prettier another time. + +## Requirements + +* General understanding of how to create and manage Discord bots +* A trained GPT model +* Basic understanding of Python +* All the pre-requisites for using GPT 2 which this was forked from +* A CUDA-Capable GPU (Recommended for better performance) + +## Usage + +You'll have to modify ``bot.py`` with your Discord bot token, as well as point the commands to the correct model. ``bot.py`` is a very basic Discord bot with Open AI's generation scripts pasted into it and modified to return a string rather than print to console. My model was trained off a group chat with friends, as such I've written the bot and its commands to reflect this. The data format which I trained my model on follows the regular expression ``^[A-z]{3} [0-9]{2}:[0-9]{2} [A-Z]{2} - .+: "(.*)"``, which looks like ``Sep 05:57 PM - Username: "Message content"``. Unless your model outputs text in this exact format, you will have to modify this bot to accommodate your needs + +There are three commands to interact with this bot. +* ``!g ``, which generates based off a prompt, or a random one if none is provided. +* ``!r ``, which replies to a prompt. Example ``!r Hi, how are you?`` may respond ``I'm good!``. This command assumes the model to return a string following the regexp mentioned above, and cuts out the irrelevant information to mimic a response from the bot. +* ``!c ``, which continues a prompt. Example ``!c My name is`` to which the bot may continue the prompt with ``My name is Jojgo`` + +## License + +[Modified MIT](./LICENSE) diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..e1c0410 --- /dev/null +++ b/bot.py @@ -0,0 +1,238 @@ +import os +import discord +from dotenv import load_dotenv +import random +import re + +# GENERATE + +#!/usr/bin/env python3 + +import fire +import json +import os +import numpy as np +import tensorflow.compat.v1 as tf + +import model, sample, encoder + +def sample_model( + model_name='h5', + seed=None, + nsamples=1, + batch_size=1, + length=250, + temperature=1, + top_k=0, + top_p=1, + models_dir='models', +): + """ + Run the sample_model + :model_name=124M : String, which model to use + :seed=None : Integer seed for random number generators, fix seed to + reproduce results + :nsamples=0 : Number of samples to return, if 0, continues to + generate samples indefinately. + :batch_size=1 : Number of batches (only affects speed/memory). + :length=None : Number of tokens in generated text, if None (default), is + determined by model hyperparameters + :temperature=1 : Float value controlling randomness in boltzmann + distribution. Lower temperature results in less random completions. As the + temperature approaches zero, the model will become deterministic and + repetitive. Higher temperature results in more random completions. + :top_k=0 : Integer value controlling diversity. 1 means only 1 word is + considered for each step (token), resulting in deterministic completions, + while 40 means 40 words are considered at each step. 0 (default) is a + special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) + """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + enc = encoder.get_encoder(model_name, models_dir) + hparams = model.default_hparams() + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx + elif length > hparams.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) + + with tf.Session(graph=tf.Graph()) as sess: + np.random.seed(seed) + tf.set_random_seed(seed) + + output = sample.sample_sequence( + hparams=hparams, length=length, + start_token=enc.encoder['<|endoftext|>'], + batch_size=batch_size, + temperature=temperature, top_k=top_k, top_p=top_p + )[:, 1:] + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) + saver.restore(sess, ckpt) + + generated = 0 + while nsamples == 0 or generated < nsamples: + out = sess.run(output) + for i in range(batch_size): + generated += batch_size + text = '```' + text += enc.decode(out[i]) + text += '```' + #print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + return text + + +# RESPOND + +def interact_model( + model_name='h5', + seed=None, + nsamples=1, + batch_size=1, + length=100, + temperature=1, + top_k=0, + top_p=1, + models_dir='models', + raw_text='test', + inline = True +): + """ + Interactively run the model + :model_name=124M : String, which model to use + :seed=None : Integer seed for random number generators, fix seed to reproduce + results + :nsamples=1 : Number of samples to return total + :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. + :length=None : Number of tokens in generated text, if None (default), is + determined by model hyperparameters + :temperature=1 : Float value controlling randomness in boltzmann + distribution. Lower temperature results in less random completions. As the + temperature approaches zero, the model will become deterministic and + repetitive. Higher temperature results in more random completions. + :top_k=0 : Integer value controlling diversity. 1 means only 1 word is + considered for each step (token), resulting in deterministic completions, + while 40 means 40 words are considered at each step. 0 (default) is a + special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) + """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + if batch_size is None: + batch_size = 1 + assert nsamples % batch_size == 0 + + enc = encoder.get_encoder(model_name, models_dir) + hparams = model.default_hparams() + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx // 2 + elif length > hparams.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) + + with tf.Session(graph=tf.Graph()) as sess: + context = tf.placeholder(tf.int32, [batch_size, None]) + np.random.seed(seed) + tf.set_random_seed(seed) + output = sample.sample_sequence( + hparams=hparams, length=length, + context=context, + batch_size=batch_size, + temperature=temperature, top_k=top_k, top_p=top_p + ) + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) + saver.restore(sess, ckpt) + + # Generation Code + + context_tokens = enc.encode(raw_text) + generated = 0 + for _ in range(nsamples // batch_size): + out = sess.run(output, feed_dict={ + context: [context_tokens for _ in range(batch_size)] + })[:, len(context_tokens):] + for i in range(batch_size): + generated += 1 + text = '' + if inline: + text = '```' + text += enc.decode(out[i]) + if inline: + text += '```' + return text + print("=" * 80) + +# DISCORD BOT +intents = discord.Intents.default() +intents.message_content = True +client = discord.Client(intents=intents) +TOKEN = ('YOUR TOKEN HERE') + + +@client.event +async def on_ready(): + print("Logged in as a bot {0.user}".format(client)) + +@client.event +async def on_message(message): + + if message.author == client.user: + return + + if message.content == '!g': + response = sample_model() + await message.channel.send(response) + elif "!g" in message.content[:2]: + text = message.content[3:] + content = interact_model(raw_text = text, inline = True, length = 250) + index = content.find('\n') + await message.channel.send(content[index:]) + elif "!r" in message.content[:2]: + text = message.content[3:] + + content = interact_model(raw_text = text, inline = False, length = 50) + + regexp = re.compile('^[A-z]{3} [0-9]{2}:[0-9]{2} [A-Z]{2} - .+: "(.*)"') + matched = regexp.match(content) + + if matched: + toReturn = matched.group(1) + await message.channel.send(toReturn) + else: + index = (content.find('\n')) + 1 + temp = content[index:] + + temp_match = regexp.match(temp) + if temp_match: + toReturn = temp_match.group(1) + await message.channel.send(toReturn) + else: + await message.channel.send(temp) + + elif "!c" in message.content[:2]: + # remove anything after \n + # add text to response + text = message.content[3:] + content = interact_model(raw_text = text, length = 20, inline = False) + + sep = '\n' + stripped = content.split(sep, 1)[0] + + toReturn = text + "" + stripped + + await message.channel.send(toReturn[:(len(toReturn) - 1)]) + elif message.content =='!h': + response = '```!g - Generates Conversation. If no prompt provided a random one will be used.\n!r - Responds to prompt\n!c - Continues prompt```' + await message.channel.send(response) + + +client.run(TOKEN) \ No newline at end of file diff --git a/domains.txt b/domains.txt new file mode 100644 index 0000000..04bdac4 --- /dev/null +++ b/domains.txt @@ -0,0 +1,1000 @@ +1542261 google +596207 archive +456344 blogspot +414695 github +333160 nytimes +321622 wordpress +315368 washingtonpost +313137 wikia +311917 bbc +246303 theguardian +210714 ebay +209416 pastebin +199360 cnn +196124 yahoo +186668 huffingtonpost +186137 go +183592 reuters +183080 imdb +160553 goo +139965 nih +135562 cbc +128011 apple +125615 medium +118676 dailymail +108012 steampowered +106417 independent +105239 etsy +98941 craigslist +93048 businessinsider +92712 telegraph +90262 wizards +83266 usatoday +80384 thehill +79655 nhl +79494 foxnews +79167 taobao +78070 bloomberg +77515 npr +77407 mlb +77172 latimes +75676 megalodon +72525 espn +72523 kickstarter +71743 breitbart +69334 abc +68009 newegg +67008 wwe +66278 myanimelist +65520 microsoft +64723 buzzfeed +63162 vice +62911 indiatimes +61845 forbes +61772 tappedout +60889 wsj +60240 vid +60239 battle +59996 adf +58706 politico +58345 redditgifts +56769 nexusmods +56469 goodreads +54866 magiccards +53973 nbcnews +53060 gamepedia +52110 mediafire +50567 time +50144 cbsnews +49203 ppy +48442 gstatic +48042 nfl +47460 steamusercontent +47046 thestar +46603 bugguide +46340 fanfiction +45505 mturk +45458 cbslocal +44729 theglobeandmail +44134 nydailynews +42992 theatlantic +42941 netflix +42328 theverge +41952 smh +40694 nbcsports +40613 cnbc +40469 slate +40071 ign +39655 dotabuff +38968 wired +38779 chicagotribune +38590 urbandictionary +38575 rt +38092 wuxiaworld +38065 wowhead +37954 wolframalpha +37749 guardian +37594 xboxdvr +36841 nypost +36741 ravelry +36321 thedailybeast +36298 nba +36188 yelp +36008 arstechnica +35485 csgo +35365 flic +35269 stackexchange +35124 vidble +35024 googleusercontent +34311 msn +34121 gizmodo +34120 boardgamegeek +33867 aljazeera +33598 rawstory +33516 scryfall +33467 bleacherreport +33419 bit +33395 thinkprogress +33170 dailycaller +32843 ap +32433 fangraphs +31742 salon +31728 mirror +31496 nintendo +31294 nationalpost +31278 nasa +31110 oddshot +31057 hltv +30952 amzn +30877 quora +30586 engadget +30397 stackoverflow +30201 aliexpress +29710 cnet +28850 leagueoflegends +28822 surveymonkey +28704 ctvnews +28650 walmart +28644 plays +28536 sfgate +28375 cbssports +28210 globo +27992 discogs +27630 wiktionary +27588 ibb +27544 stuff +27349 nature +27112 news +27020 biblegateway +26801 subtletv +26427 change +26355 zippyshare +26311 guildwars2 +26231 vox +26205 zkillboard +26174 techcrunch +25993 economist +25964 globalnews +25621 washingtontimes +25610 hollywoodreporter +25351 archiveofourown +25336 ibtimes +25257 newsweek +25139 zerohedge +25074 fav +25050 sciencedirect +24894 bestbuy +24870 spiegel +24869 247sports +24866 smmry +24764 xda-developers +24726 tvtropes +24698 phys +24663 teamliquid +24619 state +23953 gleam +23676 sbnation +23644 asahi +23620 foxsports +23240 ndtv +23189 si +23183 alternet +23009 redbubble +22846 metro +22845 theonion +22835 playstation +22808 washingtonexaminer +22682 thehindu +22557 espncricinfo +22482 mozilla +22219 op +22038 t +21984 nj +21921 indianexpress +21707 apnews +21603 dw +21422 nationalgeographic +21399 pinterest +21368 ft +21319 wiley +21254 about +21074 skysports +21033 gamespot +21014 dailykos +21009 goal +20858 patheos +20842 irishtimes +20664 variety +20592 kotaku +20584 mashable +20575 scientificamerican +20448 basketball-reference +20262 yle +20218 theage +20176 usnews +20133 animenewsnetwork +20092 livejournal +20068 +20024 pbs +19802 nhk +19741 newyorker +19727 seattletimes +19672 mlssoccer +19619 meetup +19543 nzherald +19509 philly +19496 uol +19470 patreon +19429 wikileaks +19400 gravitytales +19294 oregonlive +19267 xbox +19216 linkedin +19202 crunchyroll +19045 target +19021 ew +18922 redditpoll +18875 homedepot +18867 qz +18865 donmai +18653 baseball-reference +18646 talkingpointsmemo +18576 pathofexile +18536 makeameme +18489 postimg +18308 clyp +18175 scribd +18120 thegatewaypundit +18097 removeddit +18063 deadspin +18049 sciencedaily +18019 huffpost +17987 dallasnews +17956 europa +17878 merriam-webster +17816 haaretz +17746 deadline +17637 msnbc +17579 hindustantimes +17531 nymag +17429 gph +17208 typepad +17204 express +17098 naver +17085 bizjournals +17084 mlive +16834 rollingstone +16793 motherjones +16704 okcupid +16441 tinyurl +16410 espnfc +16397 bostonglobe +16374 thingiverse +16351 denverpost +16332 bitcointalk +16256 timesofisrael +16209 xnxx +16202 wikihow +16051 neopets +16043 indiegogo +16033 al +16032 chron +16004 avclub +15970 marketwatch +15933 mercurynews +15675 startribune +15646 pro-football-reference +15568 d20pfsrd +15545 pcgamer +15451 reason +15422 uesp +15356 lds +15152 polygon +15132 humblebundle +14962 tradingview +14931 baltimoresun +14914 strava +14912 firstpost +14856 commondreams +14801 sky +14739 eventbrite +14722 nicovideo +14697 fortune +14693 knowyourmeme +14666 robertsspaceindustries +14471 pitchfork +14466 psychologytoday +14435 combodeck +14392 mixcloud +14372 lemonde +14290 sciencemag +14060 jpost +13926 miamiherald +13902 patch +13850 nationalreview +13849 gofundme +13798 thelocal +13763 derpibooru +13726 techdirt +13658 townhall +13596 mtg +13588 gettyimages +13530 mit +13436 challonge +13369 mediaite +13357 tsn +13350 pokemonshowdown +13176 neogaf +13130 publico +13126 snopes +13092 scmp +13082 cleveland +13044 thesun +13025 mtggoldfish +12994 freep +12984 grailed +12948 standard +12923 theconversation +12913 upi +12870 bing +12778 blockchain +12774 people +12771 arxiv +12760 hearthpwn +12668 reference +12626 edhrec +12611 sputniknews +12551 nordstrom +12550 lapresse +12496 metacritic +12447 last +12395 ajc +12355 mangadex +12349 ycombinator +12345 csmonitor +12240 sportsnet +12229 cornell +12205 smithsonianmag +12201 sephora +12194 bulbagarden +12181 japantimes +12171 zdnet +12152 comicbook +12139 whitehouse +12109 theregister +12089 libsyn +12052 asos +12016 neatclip +12001 imirhil +12000 boston +11973 behance +11966 eveonline +11954 androidpolice +11935 livescience +11843 instructables +11817 hs +11788 infowars +11712 ca +11704 runescape +11699 suntimes +11697 eurogamer +11654 roblox +11622 genius +11602 stltoday +11499 elpais +11494 motorsport +11461 ceddit +11426 france24 +11373 bungie +11371 youtubedoubler +11362 openload +11348 jstor +11328 thefreedictionary +11307 inquisitr +11215 nhentai +11204 zeit +11198 ikea +11114 springer +11108 tripadvisor +11082 thescore +11036 kerbalspaceprogram +11007 cdc +10995 dailywire +10965 gawker +10953 a +10950 brooksbaseball +10940 dn +10927 sltrib +10867 brickset +10823 dictionary +10821 squarespace +10819 battlefield +10807 harvard +10786 afpbb +10734 steemit +10730 billboard +10707 tampabay +10654 nola +10621 stanford +10602 sbs +10524 cc +10520 dailydot +10510 straitstimes +10493 itch +10490 foreignpolicy +10465 vancouversun +10440 rottentomatoes +10419 dnainfo +10389 digi24 +10348 dropboxusercontent +10332 complex +10330 scp-wiki +10327 prnt +10313 ottawacitizen +10304 anandtech +10269 thenation +10253 fivethirtyeight +10244 newscientist +10240 svt +10240 inquirer +10236 coindesk +10227 codepen +10208 lichess +10204 sankei +10189 ted +10181 roosterteeth +10170 livemint +10161 teamfortress +10141 sourceforge +10119 sapo +10113 countle +10086 mtv +10075 sacbee +10066 fimfiction +10057 hentai-foundry +10054 gamesplanet +10044 io9 +10032 lifehacker +10007 cracked +9991 mainichi +9984 itmedia +9966 warthunder +9936 nos +9935 boingboing +9925 vulture +9904 lanacion +9892 qualtrics +9884 muthead +9856 jcrew +9814 jsonline +9787 spacebattles +9748 worldstarhiphop +9734 jalopnik +9721 welt +9717 curbed +9708 dbr +9705 mmafighting +9697 bigcartel +9682 transfermarkt +9680 vlive +9659 vanityfair +9658 dawn +9621 dnaindia +9601 theblaze +9599 allrecipes +9576 thejournal +9572 dailystar +9521 minecraftforum +9505 theweek +9502 kansascity +9494 anilist +9443 gog +9420 bato +9401 oxforddictionaries +9400 soompi +9394 sagepub +9389 wikiwand +9382 lolking +9322 torontosun +9319 mangapanda +9316 politifact +9306 realclearpolitics +9278 tagpro +9261 webmd +9206 app +9202 hotnews +9184 9news +9174 bhphotovideo +9147 giantbomb +9132 gamestop +9073 azcentral +9053 noaa +9040 repubblica +9021 mangaupdates +8998 space +8998 researchgate +8971 bitcoin +8957 sueddeutsche +8898 rightwingwatch +8892 mediacru +8890 afl +8862 fasttech +8858 tmz +8841 orlandosentinel +8832 tomshardware +8828 altomfotball +8822 mtgprice +8821 haskell +8816 discovery +8810 destinytracker +8808 massdrop +8800 csgolounge +8791 weather +8778 daddyleagues +8720 govtrack +8678 mentalfloss +8678 justice +8663 frontier +8655 youporn +8641 paradoxplaza +8640 rockstargames +8632 derstandard +8622 pinknews +8619 macrumors +8598 gamefaqs +8587 thepiratebay +8586 4chan +8582 post-gazette +8573 faz +8563 e-hentai +8530 jiji +8525 quoracdn +8519 fullmatchesandshows +8516 sun-sentinel +8513 xboxclips +8488 financialpost +8476 audible +8439 investopedia +8425 loc +8418 venturebeat +8414 amazonaws +8368 ubi +8345 etymonline +8326 wsws +8316 jezebel +8300 americanthinker +8284 wikidot +8269 digitaltrends +8260 nrk +8232 weebly +8228 thenextweb +8225 snahp +8223 gematsu +8210 daum +8206 ea +8189 liverpoolecho +8186 freebeacon +8178 thetimes +8168 naturalcrit +8153 warframe +8150 1drv +8143 gap +8131 seriouseats +8119 myfigurecollection +8109 gov +8086 eporner +8080 hulu +8077 senate +8046 esquire +8015 gosugamers +8000 radionz +7997 eater +7982 politicususa +7978 rte +7956 marvel +7942 metronews +7917 starcitygames +7917 hotair +7914 marca +7872 eurekalert +7840 screenrant +7834 dota2 +7797 truth-out +7784 dell +7783 eldiario +7782 pcworld +7782 doi +7780 comicbookresources +7765 dr +7729 howstuffworks +7727 gocomics +7715 worldoftanks +7707 tandfonline +7690 examiner +7688 newrepublic +7682 curseforge +7680 findlaw +7673 nikkei +7665 heraldsun +7652 podbean +7645 aftonbladet +7638 duckduckgo +7633 ynetnews +7629 timesofindia +7628 freshphase +7591 westeros +7576 youjizz +7574 spectator +7548 justia +7537 antiwar +7536 mmajunkie +7516 yomiuri +7485 newstatesman +7481 greenmangaming +7475 joystiq +7444 jsfiddle +7424 anime-planet +7415 counterpunch +7410 autosport +7395 archlinux +7384 berkeley +7383 smbc-comics +7374 rockpapershotgun +7372 pjmedia +7367 estadao +7365 intoday +7361 newsmax +7346 newsbusters +7337 grantland +7329 voanews +7292 myshopify +7286 wnd +7265 9to5mac +7257 hurriyetdailynews +7229 bleedingcool +7225 indiewire +7222 radio-canada +7216 viewsync +7211 cambridge +7204 drsd +7197 house +7185 uproxx +7152 mlbtraderumors +7145 gamasutra +7134 bricklink +7122 foodnetwork +7122 presstv +7119 opensecrets +7118 canada +7116 bgr +7097 democracynow +7091 businessweek +7085 smash +7080 usda +7078 cloudfront +7044 psu +7028 detroitnews +7028 explosm +7013 woobox +7011 football-italia +7005 academia +6948 channelnewsasia +6927 siliconera +6923 rei +6917 deseretnews +6916 supload +6914 mises +6905 rotoworld +6886 gsmarena +6878 rappler +6876 kijiji +6866 metal-archives +6826 theaustralian +6823 mediamatters +6823 wa +6818 bodybuilding +6811 memedad +6803 ucsd +6802 barnesandnoble +6791 india +6780 readability +6777 today +6726 indystar +6720 scotsman +6694 impress +6689 torrentfreak +6675 heise +6668 sportingnews +6658 pnas +6650 chzbgr +6650 milb +6631 business-standard +6630 bustle +6623 square-enix +6622 madison +6615 moddb +6613 uniqlo +6599 zillow +6577 tribune +6556 airliners +6552 svd +6547 gameinformer +6536 brisbanetimes +6536 ocregister +6533 swtor +6526 calgaryherald +6521 c-span +6518 slashdot +6505 belfasttelegraph +6499 hiyo +6494 news24 +6484 theintercept +6479 technologyreview +6455 gutenberg +6449 cinemablend +6438 dailytelegraph +6424 globalresearch +6411 lefigaro +6405 tenor +6381 redstate +6374 aclu +6361 bloodyelbow +6357 axios +6353 thewrap +6349 redditmetrics +6345 evike +6339 aol +6327 ulta +6326 plos +6324 periscope +6312 drivethrurpg +6308 infobae +6300 debian +6298 congress +6289 warcraftlogs +6284 gothamist +6281 mangastream +6276 newgrounds +6275 berniesanders +6263 lolesports +6262 mayoclinic +6242 sfchronicle +6235 edmontonjournal +6200 dhgate +6194 cincinnati +6180 history +6176 xtube +6169 nike +6160 kiji +6147 tube8 +6140 vdare +6133 unity3d +6130 twincities +6127 escapistmagazine +6126 komonews +6104 openneo +6090 oup +6082 dispatch +6079 newsobserver +6060 ballotpedia +6058 indiegala +6054 index +6050 charlotteobserver +6048 androidcentral +6032 webtoons +6028 tcgplayer +6018 zappos +6004 intel +5998 seattlepi +5996 profootballfocus +5990 ksl +5989 macleans +5984 atlasobscura +5981 yugiohprices +5980 ubuntu +5964 gq +5952 myvidster +5941 tv2 +5930 paizo +5926 montrealgazette +5919 al-monitor +5919 herokuapp +5918 volarenovels +5909 usgs +5906 nme +5906 society6 +5905 vg247 +5902 popsci +5895 lowes +5893 thefederalist +5878 amiami +5862 nyti +5848 steamdb +5841 crooksandliars +5833 popularmechanics +5832 slashfilm +5826 woot +5818 ev +5807 illinois +5792 nps +5791 destructoid +5790 mysanantonio +5772 sbtl +5742 smashboards +5700 biblehub +5696 euronews +5694 urbanoutfitters +5687 itv +5685 fastcompany +5684 techpowerup +5674 hearthhead +5656 mic +5649 autoblog +5646 futbin +5638 voat +5636 statesman +5626 zap2it +5623 userbenchmark +5623 legaliq +5622 mspaintadventures +5622 familysearch +5616 themoscowtimes +5606 theprovince +5604 allkpop +5594 Omegle +5570 activistpost +5565 thefreethoughtproject +5565 in +5559 sandiegouniontribune +5556 consumerist +5554 eff +5532 lego +5520 translationnations +5515 clickhole +5498 etherscan +5491 live +5486 vndb +5484 poll-maker +5481 mtgsalvation +5481 computerworld +5475 comicvine +5470 python +5469 digitalspy +5468 citylab +5458 expressen +5455 oxfordjournals +5451 collider +5447 statista +5437 apa +5434 g +5430 thenational +5430 eslgaming +5425 politiken +5421 ktla +5420 webmshare +5408 bostonherald +5407 comixology +5400 ustream +5399 sony +5396 tennessean +5377 scout +5374 drop +5372 ieee +5359 sverigesradio +5356 sherdog +5353 viooz +5353 marxists +5353 adobe +5349 myfitnesspal +5342 seahawks +5339 rferl +5338 thediplomat +5335 storeparser +5332 prnewswire +5330 midwayusa +5327 liverpoolfc +5326 cisco +5326 windowsphone +5323 toysrus +5321 archivesofnethys +5317 eluniversal +5309 gmanetwork +5303 asus +5297 android +5297 finalfantasyxiv +5296 cyclingnews +5293 worldbank +5288 boxingscene +5285 ticketmaster +5279 grooveshark +5277 khl +5276 gallup +5268 britannica +5263 abc7 +5260 penny-arcade +5257 hsreplay +5257 oculus +5256 bt +5250 theroot +5246 makeagif +5246 cnsnews +5243 nbc +5243 rbc +5243 fextralife +5234 legislation +5225 sendvid +5221 sciencealert +5214 wbur +5212 myfonts +5207 picsarus +5206 phoronix +5204 nerdist +5203 eonline +5195 advocate +5191 king5 +5189 xkcd +5183 kitsu +5182 weibo +5181 mangareader +5178 palmbeachpost +5176 go1dfish +5175 livestrong +5174 truthdig +5173 lgbtqnation +5172 nikkansports +5167 slickdeals +5166 streamja +5164 irs +5158 readms +5152 microcenter +5137 telesurtv +5135 lastwordonsports +5129 alarabiya +5117 cointelegraph +5114 iltalehti +5112 fc2 +5108 wral +5108 thinkgeek +5102 bitbucket +5101 letterboxd +5098 ehow +5092 abc13 +5083 beeradvocate +5077 umich +5067 macys +5064 factorio +5063 comicbookmovie +5042 telegram +5039 scroll +5034 setlist +5028 dailyherald +5019 games-workshop +5015 irishexaminer +5008 fbi +5007 heraldscotland +5001 jellyneo +4999 yale +4996 cbr +4994 masslive +4984 thestranger +4982 bundlestars +4981 alibaba +4977 filedropper +4974 monoprice +4968 forward +4964 parliament +4960 theringer +4950 hobbyking +4950 manchestereveningnews +4949 bmj +4948 thewire +4947 ff2ebook +4938 ashemaletube +4937 Twitch +4933 sketchtoy +4932 mcclatchydc +4931 memory-alpha +4925 newsok +4911 desmoinesregister +4901 puzzledragonx +4889 memecrunch diff --git a/download_model.py b/download_model.py new file mode 100644 index 0000000..54e4bb6 --- /dev/null +++ b/download_model.py @@ -0,0 +1,28 @@ +import os +import sys +import requests +from tqdm import tqdm + +if len(sys.argv) != 2: + print('You must enter the model name as a parameter, e.g.: download_model.py 124M') + sys.exit(1) + +model = sys.argv[1] + +subdir = os.path.join('models', model) +if not os.path.exists(subdir): + os.makedirs(subdir) +subdir = subdir.replace('\\','/') # needed for Windows + +for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: + + r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True) + + with open(os.path.join(subdir, filename), 'wb') as f: + file_size = int(r.headers["content-length"]) + chunk_size = 1000 + with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size) diff --git a/model_card.md b/model_card.md new file mode 100644 index 0000000..38246ee --- /dev/null +++ b/model_card.md @@ -0,0 +1,69 @@ +# GPT-2 model card + +Last updated: November 2019 + +Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we’re providing some accompanying information about the GPT-2 family of models we're releasing. + +## Model Details. + +This model was developed by researchers at OpenAI to help us understand how the capabilities of language model capabilities scale as a function of the size of the models (by parameter count) combined with very large internet-scale datasets (WebText). + +### Model date + +February 2019, trained on data that cuts off at the end of 2017. + +### Model type + +Language model + +### Model version + +1.5 billion parameters: the fourth and largest GPT-2 version. We have also released 124 million, 355 million, and 774 million parameter models. + +### Paper or other resource for more information +[Blog post](https://openai.com/blog/better-language-models/) and [paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) + +### Where to send questions or comments about the model +Please use this [Google Form](https://forms.gle/A7WBSbTY2EkKdroPA) + +## Intended Uses: + +### Primary intended uses + +The primary intended users of these models are *AI researchers and practitioners*. + +We primarily imagine these language models will be used by researchers to better understand the behaviors, capabilities, biases, and constraints of large-scale generative language models. + +### Secondary uses + +Here are some secondary use cases we believe are likely: + +- **Writing assistance**: Grammar assistance, autocompletion (for normal prose or code) +- **Creative writing and art**: exploring the generation of creative, fictional texts; aiding creation of poetry and other literary art. +- **Entertainment**: Creation of games, chat bots, and amusing generations. + +### Out-of-scope use cases + +Because large-scale language models like GPT-2 do not distinguish fact from fiction, we don’t support use-cases that require the generated text to be true. + +Additionally, language models like GPT-2 reflect the biases inherent to the systems they were trained on, so we do not recommend that they be deployed into systems that interact with humans unless the deployers first carry out a study of biases relevant to the intended use-case. We found no statistically significant difference in gender, race, and religious bias probes between 774M and 1.5B, implying all versions of GPT-2 should be approached with similar levels of caution around use cases that are sensitive to biases around human attributes. + +## Evaluation Data + +### Datasets + +This model was trained on (and evaluated against) WebText, a dataset consisting of the text contents of 45 million links posted by users of the ‘Reddit’ social network. WebText is made of data derived from outbound links from Reddit and does not consist of data taken directly from Reddit itself. Before generating the dataset we used a blocklist to ensure we didn’t sample from a variety of subreddits which contain sexually explicit or otherwise offensive content. + +To get a sense of the data that went into GPT-2, we’ve [published a list](domains.txt) of the top 1,000 domains present in WebText and their frequency. The top 15 domains by volume in WebText are: Google, Archive, Blogspot, GitHub, NYTimes, Wordpress, Washington Post, Wikia, BBC, The Guardian, eBay, Pastebin, CNN, Yahoo!, and the Huffington Post. + +### Motivation + +The motivation behind WebText was to create an Internet-scale, heterogeneous dataset that we could use to test large-scale language models against. WebText was (and is) intended to be primarily for research purposes rather than production purposes. + +### Caveats and Recommendations + +Because GPT-2 is an internet-scale language model, it’s currently difficult to know what disciplined testing procedures can be applied to it to fully understand its capabilities and how the data it is trained on influences its vast range of outputs. We recommend researchers investigate these aspects of the model and share their results. + +Additionally, as indicated in our discussion of issues relating to potential misuse of the model, it remains unclear what the long-term dynamics are of detecting outputs from these models. We conducted [in-house automated ML-based detection research](https://github.com/openai/gpt-2-output-dataset/tree/master/detector) using simple classifiers, zero shot, and fine-tuning methods. Our fine-tuned detector model reached accuracy levels of approximately 95%. However, no one detection method is a panacea; automated ML-based detection, human detection, human-machine teaming, and metadata-based detection are all methods that can be combined for more confident classification. Developing better approaches to detection today will give us greater intuitions when thinking about future models and could help us understand ahead of time if detection methods will eventually become ineffective. + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2cc521d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +fire>=0.1.3 +regex==2017.4.5 +requests==2.21.0 +tqdm==4.31.1 diff --git a/src/encoder.py b/src/encoder.py new file mode 100644 index 0000000..5f52e72 --- /dev/null +++ b/src/encoder.py @@ -0,0 +1,117 @@ +"""Byte pair encoding utilities""" + +import os +import json +import regex as re +from functools import lru_cache + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + +class Encoder: + def __init__(self, encoder, bpe_merges, errors='replace'): + self.encoder = encoder + self.decoder = {v:k for k,v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + +def get_encoder(model_name, models_dir): + with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: + encoder = json.load(f) + with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py new file mode 100755 index 0000000..eaf9a63 --- /dev/null +++ b/src/generate_unconditional_samples.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +import fire +import json +import os +import numpy as np +import tensorflow as tf + +import model, sample, encoder + +def sample_model( + model_name='124M', + seed=None, + nsamples=0, + batch_size=1, + length=None, + temperature=1, + top_k=0, + top_p=1, + models_dir='models', +): + """ + Run the sample_model + :model_name=124M : String, which model to use + :seed=None : Integer seed for random number generators, fix seed to + reproduce results + :nsamples=0 : Number of samples to return, if 0, continues to + generate samples indefinately. + :batch_size=1 : Number of batches (only affects speed/memory). + :length=None : Number of tokens in generated text, if None (default), is + determined by model hyperparameters + :temperature=1 : Float value controlling randomness in boltzmann + distribution. Lower temperature results in less random completions. As the + temperature approaches zero, the model will become deterministic and + repetitive. Higher temperature results in more random completions. + :top_k=0 : Integer value controlling diversity. 1 means only 1 word is + considered for each step (token), resulting in deterministic completions, + while 40 means 40 words are considered at each step. 0 (default) is a + special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) + """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + enc = encoder.get_encoder(model_name, models_dir) + hparams = model.default_hparams() + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx + elif length > hparams.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) + + with tf.Session(graph=tf.Graph()) as sess: + np.random.seed(seed) + tf.set_random_seed(seed) + + output = sample.sample_sequence( + hparams=hparams, length=length, + start_token=enc.encoder['<|endoftext|>'], + batch_size=batch_size, + temperature=temperature, top_k=top_k, top_p=top_p + )[:, 1:] + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) + saver.restore(sess, ckpt) + + generated = 0 + while nsamples == 0 or generated < nsamples: + out = sess.run(output) + for i in range(batch_size): + generated += batch_size + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + +if __name__ == '__main__': + fire.Fire(sample_model) + diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py new file mode 100755 index 0000000..8b66000 --- /dev/null +++ b/src/interactive_conditional_samples.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +import fire +import json +import os +import numpy as np +import tensorflow as tf + +import model, sample, encoder + +def interact_model( + model_name='124M', + seed=None, + nsamples=1, + batch_size=1, + length=None, + temperature=1, + top_k=0, + top_p=1, + models_dir='models', +): + """ + Interactively run the model + :model_name=124M : String, which model to use + :seed=None : Integer seed for random number generators, fix seed to reproduce + results + :nsamples=1 : Number of samples to return total + :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. + :length=None : Number of tokens in generated text, if None (default), is + determined by model hyperparameters + :temperature=1 : Float value controlling randomness in boltzmann + distribution. Lower temperature results in less random completions. As the + temperature approaches zero, the model will become deterministic and + repetitive. Higher temperature results in more random completions. + :top_k=0 : Integer value controlling diversity. 1 means only 1 word is + considered for each step (token), resulting in deterministic completions, + while 40 means 40 words are considered at each step. 0 (default) is a + special setting meaning no restrictions. 40 generally is a good value. + :models_dir : path to parent folder containing model subfolders + (i.e. contains the folder) + """ + models_dir = os.path.expanduser(os.path.expandvars(models_dir)) + if batch_size is None: + batch_size = 1 + assert nsamples % batch_size == 0 + + enc = encoder.get_encoder(model_name, models_dir) + hparams = model.default_hparams() + with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx // 2 + elif length > hparams.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) + + with tf.Session(graph=tf.Graph()) as sess: + context = tf.placeholder(tf.int32, [batch_size, None]) + np.random.seed(seed) + tf.set_random_seed(seed) + output = sample.sample_sequence( + hparams=hparams, length=length, + context=context, + batch_size=batch_size, + temperature=temperature, top_k=top_k, top_p=top_p + ) + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) + saver.restore(sess, ckpt) + + while True: + raw_text = input("Model prompt >>> ") + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") + context_tokens = enc.encode(raw_text) + generated = 0 + for _ in range(nsamples // batch_size): + out = sess.run(output, feed_dict={ + context: [context_tokens for _ in range(batch_size)] + })[:, len(context_tokens):] + for i in range(batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) + +if __name__ == '__main__': + fire.Fire(interact_model) + diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..230b83c --- /dev/null +++ b/src/model.py @@ -0,0 +1,174 @@ +import numpy as np +import tensorflow as tf +from tensorflow.contrib.training import HParams + +def default_hparams(): + return HParams( + n_vocab=0, + n_ctx=1024, + n_embd=768, + n_head=12, + n_layer=12, + ) + +def shape_list(x): + """Deal with dynamic shape in tensorflow cleanly.""" + static = x.shape.as_list() + dynamic = tf.shape(x) + return [dynamic[i] if s is None else s for i, s in enumerate(static)] + +def softmax(x, axis=-1): + x = x - tf.reduce_max(x, axis=axis, keepdims=True) + ex = tf.exp(x) + return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) + +def gelu(x): + return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) + +def norm(x, scope, *, axis=-1, epsilon=1e-5): + """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" + with tf.variable_scope(scope): + n_state = x.shape[-1].value + g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) + b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) + u = tf.reduce_mean(x, axis=axis, keepdims=True) + s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) + x = (x - u) * tf.rsqrt(s + epsilon) + x = x*g + b + return x + +def split_states(x, n): + """Reshape the last dimension of x into [n, x.shape[-1]/n].""" + *start, m = shape_list(x) + return tf.reshape(x, start + [n, m//n]) + +def merge_states(x): + """Smash the last two dimensions of x into a single dimension.""" + *start, a, b = shape_list(x) + return tf.reshape(x, start + [a*b]) + +def conv1d(x, scope, nf, *, w_init_stdev=0.02): + with tf.variable_scope(scope): + *start, nx = shape_list(x) + w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) + b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) + c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) + return c + +def attention_mask(nd, ns, *, dtype): + """1's in the lower triangle, counting from the lower right corner. + + Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. + """ + i = tf.range(nd)[:,None] + j = tf.range(ns) + m = i >= j - ns + nd + return tf.cast(m, dtype) + + +def attn(x, scope, n_state, *, past, hparams): + assert x.shape.ndims == 3 # Should be [batch, sequence, features] + assert n_state % hparams.n_head == 0 + if past is not None: + assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] + + def split_heads(x): + # From [batch, sequence, features] to [batch, heads, sequence, features] + return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) + + def merge_heads(x): + # Reverse of split_heads + return merge_states(tf.transpose(x, [0, 2, 1, 3])) + + def mask_attn_weights(w): + # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. + _, _, nd, ns = shape_list(w) + b = attention_mask(nd, ns, dtype=w.dtype) + b = tf.reshape(b, [1, 1, nd, ns]) + w = w*b - tf.cast(1e10, w.dtype)*(1-b) + return w + + def multihead_attn(q, k, v): + # q, k, v have shape [batch, heads, sequence, features] + w = tf.matmul(q, k, transpose_b=True) + w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) + + w = mask_attn_weights(w) + w = softmax(w) + a = tf.matmul(w, v) + return a + + with tf.variable_scope(scope): + c = conv1d(x, 'c_attn', n_state*3) + q, k, v = map(split_heads, tf.split(c, 3, axis=2)) + present = tf.stack([k, v], axis=1) + if past is not None: + pk, pv = tf.unstack(past, axis=1) + k = tf.concat([pk, k], axis=-2) + v = tf.concat([pv, v], axis=-2) + a = multihead_attn(q, k, v) + a = merge_heads(a) + a = conv1d(a, 'c_proj', n_state) + return a, present + + +def mlp(x, scope, n_state, *, hparams): + with tf.variable_scope(scope): + nx = x.shape[-1].value + h = gelu(conv1d(x, 'c_fc', n_state)) + h2 = conv1d(h, 'c_proj', nx) + return h2 + + +def block(x, scope, *, past, hparams): + with tf.variable_scope(scope): + nx = x.shape[-1].value + a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) + x = x + a + m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) + x = x + m + return x, present + +def past_shape(*, hparams, batch_size=None, sequence=None): + return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] + +def expand_tile(value, size): + """Add a new axis of given size.""" + value = tf.convert_to_tensor(value, name='value') + ndims = value.shape.ndims + return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) + +def positions_for(tokens, past_length): + batch_size = tf.shape(tokens)[0] + nsteps = tf.shape(tokens)[1] + return expand_tile(past_length + tf.range(nsteps), batch_size) + + +def model(hparams, X, past=None, scope='model', reuse=False): + with tf.variable_scope(scope, reuse=reuse): + results = {} + batch, sequence = shape_list(X) + + wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.01)) + wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.02)) + past_length = 0 if past is None else tf.shape(past)[-2] + h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) + + # Transformer + presents = [] + pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer + assert len(pasts) == hparams.n_layer + for layer, past in enumerate(pasts): + h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) + presents.append(present) + results['present'] = tf.stack(presents, axis=1) + h = norm(h, 'ln_f') + + # Language model loss. Do tokens