Skip to content

[Feature Request] Support input embedding in LLM.generate() #416

@KimmiShi

Description

@KimmiShi

Hi, I am using llm as part of a multimodal model, so the model needs to pass input embedding tensor directly to generate, and also need to access the language model's embed_tokens member to fist calculate the embedding, and then processed, finnaly send to generate, demo in the following code :

        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

        prefix_embeds = inputs_embeds[:, :self.offset, :]
        postfix_embeds = inputs_embeds[:, self.offset:, :]
        inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)

        .....
        attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)

        outputs = self.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            **generate_kwargs,
        )

I read the vllm code, and it seems that I need to add two interfaces in vllm, one is LLM.get_input_embeddings, another one is LLM.generate(inputs_embeds=inputs_embeds, ...)

Do you think this will work? And would you consider support this feature?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions