classLlamaForCausalLM(LlamaPreTrainedModel):_tied_weights_keys=["lm_head.weight"]def__init__(self,config):super().__init__(config)self.model=LlamaModel(config)self.pretraining_tp=config.pretraining_tpself.vocab_size=config.vocab_sizeself.lm_head=nn.Linear(config.hidden_size,config.vocab_size,bias=False)# Initialize weights and apply final processingself.post_init()
classLlamaModel(LlamaPreTrainedModel):""" Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """def__init__(self,config:LlamaConfig):super().__init__(config)self.padding_idx=config.pad_token_idself.vocab_size=config.vocab_sizeself.embed_tokens=nn.Embedding(config.vocab_size,config.hidden_size,self.padding_idx)self.layers=nn.ModuleList([LlamaDecoderLayer(config)for_inrange(config.num_hidden_layers)])self.norm=LlamaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.gradient_checkpointing=False# Initialize weights and apply final processingself.post_init()
classLlamaRMSNorm(nn.Module):def__init__(self,hidden_size,eps=1e-6):""" LlamaRMSNorm is equivalent to T5LayerNorm """super().__init__()self.weight=nn.Parameter(torch.ones(hidden_size))self.variance_epsilon=epsdefforward(self,hidden_states):input_dtype=hidden_states.dtypehidden_states=hidden_states.to(torch.float32)variance=hidden_states.pow(2).mean(-1,keepdim=True)hidden_states=hidden_states*torch.rsqrt(variance+self.variance_epsilon)returnself.weight*hidden_states.to(input_dtype)
defrotate_half(x):"""Rotates half the hidden dims of the input."""x1=x[...,:x.shape[-1]//2]x2=x[...,x.shape[-1]//2:]returntorch.cat((-x2,x1),dim=-1)defapply_rotary_pos_emb(q,k,cos,sin,position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos=cos.squeeze(1).squeeze(0)# [seq_len, dim]sin=sin.squeeze(1).squeeze(0)# [seq_len, dim]cos=cos[position_ids].unsqueeze(1)# [bs, 1, seq_len, dim]sin=sin[position_ids].unsqueeze(1)# [bs, 1, seq_len, dim]q_embed=(q*cos)+(rotate_half(q)*sin)k_embed=(k*cos)+(rotate_half(k)*sin)returnq_embed,k_embed
classLlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def__init__(self,config:LlamaConfig):super().__init__()self.config=configself.hidden_size=config.hidden_sizeself.num_heads=config.num_attention_headsself.head_dim=self.hidden_size//self.num_headsself.num_key_value_heads=config.num_key_value_headsself.num_key_value_groups=self.num_heads//self.num_key_value_headsself.pretraining_tp=config.pretraining_tpself.max_position_embeddings=config.max_position_embeddingsif(self.head_dim*self.num_heads)!=self.hidden_size:raiseValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"f" and `num_heads`: {self.num_heads}).")self.q_proj=nn.Linear(self.hidden_size,self.num_heads*self.head_dim,bias=False)self.k_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False)self.v_proj=nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False)self.o_proj=nn.Linear(self.num_heads*self.head_dim,self.hidden_size,bias=False)self._init_rope()
defrepeat_kv(hidden_states:torch.Tensor,n_rep:int)->torch.Tensor:""" This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """batch,num_key_value_heads,slen,head_dim=hidden_states.shapeifn_rep==1:returnhidden_stateshidden_states=hidden_states[:,:,None,:,:].expand(batch,num_key_value_heads,n_rep,slen,head_dim)returnhidden_states.reshape(batch,num_key_value_heads*n_rep,slen,head_dim)
# upcast attention to fp32attn_weights=nn.functional.softmax(attn_weights,dim=-1,dtype=torch.float32).to(query_states.dtype)attn_output=torch.matmul(attn_weights,value_states)
loss=NoneiflabelsisnotNone:# Shift so that tokens < n predict nshift_logits=logits[...,:-1,:].contiguous()shift_labels=labels[...,1:].contiguous()# Flatten the tokensloss_fct=CrossEntropyLoss()shift_logits=shift_logits.view(-1,self.config.vocab_size)shift_labels=shift_labels.view(-1)# Enable model parallelismshift_labels=shift_labels.to(shift_logits.device)loss=loss_fct(shift_logits,shift_labels)