• <xmp id="om0om">
  • <table id="om0om"><noscript id="om0om"></noscript></table>
  • Conversational AI / NLP

    SteerLM: ?? ?? LLM? ?? ??? ? ?? ???? ???? ??

    Reading Time: 6 minutes

    ?? ? ? ?? GPT-3, Megatron-Turing, Chinchilla, PaLM-2, Falcon, Llama 2? ?? ??? ?? ??(LLM)? ???? ??? ??? ??? ?? ??? ???????. ??? ??? ??? ???? ??? ? ?? ???? ???? ?? LLM? ??? ???? ?? ???? ??? ??? ???? ?? ? ????.

    ?? LLM? ???? ?? ?? ??? ??? ??? ?? ?? ?? ??(SFT)? ??? ???? ?? ?? ??(RLHF)? ?????. RLHF? ??? ???? ? ???, ??? ??? ? ??? ?? ?? ? ? ?? ??? ????.

    ??? ??? ???? ?? NVIDIA ???? ???? ??? ??? ?? ?? ??? ???? ??? ? ?? ??? LLM ??????? ????? ??? 4?? ??? SteerLM? ???? NVIDIA NeMo? ??? ??????. ? ?????? SteerLM? ?? ??, ? ??? ??? ??, SteerLM ??? ???? ??? ?? ??? ?????.

    ?? ??? ???? ??? ??

    ??? ??? ???? ?? ?? ??? ?? LLM? ???? ?? ??? ?? ??? ?????. ????? ??, ????, ??? ?? ? ??? ??? ??(NLP) ??? LLM? ????? ??????. ??? ??? ??? ???? ??? ??? ??? ??? ????? ?????? ???? ???? ???? ??? ????. ??? ???? ?? ???? LLM? ??? ???? ? ??????.

    ?? ?? ??? ??

    SFT? ?? ??? ????? ??? ?? ??????. RLHF? ??(althernative)?? ??? ???? ??? ?????? ??? ??????. ??? RLHF? ?? ??? ???? ???? ???? ??? ???? ??? ???? ???.

    SteerLM ??

    SteerLM? ?? ?? ??? ??? ? ?? ?? ?? ?? ??? ?????. ? ??? ?? ?? ??? ??? ????, ? ?? ?? ??? ?????:

    1. ??? ??? ? ??? ??? ?? ?? ?? ??? ???? ??, ??, ??? ? ??? ??? ?? ?? ??? ?????.
    2. 1??? ??? ???? ?? ??? ???? ??? ??? ??? ??? ?? ??? ??? ? ?? ???? ???? ?????.
    3. ???? ???? ?? ? ???? ?? ??? ?? ??? ?? ??? ??? ????? LLM? ???? ?? ??? SFT? ?????.
    4. ?? ??? ?? ??? ??? ???? ?? ???? ?? ????? ??? ??? ??(?? 1, 4a), ?? ?? ???? ??? ?? ?????(?? 1, 4b).
    ?? 1. SteerLM? 4??

    ?? ?? ??? ???? ???? SteerLM? RLHF? ?? ?????? ??????. ?? ??? ??? ??? ? ??? ?? ???? ?? ??? AI? ?????. ??? ?? ??? ?? ??? ???? ?? ?? ??? ?? ???? ??????? ??? ?? ??? ??? ? ????.

    ??? ??? ?? ??? AI ??

    ???? ??? ??? ? ??? ? ??? ??(?: ?? ?? ? ?? ???)? ??? ? ??? ?? ?? SteerLM? ?? ?????. ???? SteerLM?? ??? ??? ??? ???? ??? ?? ??? ?? ?? ??? ???? ??? ??? ? ????.

    SteerLM? ??? ?? ??? ??????? ?????:

    ??? ???? ?? ??? ??? ?? ??? ??? AI ???? ??? ? ????.

    ???? ??? ?? ??? ??????? ???

    ?? ?? ?????? ??? ??? ???? ???? ???? ?, SteerLM? ??? ?? ??? ???? ??? ?????? ??? ? ?? ??? ? ??? ????. ? ??? ??? ??? ??? ?? ?? ??? ?? ??? ???? ??? ?? ??? ?????.

    SFT? ?? ?? ??? ???? ???? ????? ???? ?? ??? ???? ? ????. ??? ??????? ???? ?? ???? ??? ?? ? ????.

    ?????, ?? ?? ??? ??? LLM? ?? ?? ???? ???? ?? ???? ?????. ?? ??, SteerLM 43B? Vicuna ?????? LLaMA 30B RLHF? ?? ?? RLHF ??? ???? ??? ??? ??????. ??, Vicuna ?? ???? ?? 655.75?? ??? ??, Guanaco 65B? 646.25?, LLaMA 30B RLHF? 612.75?? ??????.

    ??? ??? SteerLM? ??? ???? ????? ?? ? ??? RLHF ??? ??? ???? ??? LLM? ??? ? ??? ?? ?????. ???? ????? ?? ?? ??? ???? ???? ?? ?? ??? ????? ?? ??????? ???? ? ????.

    ??? ??? ??? ??? SteerLM: RLHF? (??? ?? ???) ?????? ?? ??? SFT? ?????. ?? SteerLM ??? ???? ??????? ?? 2 13B ??? ???? ??? ?? ??? ?? ? ????.

    SteerLM ??? ???? ??

    ? ????? 2B NeMo LLM ??? ???? OASST ????? ?? SteerLM ?????? ???? ??? ???? ???? ??????. ???? ??? ?????:

    • ??? ?? ? ???
    • ?? ?? ??(? ??)
    • ?? ??? SFT(SteerLM ??) ??
    • ?? ?? ?? SteerLM ??? ?? ??

    1??: ?? ?? ??

    ?? ??? Python ?????? ?????:

    pip install fire langchain==0.0.133

    NeMo? ??????.

    2??: ??? ???? ? ?? ??

    ? ??????? OASST ??? ??? ?? ?? ??? ?????. OASST?? 13?? ?? ??? ?? ??? ??? ?? ?? ??? ??? ???? ????.

    ?? ???? ?????? ?? ???? ????:

    mkdir -p data
    cd data
    
    wget https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz
    
    gunzip -f 2023-04-12_oasst_all.trees.jsonl.gz
    
    mv 2023-04-12_oasst_all.trees.jsonl data.jsonl
    
    head -5000 data.jsonl > subset_data.jsonl
    
    cd -

    3??: Llama 2 LLM ?? ? ??? ?? ???? ? ??

    Llama 2 7B LLM ??? ?????? models ??? ???????.

    ?? ?? Llama 2 LLM? .nemo ???? ?????:

    python NeMo/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py --in-file /path/to/llama --out-file /output_path/llama7b.nemo

    .nemo ??? ?? ???? NeMo ??? ????? ????:

    tar <path-to-model>/llama7b.nemo mv ba4632640484461f8ae9d61f6dfe0d0b_tokenizer.model tokenizer.model

    ????? ???? ??? ? ??? ? ????. ?? ??? ??? ? ??? ???? ??? ????? ?????.

    4??: OASST ??? ???

    NeMo ??? ????? ???? ???? ??????. ?? ?? ??? ???-? ? ?-??? ??? ?????:

    python scripts/nlp_language_modeling/sft/preprocessing.py \
        --input_file=data/subset_data.jsonl \
        --output_file_prefix=data/subset_data_output \
        --mask_role=User \
        --type=TEXT_TO_VALUE \
        --split_ratio=0.95 \
        --seed=10
    
    python scripts/nlp_language_modeling/sft/preprocessing.py \
        --input_file=data/subset_data.jsonl \
        --output_file_prefix=data/subset_data_output_v2t \
        --mask_role=User \
        --type=VALUE_TO_TEXT \
        --split_ratio=0.95 \
        --seed=10

    5??: ???-? ??? ??

    ??? ??? ?? ???? ?? ?? ??? ???? ?? ?? ????? ???? ???? ?????.

    python scripts/nlp_language_modeling/sft/data_clean.py \
        --dataset_file=data/subset_data_output_train.jsonl \
        --output_file=data/subset_data_output_train_clean.jsonl \
        --library sentencepiece \
        --model_file tokenizer.model \
        --seq_len 4096
    
    python scripts/nlp_language_modeling/sft/data_clean.py \
        --dataset_file=data/subset_data_output_val.jsonl \
        --output_file=data/subset_data_output_val_clean.jsonl \
        --library sentencepiece \
        --model_file tokenizer.model \
        --seq_len 4096

    6??: ??? OASST ???? ? ?? ????

    ? ??????? 1K ??? ?? ? ??? ?????. ?? ?? ??? ???? ? ?? ???? ?? ?? ? ?? ???? ?? ????.

    python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \
        ++trainer.limit_val_batches=10 \
        trainer.num_nodes=1 \
        trainer.devices=2 \
        trainer.max_epochs=null \
        trainer.max_steps=1000 \
        trainer.val_check_interval=100 \
        trainer.precision=bf16 \
        model.megatron_amp_O2=False \
        model.restore_from_path=/model/llama7b.nemo \
        model.tensor_model_parallel_size=2 \
        model.pipeline_model_parallel_size=1 \
        model.optim.lr=5e-6 \
        model.optim.name=distributed_fused_adam \
        model.optim.weight_decay=0.01 \
        model.answer_only_loss=True \
        model.activations_checkpoint_granularity=selective \
        model.activations_checkpoint_method=uniform \
        model.data.chat=True \
        model.data.train_ds.max_seq_length=4096 \
        model.data.train_ds.micro_batch_size=1 \
        model.data.train_ds.global_batch_size=1 \
      model.data.train_ds.file_names=[data/subset_data_output_train_clean.jsonl] \
        model.data.train_ds.concat_sampling_probabilities=[1.0] \
        model.data.train_ds.num_workers=0 \
    ??    model.data.train_ds.hf_dataset=True \
    
        model.data.train_ds.prompt_template='\{input\}\{output\}' \
        model.data.train_ds.add_eos=False \
        model.data.validation_ds.max_seq_length=4096 \
        model.data.validation_ds.file_names=[data/subset_data_output_val_clean.jsonl] \
        model.data.validation_ds.names=["oasst"] \
        model.data.validation_ds.micro_batch_size=1 \
        model.data.validation_ds.global_batch_size=1 \
        model.data.validation_ds.num_workers=0 \
        model.data.validation_ds.metric.name=loss \
        model.data.validation_ds.index_mapping_dir=/indexmap_dir \
        model.data.validation_ds.hf_dataset=True \
    
    model.data.validation_ds.prompt_template='\{input\}\{output\}' \
        model.data.validation_ds.add_eos=False \
        model.data.test_ds.max_seq_length=4096 \
        model.data.test_ds.file_names=[data/subset_data_output_val_clean.jsonl] \
        model.data.test_ds.names=["oasst"] \
        model.data.test_ds.micro_batch_size=1 \
        model.data.test_ds.global_batch_size=1 \
        model.data.test_ds.num_workers=0 \
        model.data.test_ds.metric.name=loss \
        model.data.test_ds.hf_dataset=True \
        model.data.test_ds.prompt_template='\{input\}\{output\}' \
        model.data.test_ds.add_eos=False \
        exp_manager.explicit_log_dir="/home/value_model/" \
        exp_manager.create_checkpoint_callback=True \
        exp_manager.checkpoint_callback_params.monitor=val_loss \
        exp_manager.checkpoint_callback_params.mode=min

    7??: ?? ??

    ??? ????? ??????? ?? ??? ???? ?? ??? ?????:

    python examples/nlp/language_modeling/megatron_gpt_eval.py \
            gpt_model_file=/models/<TRAINED_ATTR_PREDICTION_MODEL.nemo> \
            pipeline_model_parallel_split_rank=0 \
            server=True \
            tensor_model_parallel_size=1 \
            pipeline_model_parallel_size=1 \
            trainer.precision=bf16 \
            trainer.devices=1 \
            trainer.num_nodes=1 \
            web_server=False \
            port=1424

    ?? ??? ?????:

    python scripts/nlp_language_modeling/sft/attribute_annotate.py  --batch_size=1 --host=localhost --input_file_name=data/subset_data_output_v2t_train.jsonl --output_file_name=data/subset_data_v2t_train_value_output.jsonl --port_num=1424
    
    python scripts/nlp_language_modeling/sft/attribute_annotate.py  --batch_size=1 --host=localhost --input_file_name=data/subset_data_output_v2t_val.jsonl --output_file_name=data/subset_data_v2t_val_value_output.jsonl --port_num=1424

    8??: ?-??? ??? ????
    ??? ??? ?? ??? ? ?? ??? ???? ?? ???? ?????:

    python scripts/data_clean.py \ –dataset_file=data/subset_data_v2t_train_value_output.jsonl \ –output_file=data/subset_data_v2t_train_value_output_clean.jsonl \ –library sentencepiece \ –model_file tokenizer.model \ –seq_len 4096 python scripts/data_clean.py \ –dataset_file=data/subset_data_v2t_val_value_output.jsonl \ –output_file=data/subset_data_v2t_val_value_output_clean.jsonl \ –library sentencepiece \ –model_file tokenizer.model \ –seq_len 4096

    9??: SteerLM ?? ????

    ? ????? ??? ??, SteerLM ??? 1K ??? ?????. ? ??? ??? ???? ?? ? ?? ???? ? ?? ???? ?? ????.

    python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \ ++trainer.limit_val_batches=10 \ trainer.num_nodes=1 \ trainer.devices=2 \ trainer.max_epochs=null \ trainer.max_steps=1000 \ trainer.val_check_interval=100 \ trainer.precision=bf16 \ model.megatron_amp_O2=False \ model.restore_from_path=/model/llama7b.nemo \ model.tensor_model_parallel_size=2 \ model.pipeline_model_parallel_size=1 \ model.optim.lr=5e-6 \ model.optim.name=distributed_fused_adam \ model.optim.weight_decay=0.01 \ model.answer_only_loss=True \ model.activations_checkpoint_granularity=selective \ model.activations_checkpoint_method=uniform \ model.data.chat=True \ model.data.train_ds.max_seq_length=4096 \ model.data.train_ds.micro_batch_size=1 \ model.data.train_ds.global_batch_size=1 \ model.data.train_ds.file_names=[data/subset_data_v2t_train_value_output_clean.jsonl] \ model.data.train_ds.concat_sampling_probabilities=[1.0] \ model.data.train_ds.num_workers=0 \ model.data.train_ds.prompt_template='\{input\}\{output\}' \ model.data.train_ds.add_eos=False \ model.data.validation_ds.max_seq_length=4096 \ model.data.validation_ds.file_names=[data/subset_data_v2t_val_value_output_clean.jsonl] \ model.data.validation_ds.names=["oasst"] \ model.data.validation_ds.micro_batch_size=1 \ model.data.validation_ds.global_batch_size=1 \ model.data.validation_ds.num_workers=0 \ model.data.validation_ds.metric.name=loss \ model.data.validation_ds.index_mapping_dir=/indexmap_dir \ model.data.validation_ds.prompt_template='\{input\}\{output\}' \ model.data.validation_ds.add_eos=False \ model.data.test_ds.max_seq_length=4096 \ model.data.test_ds.file_names=[data/subset_data_v2t_val_value_output_clean.jsonl] \ model.data.test_ds.names=["oasst"] \ model.data.test_ds.micro_batch_size=1 \ model.data.test_ds.global_batch_size=1 \ model.data.test_ds.num_workers=0 \ model.data.test_ds.metric.name=loss \ model.data.test_ds.prompt_template='\{input\}\{output\}' \ model.data.test_ds.add_eos=False \ exp_manager.explicit_log_dir="/home/steerlm_model/" \ exp_manager.create_checkpoint_callback=True \ exp_manager.checkpoint_callback_params.monitor=val_loss \ exp_manager.checkpoint_callback_params.mode=min

    10??: ??

    ??? ????? ?? ??? ???? ??????? ?? ??? ?????:

    python examples/nlp/language_modeling/megatron_gpt_eval.py \
            gpt_model_file=/models/<TRAINED_STEERLM_MODEL.nemo> \
            pipeline_model_parallel_split_rank=0 \
            server=True \
            tensor_model_parallel_size=1 \
            pipeline_model_parallel_size=1 \
            trainer.precision=bf16 \
            trainer.devices=1 \
            trainer.num_nodes=1 \
            web_server=False \
            port=1427

    ?????, Python ?? ??? ?????.

    def get_answer(question, max_tokens, values, eval_port='1427'):
     
        prompt = f"""<extra_id_0>System
    A chat between a curious user and an artificial intelligence assistant. 
    The assistant gives helpful, detailed, and polite answers to the user's questions.
     
    <extra_id_1>User
     
    {question}
     
    <extra_id_1>Assistant
     
    <extra_id_2>{values}
     
    """
     
        prompts = [prompt]
        data = {
            "sentences": prompts,
            "tokens_to_generate": max_tokens,
            "top_k": 1,
            'greedy': True,
            'end_strings': ["<extra_id_1>", "quality:", "quality:4", "quality:0"]
        }
     
        url = f"http://localhost:{eval_port}/generate"
        response = requests.put(url, json=data)
        json_response = response.json()
     
        response_sentence = json_response['sentences'][0][len(prompt):]
     
        return response_sentence
    def encode_labels(labels):
        items = []
        for key in labels:
            value = labels[key]
            items.append(f'{key}:{value}')
        return ','.join(items)
    

    ?? ?? ?? ?? ???? ?? ??? ?????:

    values = OrderedDict([
        ('quality', 4),
        ('toxicity', 0),
        ('humor', 0),
        ('creativity', 0),
        ('violence', 0),
        ('helpfulness', 4),
        ('not_appropriate', 0),
        ('hate_speech', 0),
        ('sexual_content', 0),
        ('fails_task', 0),
        ('political_content', 0),
        ('moral_judgement', 0),
    ])
    values = encode_labels(values)

    ????? ??? ?? ??? ?????:

    question = """Where and when did techno music originate?"""
    print (get_answer(question, 4096, values))

    ? ????? ??? ???? ? ????? ???? ?? ?? ??? ??? ? ????. ? ??? ??? ?????? ?? ???? ?? ????? ? ??? ? ? ????.

    SteerLM? ?? AI? ??

    SteerLM? ??? ???? ?? ?? ??? ???? ??? AI ???? ??? ? ?? ??? ??? ?????. ??? ???, ?? ??, ?????? ??? ???? ??? ? ?? AI? ??? ???? ?????. SteerLM? ?? ?? ?? ?????? ????, NVIDIA/NeMo GitHub ?????? ?? ???? ? ????. ?? SteerLM ??? ???? ??????? ?? 2 13B ??? ???? ??? ?? ??? ?? ? ????.

    ??? ?????? ?? ? ??? ?? SteerLM? ??? ?? AI ??? ??, ?????? ? ??? ? ?? ??? ?????? NVIDIA NeMo? ??? ?????. SteerLM ??? ?? 2(Llama 2), ??(Falcon) LLM, MPT ? ?????? ??? ?? ?? ?? ??? LLM? ???? NeMo?? ???? ?? ???? ?????. ??? ??? ???? ???? ?? ??? ??? ???? ??? ???? ?? ?? ??? ???? ??? ????. AI? ??? SteerLM?? ??? ? ?????.

    ??? ?

    ? ???? SteerLM? ??? ??? ?? Xianchao Wu? Oleksii Kuchaiev?? ??? ??? ????.

    ?? ???

    Discuss (0)
    +1

    Tags

    人人超碰97caoporen国产