Skip to content

Commit

Permalink
Update phishing_email_detection_gpt2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
david-thrower committed Jun 13, 2024
1 parent 0340045 commit 1566908
Showing 1 changed file with 4 additions and 44 deletions.
48 changes: 4 additions & 44 deletions phishing_email_detection_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import zero_7_exp_decay, zero_95_exp_decay, simple_sigmoid
from ast import literal_eval

from custom.custom import GPT2Layer

#
# Load the email data
#
Expand Down Expand Up @@ -73,50 +75,8 @@
INPUT_SHAPES = [()]
OUTPUT_SHAPES = [1]

"""### A custom GPT2 encoder layer for text embedding"""

class GPT2Layer(tf.keras.layers.Layer):

def __init__(self, max_seq_length, **kwargs):
#
super(GPT2Layer, self).__init__(**kwargs)
#
# Load the GPT2 tokenizer, preprocessor and model
self.tokenizer = GPT2Tokenizer.from_preset("gpt2_base_en")
self.preprocessor = GPT2Preprocessor(self.tokenizer,
sequence_length=max_seq_length)
self.encoder = GPT2Backbone.from_preset("gpt2_base_en")
#
# Set whether the GPT2 model's layers are trainable
#self.encoder.trainable = False
for layer in self.encoder.layers:
layer.trainable = False
#
self.encoder.layers[-2].trainable = True
#
# Set the maximum sequence length for tokenization
self.max_seq_length = max_seq_length

def call(self, inputs):
#
# Output the GPT2 embedding
prep = self.preprocessor([inputs])
embedding = self.encoder(prep)
avg_pool = tf.reduce_mean(embedding, axis=1)
#
return avg_pool

def get_config(self):
#
config = super(GPT2Layer, self).get_config()
config.update({'max_seq_length': self.max_seq_length})
#
return config

@classmethod
def from_config(cls, config):
#
return cls(max_seq_length=config['max_seq_length'])



# GPT2 configurables
max_seq_length = 96
Expand Down

0 comments on commit 1566908

Please sign in to comment.