Skip to content

Commit

Permalink
llama: define architecture for small granite models
Browse files Browse the repository at this point in the history
it works only for the small models 3b and 8b.

There are enough differences with the base llama arch that it is
worth to define a new architecture.

To create the .gguf files, it is necessary to specify
GraniteSmallForCausalLM in the architectures for the hf model.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
  • Loading branch information
giuseppe committed May 23, 2024
1 parent d52b4d8 commit 1fb9186
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
28 changes: 28 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,34 @@ def set_vocab(self, *args, **kwargs):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

@Model.register("GraniteSmallForCausalLM")
class GraniteModel(Model):
model_arch = gguf.MODEL_ARCH.GRANITE_SMALL

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_vocab(self):
tokens, toktypes, _ = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre("starcoder")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_name("GraniteSmall")
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
return [(self.map_tensor_name(name), data_torch)]


###### CONVERSION LOGIC ######

Expand Down
26 changes: 26 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
GRANITE_SMALL = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -218,6 +219,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.GRANITE_SMALL: "granite-small",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -732,6 +734,26 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GRANITE_SMALL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down Expand Up @@ -765,6 +787,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.GRANITE_SMALL: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}

#
Expand Down
44 changes: 44 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ enum llm_arch {
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
LLM_ARCH_GRANITE_SMALL,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -257,6 +258,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_GRANITE_SMALL, "granite-small" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1032,6 +1034,32 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GRANITE_SMALL,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -4344,6 +4372,16 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GRANITE_SMALL:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break;
case 36: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

Expand Down Expand Up @@ -4453,6 +4491,9 @@ static void llm_load_vocab(
} else {
if (tokenizer_model == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
if (model.arch == LLM_ARCH_GRANITE_SMALL) {
vocab.add_space_prefix = false;
}
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
Expand Down Expand Up @@ -5023,6 +5064,7 @@ static bool llm_load_tensors(
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE_SMALL:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

Expand Down Expand Up @@ -10893,6 +10935,7 @@ static struct ggml_cgraph * llama_build_graph(

switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_GRANITE_SMALL:
{
result = llm.build_llama();
} break;
Expand Down Expand Up @@ -16038,6 +16081,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_GEMMA:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_GRANITE_SMALL:
return LLAMA_ROPE_TYPE_NEOX;

// all model arches should be listed explicitly here
Expand Down

0 comments on commit 1fb9186

Please sign in to comment.