Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| # import os | |
| # os.chdir('modeling') | |
| import tensorflow as tf, tf_keras | |
| import tensorflow_hub as hub | |
| from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM | |
| from official.projects.movinet.modeling import movinet | |
| from official.projects.movinet.modeling import movinet_model_a2_modified as movinet_model_modified | |
| movinet_path = 'movinet_checkpoints_a2_epoch9' | |
| movinet_model = tf_keras.models.load_model(movinet_path) | |
| movinet_model.trainable = False | |
| tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
| t5_model = TFAutoModelForSeq2SeqLM.from_pretrained("deanna-emery/ASL_t5_movinet_sentence") | |
| t5_model.trainable = False | |
| def crop_center_square(frame): | |
| y, x = frame.shape[0:2] | |
| if x > y: | |
| start_x = (x-y)/2 | |
| end_x = start_x + y | |
| start_x = int(start_x) | |
| end_x = int(end_x) | |
| return frame[:, int(start_x):int(end_x)] | |
| else: | |
| return frame | |
| def preprocess(filename, max_frames=0, resize=(224,224)): | |
| video_capture = cv2.VideoCapture(filename) | |
| frames = [] | |
| try: | |
| while video_capture.isOpened(): | |
| ret, frame = video_capture.read() | |
| if not ret: | |
| break | |
| frame = crop_center_square(frame) | |
| frame = cv2.resize(frame, resize) | |
| frame = frame[:, :, [2, 1, 0]] | |
| frames.append(frame) | |
| if len(frames) == max_frames: | |
| break | |
| finally: | |
| video_capture.release() | |
| video = np.array(frames) / 255.0 | |
| video = np.expand_dims(video, axis=0) | |
| return video | |
| def translate(video_file, true_caption=None): | |
| video = preprocess(video_file, max_frames=0, resize=(224,224)) | |
| embeddings = movinet_model(video)['vid_embedding'] | |
| tokens = t5_model.generate(inputs_embeds = embeddings, | |
| max_new_tokens=128, | |
| temperature=0.1, | |
| no_repeat_ngram_size=2, | |
| do_sample=True, | |
| top_k=80, | |
| top_p=0.90, | |
| ) | |
| translation = tokenizer.batch_decode(tokens, skip_special_tokens=True) | |
| return {"translation":translation} | |
| # Gradio App config | |
| title = "American Sign Language Translation: An Approach Combining MoViNets and T5" | |
| description = """ | |
| This application hosts a model for translation of American Sign Language (ASL). | |
| The model comprises of a fine-tuned MoViNet CNN model to generate video embeddings and a T5 encoder-decoder model | |
| to generate translations from the video embeddings. This model architecture achieves a BLEU score of 1.98 | |
| and an average cosine similarity score of 0.21 when trained and evaluated on the YouTube-ASL dataset. | |
| More information about the model training and instructions to download the models | |
| can be found in our <a href=https://github.com/deanna-emery/ASL-Translator>GitHub repository</a>. | |
| You can also find an overview of the project approach | |
| <a href=https://www.ischool.berkeley.edu/projects/2023/signsense-american-sign-language-translation>here</a>. | |
| A limitation of this architecture is the size of the MoViNets model, making it especially slow during inference on a CPU. | |
| We do not recommend uploading videos longer than 4 seconds as the video embedding generation may take some time. | |
| The application does not accept videos that are longer than 10 seconds. | |
| We have provided some pre-cached videos with their original captions and translations as examples. | |
| """ | |
| examples = [ | |
| ["videos/My_second_ASL_professors_name_was_Will_White.mp4", "My second ASL professor's name was Will White"], | |
| ['videos/You_are_my_sunshine.mp4', 'You are my sunshine'], | |
| ['videos/scrub_your_hands_for_at_least_20_seconds.mp4', 'scrub your hands for at least 20 seconds'], | |
| ['videos/no.mp4', 'no'], | |
| ["videos/i_feel_rejuvenated_by_this_beautiful_weather.mp4","I feel rejuvenated by this beautiful weather"], | |
| ["videos/north_dakota_they_dont_need.mp4","... north dakota they don't need ..."], | |
| ] | |
| # Gradio App interface | |
| gr.Interface(fn=translate, | |
| inputs=[gr.Video(label='Video', show_label=True, max_length=10, sources='upload'), | |
| gr.Textbox(label='Caption', show_label=True, interactive=False, visible=False)], | |
| outputs="text", | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| ).launch() |