Stock Prediction with BERT (1)

Using pre-trained BERT from Mxnet, the post shows how to predict DJIA's adjusted closing prices.

Intro

When we want to predict next day's (week's or month's even) prices of a certain stock, first thing we do is to get as much as information about a company and 'guess' what it will be likely. This was usually done by hands without much help from using computers in the past. Even if one was used, it did not help much because of limits on resources such as computing power.

However, as technology is getting better and faster computers are manufactured every second, we began to start utilizing them to help us for prediction. In this post, I am sharing what I did to predict DJIA's adjusted closing prices with news articles as input features.

The data used is from Kaggle Dataset, uploaded by Aaron7sun. It has 25 news articles each day from 2008-06-08 to 2016-07-01, total of 1989 days of samples.

There are three csv files but I only used 'Combined_News_DJIA' because I made a model that predicts only with articles.

Common approaches before was just to use RNN, GRU, LSTM or ARIMA models that rely on past values. However, my approach was to use same day's news articles and try to get how much they affect the day's opening value. If it affects positively, the closing will result in higher value.

Since the data is in string format and not numeric, I used pre-trained BERT to convert them into vectors of floating values, which I got from Mxnet's Model Zoo.

What is BERT?

BERT is an encoder that given sets of words (or phrases), converts them into appropriate floating values. Unlike word2vec which has fixed value for each word, it can capture significance of a word in a sentence. So for the same word in two different sentences, it can output different values if it has different meaning or impact on them.

As an example, we can look at two sentences. 1. I hate seeing you 2. I hate leaving you

If we are to predict my feeling about you with word2vec, we are forced to make a model only with 'seeing' and 'leaving' because they both contain 'I', 'hate', and 'you' in the same position that the model will not gain much from them. But if we use BERT, it's possible to capture that 'hate seeing' has negative feeling while 'hate leaving' has positive one because 'hate' will then have differert values.

Another example of is to predict a rating of a restaurant. With the sentence 'Bob hates this restaurant', word2vec might have following values.

  1. Bob : 3

  2. hates : -7

  3. this : 0

  4. restaurant : 3

If we make a (naive) model that just sums up values, with above numbers and predict if Bob's rating will be positive or negative, we will get a negative rating. But what happens if we change 'hates' to 'dislikes' which has the value of -5. Then the output of the model will be positive with the value of 1.

If we define a model with word2vec, we would have to consider all kinds of possibility and many different combinations to correctly output a desired result.

This is where BERT differs from word2vec as it has the capability of capturing each word's impact. As the purpose of the post is not about BERT, I will skip the rest of the explanation and will have another in later post.

Data Preprocess

import warnings
warnings.filterwarnings('ignore')

import numpy
import pandas as pd
path = './data/'

combined_news_path = path + 'Combined_News_DJIA.csv'

news_djia = pd.read_csv(combined_news_path)
news_djia.shape
(1989, 27)
news_djia.head(2)

Date

Label

Top1

Top2

Top3

Top4

Top5

Top6

Top7

Top8

...

Top16

Top17

Top18

Top19

Top20

Top21

Top22

Top23

Top24

Top25

0

2008-08-08

0

b"Georgia 'downs two Russian warplanes' as cou...

b'BREAKING: Musharraf to be impeached.'

b'Russia Today: Columns of troops roll into So...

b'Russian tanks are moving towards the capital...

b"Afghan children raped with 'impunity,' U.N. ...

b'150 Russian tanks have entered South Ossetia...

b"Breaking: Georgia invades South Ossetia, Rus...

b"The 'enemy combatent' trials are nothing but...

...

b'Georgia Invades South Ossetia - if Russia ge...

b'Al-Qaeda Faces Islamist Backlash'

b'Condoleezza Rice: "The US would not act to p...

b'This is a busy day: The European Union has ...

b"Georgia will withdraw 1,000 soldiers from Ir...

b'Why the Pentagon Thinks Attacking Iran is a ...

b'Caucasus in crisis: Georgia invades South Os...

b'Indian shoe manufactory - And again in a se...

b'Visitors Suffering from Mental Illnesses Ban...

b"No Help for Mexico's Kidnapping Surge"

1

2008-08-11

1

b'Why wont America and Nato help us? If they w...

b'Bush puts foot down on Georgian conflict'

b"Jewish Georgian minister: Thanks to Israeli ...

b'Georgian army flees in disarray as Russians ...

b"Olympic opening ceremony fireworks 'faked'"

b'What were the Mossad with fraudulent New Zea...

b'Russia angered by Israeli military sale to G...

