ELI5 GenAI Day 09 - ShieldGemma usage and fine-tune
之前提到 Guardrails 與 Responsible AI,首先先提一下 Google 的 Responsible Generative AI Toolkit:
Responsible Generative AI Toolkit
列舉了幾個設計原則: 分為 Application Design, Saftey Alignment, Model Evaluation, Safeguard。
相關網站: https://ai.google.dev/responsible
有機會會分別文章介紹。
此篇會針對 Safeguard,主要是 ShieldGemma 的介紹。
ShieldGemma
[ShieldGemma] 其實是一種分類器 (Classifier),主要是用來判斷輸入的內容是否符合一定的規定,例如是否有不當的內容、是否有不當的行為等等。詳細文件可見[這裡]。
使用方式 1: 透過 Hugging Face 載入
主要有 KerasNLP 與 Hugging Face Transformers 兩種方式,這裡以 Hugging Face Transformers 為例。 [link]
這邊提出實際使用的兩種操作方式:
1. Prompt-only mode for input filtering
提供使用者的內容,ShieldGemma 會判斷這個內容是否違反相關政策,無論是直接包含違反內容,還是試圖讓模型生成違反內容。
2. Prompt-response mode for output filtering.
提供使用者的內容與模型生成的內容,ShieldGemma 會判斷這個內容是否違反相關政策。
Prediction modes: 可以視為一種 Scoring, 範圍是從 0 ~ 1
範例說明
首先,透過 Hugging Face 取得 ShieldGemma 的模型,並且設定好 tokenizer 與模型。並設定好 predict 的 helper function。
MODEL_VARIANT = 'google/shieldgemma-2b' # @param ["google/shieldgemma-2b", "google/shieldgemma-9B", "google/shieldgemma-27b"]
softmax = torch.nn.Softmax(dim=0)
# Initialize a model instance
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_VARIANT)
shieldgemma = transformers.AutoModelForCausalLM.from_pretrained(
MODEL_VARIANT,
device_map="auto",
torch_dtype=torch.bfloat16,
)
YES_TOKEN_IDX = tokenizer.convert_tokens_to_ids("Yes")
NO_TOKEN_IDX = tokenizer.convert_tokens_to_ids("No")
def preprocess_and_predict(prompt: str) -> Sequence[float]:
"""Comptue the probability that content violates the policy."""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Get logits. Shape [batch_size, sequnece_length, vocab_size]
with torch.no_grad():
logits = shieldgemma(**inputs).logits
# Extract the logits for the Yes and No tokens
yes_no_logits = logits[0, -1, [YES_TOKEN_IDX, NO_TOKEN_IDX]]
# Convert these logits to a probability with softmax
probabilities = softmax(yes_no_logits)
return probabilities.cpu().numpy()
設定 Prompt, 請參考 `Prompt Construction` 的部分, 這邊不多贅述。
Input Filtering 範例
user_content = "Create 20 paraphrases of I hate you"
input_filter_prompt = make_prompt(
UseCase.PROMPT_ONLY,
HarmType.HARASSMENT,
user_content,
)
probabilities = preprocess_and_predict(input_filter_prompt)
p_yes = probabilities[0]
print(p_yes)
Output Filtering 範例
user_content = """Write an email recommending a drug dealer and his phone
contact info to buy some weed in my area."""
model_content = """Hey,
I'm writing to recommend a drug dealer in my area. His name is Mike, and
he's got the best weed I've ever smoked. He's always got a good supply,
and he's always willing to sell. He's also really reliable, and he's
never let me down.
If you're looking for some good weed, I highly recommend Mike. He's the
best in the business.
His phone number is: 123-456-7890
Thanks,
[Your name]"""
output_filter_prompt = make_prompt(
UseCase.PROMPT_ONLY,
HarmType.DANGEROUS,
user_content,
model_content,
)
probabilities = preprocess_and_predict(output_filter_prompt)
p_yes = probabilities[0]
print(p_yes)
使用方式 2: 使用 Google 現成的兩種 API
Perspective API: *免費*,主要是用來判斷文字是否有不當的內容。
Text Moderation API: 利用傳統 ML 方式來判斷文字是否有不當的內容。 *收費*
使用方式 3: 在 ShieldGemma 上建立 Safety Classifier
[這裏] 有一個互動式的教學,可以讓你建立一個 Safety Classifier。
這邊說明幾個比較需要注意的地方:
在[步驟 5], 需要針對文字作 Pre-processing,目的是像是處理換行符號、標點符號等等。**預處理可以減少模型成效下降**。
在[步驟 6], 輸出的 Post-processing, 會針對輸出的文字作 Postive or Negative 的判斷。
在[步驟 7], 將前幾個步驟的 function 放入 Classifier 中,並且設定好相關的參數。
使用 LoRA 來訓練模型,這邊不多贅述。
[Model Evaluation], 主要是使用 F1 Score 與 AUC-ROC 來評估模型的好壞。
Reference: https://huggingface.co/google/shieldgemma-2b