forked from borisdayma/dalle-mini
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
66 lines (54 loc) · 2.75 KB
/
model.py
File metadata and controls
66 lines (54 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import jax
import flax.linen as nn
from transformers.models.bart.modeling_flax_bart import (
FlaxBartModule,
FlaxBartForConditionalGenerationModule,
FlaxBartForConditionalGeneration,
FlaxBartEncoder,
FlaxBartDecoder
)
from transformers import BartConfig
# Model hyperparameters, for convenience
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
# we keep shared to easily load pre-trained weights
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# a separate embedding is used for the decoder
self.decoder_embed = nn.Embed(
self.config.vocab_size_output,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
# the decoder has a different config
decoder_config = BartConfig(self.config.to_dict())
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
decoder_config.vocab_size = self.config.vocab_size_output
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
def setup(self):
# check config is valid, otherwise set default values
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size_output,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
module_class = CustomFlaxBartForConditionalGenerationModule