b'An American citizen living in S.Ossetia blam...

...

b'Israel and the US behind the Georgian aggres...

b'"Do not believe TV, neither Russian nor Geor...

b'Riots are still going on in Montreal (Canada...

b'China to overtake US as largest manufacturer'

b'War in South Ossetia [PICS]'

b'Israeli Physicians Group Condemns State Tort...

b' Russia has just beaten the United States ov...

b'Perhaps *the* question about the Georgia - R...

b'Russia is so much better at war'

b"So this is what it's come to: trading sex fo...

2 rows × 27 columns

news_djia = news_djia.drop(labels='Label', axis=1)

# Some column values are not in string so convert them
news_djia = news_djia.apply(lambda x: x.map(lambda y: str(y)), axis=1)

# Remove starting b' and b" characters
news_djia = news_djia.apply(lambda x: x.map(lambda y: y.replace('b"', '').replace("b'", '').replace('"', '')), axis=1)

# Set each strings of articles to list of articles for bert_embedding
news_djia.iloc[:, 1:] = news_djia.iloc[:, 1:].apply(lambda x: x.map(lambda y: [y]), axis=1)

# Move Date to Index
news_djia = news_djia.set_index(news_djia.iloc[:, 0]).drop('Date', axis=1)
news_djia.head(2)

Top1

Top2

Top3

Top4

Top5

Top6

Top7

Top8

Top9

Top10

...

Top16

Top17

Top18

Top19

Top20

Top21

Top22

Top23

Top24

Top25

Date

2008-08-08

