# 1. 字符串相等
import re
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
if solution == ground_truth:
reward = 1.0 # Exact match
elif has_keyword_overlap(solution, ground_truth):
reward = 0.5 # Partial match based on keyword overlap
else:
reward = 0.0 # No meaningful match
return reward
def has_keyword_overlap(text1: str, text2: str) -> bool:
keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
overlap = keywords1.intersection(keywords2)
overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
return overlap_ratio > 0.3 # Threshold can be adjusted
# 2. 字符串包含
import re
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
set_solution = set(tokenize(solution))
set_ground_truth = set(tokenize(ground_truth))
if not set_solution and not set_ground_truth:
similarity = 1.0
else:
intersection = set_solution.intersection(ground_truth)
union = set_solution.union(ground_truth)
similarity = len(intersection) / len(union)
similarity = max(0.0, min(similarity, 1.0))
return similarity
def tokenize(text: str) -> List[str]:
return re.findall(r'\b\w+\b', text.lower())
# 3. 字符串相似度比较
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
distance = levenshtein_distance(solution, ground_truth)
max_len = max(len(solution), len(ground_truth))
if max_len == 0:
similarity = 1.0
else:
similarity = 1 - (distance / max_len)
similarity = max(0.0, min(similarity, 1.0))
return similarity
def levenshtein_distance(s1: str, s2: str) -> int:
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1 # insertion
deletions = current_row[j] + 1 # deletion
substitutions = previous_row[j] + (c1 != c2) # substitution
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
在这篇文章中: