Source code for rl4lms.envs.text_generation.caption_metrics.spice.spice

"""
Adapted from https://github.com/INK-USC/CommonGen/tree/master/evaluation/Traditional/eval_metrics/spice
"""

from __future__ import division
import os
import subprocess
import json
import numpy as np
import tempfile
import spacy

# Assumes spice.jar is in the same directory as spice.py.  Change as needed.
SPICE_JAR = 'spice-1.0.jar'
TEMP_DIR = 'tmp'
CACHE_DIR = 'cache'


[docs]class Spice: """ Main Class to compute the SPICE metric """
[docs] def __init__(self) -> None: self._nlp = spacy.load("en_core_web_sm") # keep only tagger for pipe in ["tok2vec", "parser", "ner", "attribute_ruler", "lemmatizer"]: self._nlp.remove_pipe(pipe)
[docs] def float_convert(self, obj): try: return float(obj) except: return np.nan
[docs] def tokenize(self, dict): for key in dict: new_sentence_list = [] for sentence in dict[key]: a = '' for token in self._nlp(str(sentence)): a += token.text a += ' ' new_sentence_list.append(a.rstrip()) dict[key] = new_sentence_list return dict
[docs] def compute_score(self, gts, res): # tokenize gts = self.tokenize(gts) res = self.tokenize(res) assert(sorted(gts.keys()) == sorted(res.keys())) imgIds = sorted(gts.keys()) # Prepare temp input file for the SPICE scorer input_data = [] for id in imgIds: hypo = res[id] ref = gts[id] # Sanity check. assert(type(hypo) is list) assert(len(hypo) == 1) assert(type(ref) is list) assert(len(ref) >= 1) input_data.append({ "image_id": id, "test": hypo[0], "refs": ref }) cwd = os.path.dirname(os.path.abspath(__file__)) temp_dir = os.path.join(cwd, TEMP_DIR) if not os.path.exists(temp_dir): os.makedirs(temp_dir) in_file = tempfile.NamedTemporaryFile( mode="w", delete=False, dir=temp_dir) json.dump(input_data, in_file, indent=2) in_file.close() # Start job out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) out_file.close() cache_dir = os.path.join(cwd, CACHE_DIR) if not os.path.exists(cache_dir): os.makedirs(cache_dir) spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, '-cache', cache_dir, '-out', out_file.name, '-subset', '-silent' ] subprocess.check_call(spice_cmd, cwd=os.path.dirname(os.path.abspath(__file__))) # Read and process results with open(out_file.name) as data_file: results = json.load(data_file) os.remove(in_file.name) os.remove(out_file.name) imgId_to_scores = {} spice_scores = [] individual_scores = {} for item in results: imgId_to_scores[item['image_id']] = item['scores'] spice_scores.append(self.float_convert(item['scores']['All']['f'])) individual_scores[item['image_id']] = self.float_convert( item['scores']['All']['f']) average_score = np.mean(np.array(spice_scores)) scores = [] for image_id in imgIds: # Convert none to NaN before saving scores over subcategories score_set = {} for category, score_tuple in imgId_to_scores[image_id].items(): score_set[category] = {k: self.float_convert( v) for k, v in score_tuple.items()} scores.append(score_set) return average_score, individual_scores
[docs] def method(self): return "SPICE"
if __name__ == "__main__": gts = {"cat#dog#boy": ["The dog is the boy's cat.", "The dog eats the cat of the boy."], "apple#tree#boy": ["A boy is picking apples from trees."]} res = {"cat#dog#boy": ["The dog is the boy's cat."], "apple#tree#boy": ["A boy is picking apples from trees and put them into bags."]} metric = Spice() print(metric.compute_score(gts, res))