[Georgia 'downs two Russian warplanes' as coun...

[BREAKING: Musharraf to be impeached.']

[Russia Today: Columns of troops roll into Sou...

[Russian tanks are moving towards the capital ...

[Afghan children raped with 'impunity,' U.N. o...

[150 Russian tanks have entered South Ossetia ...

[Breaking: Georgia invades South Ossetia, Russ...

[The 'enemy combatent' trials are nothing but ...

[Georgian troops retreat from S. Osettain capi...

[Did the U.S. Prep Georgia for War with Russia?']

...

[Georgia Invades South Ossetia - if Russia get...

[Al-Qaeda Faces Islamist Backlash']

[Condoleezza Rice: The US would not act to pre...

[This is a busy day: The European Union has a...

[Georgia will withdraw 1,000 soldiers from Ira...

[Why the Pentagon Thinks Attacking Iran is a B...

[Caucasus in crisis: Georgia invades South Oss...

[Indian shoe manufactory - And again in a ser...

[Visitors Suffering from Mental Illnesses Bann...

[No Help for Mexico's Kidnapping Surge]

2008-08-11

[Why wont America and Nato help us? If they wo...

[Bush puts foot down on Georgian conflict']

[Jewish Georgian minister: Thanks to Israeli t...

[Georgian army flees in disarray as Russians a...

[Olympic opening ceremony fireworks 'faked']

[What were the Mossad with fraudulent New Zeal...

[Russia angered by Israeli military sale to Ge...

[An American citizen living in S.Ossetia blame...

[Welcome To World War IV! Now In High Definiti...

[Georgia's move, a mistake of monumental propo...

...

[Israel and the US behind the Georgian aggress...

[Do not believe TV, neither Russian nor Georgi...

[Riots are still going on in Montreal (Canada)...

[China to overtake US as largest manufacturer']

[War in South Ossetia [PICS]']

[Israeli Physicians Group Condemns State Tortu...

[ Russia has just beaten the United States ove...

[Perhaps *the* question about the Georgia - Ru...

[Russia is so much better at war']

[So this is what it's come to: trading sex for...

I removed the label column and moved the date values to index. Then I removed starting b' or b" since it is not an actual word that I need.

The reason I converted a string to a list of words is so that BERT will output values for each word.

It is possible some news articles contain non-alphanumeric but I did not preprocess them but doing so will likely improve a model.

You can download files necessary to run BERT from Mxnet BERT page. Also to run it, you have to install mxnet with pip.

import mxnet as mx

from mxnet import gluon
from bert.embedding import BertEmbedding

# Get GPU
ctx = mx.gpu()

# Define a model in GPU for faster training
bert_embedding = BertEmbedding(model='bert_12_768_12', dataset_name='book_corpus_wiki_en_cased', ctx=ctx)

You can change the model to another and this page has parameters for that. Additionally you can change the dataset to a different one. The model I loaded outputs an embedding in the shape of 768, as it can be seen in the name of model. Bigger number will generate bigger features which might boost accuracy of the model so feel free to try different models as well.

Next is the result of passing first two samples into BERT.

example_embedding = news_djia.iloc[:2, :].apply(lambda x: x.map(lambda y: bert_embedding(y)))
example_embedding

Top1

Top2

Top3

Top4

Top5

Top6

Top7

Top8

Top9

Top10

...

Top16

Top17

Top18

Top19

Top20

Top21

Top22

Top23

Top24

Top25

Date

2008-08-08

[([Georgia, ', downs, two, Russian, warplanes,...

[([BREAKING, :, Musharraf, to, be, impeached, ...

[([Russia, Today, :, Columns, of, troops, roll...

[([Russian, tanks, are, moving, towards, the, ...

[([Afghan, children, raped, with, ', impunity,...

[([150, Russian, tanks, have, entered, South, ...

[([Breaking, :, Georgia, invades, South, Osset...

[([The, ', enemy, combatent, ', trials, are, n...

[([Georgian, troops, retreat, from, S, ., Oset...

[([Did, the, U, ., S, ., Prep, Georgia, for, W...

...

[([Georgia, Invades, South, Ossetia, -, if, Ru...

[([Al, -, Qaeda, Faces, Islamist, Backlash, ']...

[([Condoleezza, Rice, :, The, US, would, not, ...

[([This, is, a, busy, day, :, The, European, U...

[([Georgia, will, withdraw, 1, ,, 000, soldier...

[([Why, the, Pentagon, Thinks, Attacking, Iran...

[([Caucasus, in, crisis, :, Georgia, invades, ...

[([Indian, shoe, manufactory, -, And, again, i...

[([Visitors, Suffering, from, Mental, Illnesse...

[([No, Help, for, Mexico, ', s, Kidnapping, Su...

2008-08-11

[([Why, wont, America, and, Nato, help, us, ?,...

[([Bush, puts, foot, down, on, Georgian, confl...

[([Jewish, Georgian, minister, :, Thanks, to, ...

[([Georgian, army, flees, in, disarray, as, Ru...

[([Olympic, opening, ceremony, fireworks, ', f...

[([What, were, the, Mossad, with, fraudulent, ...

[([Russia, angered, by, Israeli, military, sal...

[([An, American, citizen, living, in, S, ., Os...

[([Welcome, To, World, War, IV, !, Now, In, Hi...

[([Georgia, ', s, move, ,, a, mistake, of, mon...

...

[([Israel, and, the, US, behind, the, Georgian...

[([Do, not, believe, TV, ,, neither, Russian, ...

[([Riots, are, still, going, on, in, Montreal,...

[([China, to, overtake, US, as, largest, manuf...

[([War, in, South, Ossetia, [, PICS, ], '], [[...

[([Israeli, Physicians, Group, Condemns, State...

[([Russia, has, just, beaten, the, United, Sta...

[([Perhaps, *, the, *, question, about, the, G...

[([Russia, is, so, much, better, at, war, '], ...

[([So, this, is, what, it, ', s, come, to, :, ...

The output of bert_embedding is a tuple whose first entry is words and second is the floating values corresponding to each of them. Since I did not need any string values, I extracted numeric values by doing next.

def extract_features(x):

    # Compact code
    # return np.array(x[0][1]).sum(axis=0)

    features = np.array(x[0][1])
    features = features.sum(axis=0)

    return features
example_embedding = example_embedding.apply(lambda x: x.map(extract_features))
example_embedding

Top1

Top2

Top3

Top4

Top5

Top6

Top7

Top8

Top9

Top10

...

Top16

Top17

Top18

Top19

Top20

Top21

Top22

Top23

Top24

Top25

Date

2008-08-08

[3.43436, 0.89767, -2.27312, 2.62073, 0.792095...

[1.5841, 0.365171, -2.74592, -0.759266, 0.5633...

[1.66644, 0.786982, -1.82106, 0.194207, 1.0017...

[1.24759, -1.67627, -4.90526, -1.23436, -0.495...

[3.78023, -0.584944, -4.42655, -2.08476, 0.645...

[6.28313, -2.16985, -4.36524, -0.282691, -3.88...

[2.52342, -4.29657, -1.48065, -0.826757, 4.262...

[4.61817, 1.62902, -2.46977, 1.27654, 3.85087,...

[3.6748, -4.58673, -3.14633, -2.66208, 3.22038...

[2.27328, -6.11834, -3.83456, -3.41502, 0.6939...

...

[4.8692, -3.09887, -3.37449, 1.15127, 1.20513,...

[0.677572, 0.407966, -1.27538, -0.430122, -1.8...

[3.98538, -1.94795, -5.36948, -2.99529, -3.041...

[3.50761, 2.43183, -0.859873, 1.72922, 2.43535...

[3.2757, -2.02832, -3.94034, -5.52725, -3.2960...

[2.09109, 1.17632, -6.94091, -1.05991, 4.66838...

[1.05314, -1.48395, -0.210298, -0.0948398, 0.0...

[5.10344, -1.73553, -7.81096, 1.49675, 6.49278...

[0.960207, 3.25463, -2.08372, -0.192987, 0.030...

[0.967569, -0.0868791, -1.75524, 1.64098, -1.9...

2008-08-11

[4.37261, -2.79286, -5.87916, 6.66576, 2.75987...

[1.53141, 0.333403, -0.260142, 0.128509, -1.66...

[1.40451, -3.75495, -3.26506, -0.551889, -0.06...

[1.30285, -7.20199, -1.98171, -1.76504, 3.3063...

[0.382488, 1.93044, -0.987913, -0.084965, -0.9...

[2.7923, -1.36717, -0.413153, -2.22618, 0.3724...

[2.81514, -2.69294, -3.19405, -0.943336, -2.15...

[4.82897, -7.04651, 0.574161, -2.90421, 2.1165...

[-0.700163, 4.20339, 0.784252, 3.06448, 2.2954...

[1.78688, 0.486187, -0.594014, 1.10868, 1.0810...

...

[2.89638, -1.57175, -2.83468, -1.60865, 0.3098...

[7.5718, -2.25389, 0.510455, 0.161088, 5.25497...

[3.07508, 1.98748, -1.68678, 4.93265, 2.72879,...

[1.08836, 1.68378, -1.16536, -0.757239, 0.2008...

[1.1617, -0.767918, -1.46226, -1.37308, -0.755...

[0.307321, -0.462073, -1.55613, -0.213022, -1....

[-1.54555, -0.837098, -1.98399, 0.317458, 1.12...

[1.98323, -0.86789, -6.02319, 1.49993, -1.8145...

[0.65908, -1.02872, -2.99243, 0.0656184, 1.712...

[1.11363, 3.66216, -0.0606607, 4.42123, 3.8314...

With the function above, now I have a dataframe with 25 columns of numeric values. Same thing was applied to the whole dataset.

Using BERT model on CPU took more than an hour so I had to use on Google Clout Platform with one Tesla v4 which still took about 30 minutes.

news_embedding = news_djia.apply(lambda x: x.map(lambda y: bert_embedding(y)))

# Remove word and only keep numeric vectors
news_embedding = news_embedding.apply(lambda x: x.map(extract_features))

After that, I aggregated all columns into one.

news_embedding['combined'] = news_embedding.values.tolist()

news_embedding = news_embedding[['combined']]

news_embedding.head()

combined

Date

2008-08-08

[[3.43436, 0.89767, -2.27312, 2.62073, 0.79209...

2008-08-11

[[4.37261, -2.79286, -5.87916, 6.66576, 2.7598...

2008-08-12

[[3.6508, 2.65258, -2.76219, -0.521201, 1.1053...

2008-08-13

[[3.9178, -2.75983, -2.82817, -4.31737, 0.1892...

2008-08-14

[[0.886879, 0.293792, -2.50885, 1.20781, 0.921...

Each article differs in the number of words that the shape of each embedding is also different. So I cannot just put them into a model because then it will have to have flexible input size.

Instead, by using min, max, sum and mean over each data sample's embedding element-wise, I extracted extreme values. For example by using max, it will take the strongest features among others.

min_embedding = news_embedding['combined'].map(lambda x: np.min(x, axis=0)).to_frame()
max_embedding = news_embedding['combined'].map(lambda x: np.max(x, axis=0)).to_frame()
sum_embedding = news_embedding['combined'].map(lambda x: np.sum(x, axis=0)).to_frame()
mean_embedding = news_embedding['combined'].map(lambda x: np.mean(x, axis=0)).to_frame()
mean_embedding.head(2)

combined

Date

2008-08-08

[2.29378, -1.15325, -3.16382, -0.287952, 0.937...

2008-08-11

[1.84285, -1.09343, -1.54086, 0.335894, 1.0583...

# Save them for easier access later
path = 'embedding_files/'

min_embedding.to_json(path+'min_embedding.json')
max_embedding.to_json(path+'max_embedding.json')
sum_embedding.to_json(path+'sum_embedding.json')
mean_embedding.to_json(path+'mean_embedding.json')

I had to make a different post for actual model implementation because putting all together was too long for one. You can find it here.

Last updated